本文参考《TensorFlow实战Google深度学习框架》一书,总结了一些在TensorFlow中与变量管理相关的一些API和使用技巧
1.创建变量
TensorFlow中可以通过tf.Variable和tf.get_variable两个函数来创建变量,两者基本功能相同,但是用法存在差别。
#下面两个定义是等价的,只不过变量的名称不同
v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = 'v1')
v2 = tf.get_variable('v2', shape = [1], initializer = tf.constant_initializer(1.0))tf.get_variable在函数调用时名称是必须输入的,而tf.Variable则不是必选的。此外tf.get_variable函数调用时提供的维度和初始化方法与tf.Variable也类似,TF中提供的initializer函数与随机数和常量的生成函数大部分是一一对应的。
在变量初始化的过成中,tf.Variable可以使用其他已经初始化的变量对其进行初始化
v3 = tf.Variable(v1.initialized_value() * 2, name = 'v3')而tf.get_variable函数可以通过tf.variable_scope函数来生成一个上下文管理器,并且明确指定在这个上下文管理器中,tf.get_variable将直接获取已经生成的变量。
import tensorflow as tf
with tf.variable_scope('scope1', reuse = False) as scope:
print(tf.get_variable_scope().reuse)
x1 = tf.get_variable('x', [1])
with tf.variable_scope('scope1', reuse = True) as scope:
print(tf.get_variable_scope().reuse)
x2 = tf.get_variable('x', [1])
print(x1 == x2)#输出Truetf.variable_scope函数的reuse默认为None。当reuse=False或者None时,tf.get_variable将创建新的变量,如果同名的变量已经存在了,那么会报错。如果reuse=True,tf.get_variable函数将会直接获取已经创建的变量,如果变量不存在,则会报错。
此外tf.variable_scope函数可以嵌套,当有外层已经被指定为reuse=True之后,内层嵌套的其他同类函数的reuse都会被默认设置为True,除非有显示地指明。此处就不再以代码示人。
2.命名空间
tf.variable_scope生成的上下文管理器会创建一个TF的命名空间,在该空间内创建的变量name都会带上这个命名空间名作为前缀。命名空间随着tf.variable_scope的嵌套,也可以进行嵌套,会有不同的name属性。但是如果在内层中使用的变量标识符与外层使用的相同,则该变量会被更新。如果是并列的没有包含关系的命名空间,使用相同的标识符表示变量则不会有冲突。
import tensorflow as tf
sess = tf.InteractiveSession()
with tf.variable_scope('scope1') as scope:
print(tf.get_variable_scope().reuse)
x1 = tf.get_variable('x', [1], initializer = tf.constant_initializer(1))
print(x1.name)#print scope1/x:0
with tf.variable_scope('scope2') as scope:
x2 = tf.get_variable('x', [1],initializer = tf.constant_initializer(2))#如果这里使用x1作为标识符,则x1会被更新为2
print(x2.name)#print scope1/scope2/x:0
with tf.variable_scope('scope1', reuse = True) as scope:
x3 = tf.get_variable('x', [1])
print(x3.name)
print(x1 == x3)#print scope1/x:0
with tf.variable_scope('', reuse = True) as scope:#这里只能用空的名称
print(tf.get_variable_scope().reuse)
x4 = tf.get_variable('scope1/x', [1], initializer = tf.constant_initializer(3))#这里的初始化没有作用,仍然为1
print(x4.name)#print scope1/x:0
print(x4 == x1)#True 如果scope2中的标识符也叫x1,那么会输出Fasle
sess.run(tf.initialize_all_variables())
print(x1.eval(), x2.eval(), x3.eval(), x4.eval())#print 1 2 1 1
print(tf.get_default_graph().get_tensor_by_name('scope1/x:0').eval())#1
print(tf.get_default_graph().get_tensor_by_name('scope1/scope2/x:0').eval())#2
3.初始化
TF中使用tf.initialize_all_variables()对所有的变量进行初始化。一般有两种调用方法。
sess.run(tf.initialize_all_variables())
#tf.initialize_all_variables().run()也可以单独对某个变量进行初始化,但是不常见
sess.run(x1.initializer)