您好,登錄后才能下訂單哦!
環境:Ubuntu14.04,tensorflow=1.4(bazel源碼安裝),Anaconda python=3.6
聲明變量主要有兩種方法:tf.Variable和 tf.get_variable,二者的最大區別是:
(1) tf.Variable是一個類,自帶很多屬性函數;而 tf.get_variable是一個函數;
(2) tf.Variable只能生成獨一無二的變量,即如果給出的name已經存在,則會自動修改生成新的變量name;
(3) tf.get_variable可以用于生成共享變量。默認情況下,該函數會進行變量名檢查,如果有重復則會報錯。當在指定變量域中聲明可
以變量共享時,可以重復使用該變量(例如RNN中的參數共享)。
下面給出簡單的的示例程序:
import tensorflow as tf with tf.variable_scope('scope1',reuse=tf.AUTO_REUSE) as scope1: x1 = tf.Variable(tf.ones([1]),name='x1') x2 = tf.Variable(tf.zeros([1]),name='x1') y1 = tf.get_variable('y1',initializer=1.0) y2 = tf.get_variable('y1',initializer=0.0) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print(x1.name,x1.eval()) print(x2.name,x2.eval()) print(y1.name,y1.eval()) print(y2.name,y2.eval())
輸出結果為:
scope1/x1:0 [ 1.] scope1/x1_1:0 [ 0.] scope1/y1:0 1.0 scope1/y1:0 1.0
1. tf.Variable(…)
tf.Variable(…)使用給定初始值來創建一個新變量,該變量會默認添加到 graph collections listed in collections, which defaults to [GraphKeys.GLOBAL_VARIABLES]。
如果trainable屬性被設置為True,該變量同時也會被添加到graph collection GraphKeys.TRAINABLE_VARIABLES.
# tf.Variable __init__( initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None, constraint=None )
2. tf.get_variable(…)
tf.get_variable(…)的返回值有兩種情形:
使用指定的initializer來創建一個新變量;
當變量重用時,根據變量名搜索返回一個由tf.get_variable創建的已經存在的變量;
get_variable( name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=True, collections=None, caching_device=None, partitioner=None, validate_shape=True, use_resource=None, custom_getter=None, constraint=None )
3. 根據名稱查找變量
在創建變量時,即使我們不指定變量名稱,程序也會自動進行命名。于是,我們可以很方便的根據名稱來查找變量,這在抓取參數、finetune模型等很多時候都很有用。
示例1:
通過在tf.global_variables()變量列表中,根據變量名進行匹配搜索查找。 該種搜索方式,可以同時找到由tf.Variable或者tf.get_variable創建的變量。
import tensorflow as tf x = tf.Variable(1,name='x') y = tf.get_variable(name='y',shape=[1,2]) for var in tf.global_variables(): if var.name == 'x:0': print(var)
示例2:
利用get_tensor_by_name()同樣可以獲得由tf.Variable或者tf.get_variable創建的變量。
需要注意的是,此時獲得的是Tensor, 而不是Variable,因此 x不等于x1.
import tensorflow as tf x = tf.Variable(1,name='x') y = tf.get_variable(name='y',shape=[1,2]) graph = tf.get_default_graph() x1 = graph.get_tensor_by_name("x:0") y1 = graph.get_tensor_by_name("y:0")
示例3:
針對tf.get_variable創建的變量,可以利用變量重用來直接獲取已經存在的變量。
with tf.variable_scope("foo"): bar1 = tf.get_variable("bar", (2,3)) # create with tf.variable_scope("foo", reuse=True): bar2 = tf.get_variable("bar") # reuse with tf.variable_scope("", reuse=True): # root variable scope bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above) print((bar1 is bar2) and (bar2 is bar3))
以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。