tf.get_variable()
方法是TensorFlow提供的比tf.Variable()
稍微高级的创建/获取变量的方法,它的工作方式根据当前的变量域(Variable Scope)的reuse
属性变化而变化,我们可以通过tf.get_variable_scope().reuse
来查看这个属性,它默认是False
。
-
tf.get_variable_scope().reuse == False
此时调用tf.get_variable(name, shape, dtype, initializer)
,我们可以创建一个新的变量(或者说张量),这个变量的名字为name,维度是shape,数据类型是dtype,初始化方法是指定的initializer。如果名字为name的变量已经存在的话,会导致ValueError
。
一个例子如下:
# create var
entity = tf.get_variable(name='entity', initializer=...)
-
tf.get_variable_scope().reuse == True
此时调用tf.get_variable(name)
,我们 可以 得到一个已经存在的名字为name的变量,如果这个变量不存在的话,会导致ValueError
。
一个例子如下:
# reuse var
tf.get_variable_scope().reuse_variables() # set reuse to True
entity = tf.get_variable(name='entity')
上面的两种情况得到的变量的名字都为name,这是假设在默认的变量域中调用tf.get_variable()
,如果在指定的变量域中调用,比如:
# create var
with tf.variable_scope('embedding'):
entity = tf.get_variable(name='entity', initializer=...)
# reuse var
with tf.variable_scope('embedding', reuse=True):
entity = tf.get_variable(name='entity')
那么得到的变量entity
的名字则是embedding/entity
。