TensorFlow中的变量管理

xiaoxiao2021-02-28  19

本文参考《TensorFlow实战Google深度学习框架》一书,总结了一些在TensorFlow中与变量管理相关的一些API和使用技巧

1.创建变量

TensorFlow中可以通过tf.Variable和tf.get_variable两个函数来创建变量,两者基本功能相同,但是用法存在差别。

#下面两个定义是等价的,只不过变量的名称不同 v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = 'v1') v2 = tf.get_variable('v2', shape = [1], initializer = tf.constant_initializer(1.0))tf.get_variable在函数调用时名称是必须输入的,而tf.Variable则不是必选的。此外tf.get_variable函数调用时提供的维度和初始化方法与tf.Variable也类似,TF中提供的initializer函数与随机数和常量的生成函数大部分是一一对应的。

在变量初始化的过成中,tf.Variable可以使用其他已经初始化的变量对其进行初始化

v3 = tf.Variable(v1.initialized_value() * 2, name = 'v3')而tf.get_variable函数可以通过tf.variable_scope函数来生成一个上下文管理器,并且明确指定在这个上下文管理器中,tf.get_variable将直接获取已经生成的变量。 import tensorflow as tf with tf.variable_scope('scope1', reuse = False) as scope: print(tf.get_variable_scope().reuse) x1 = tf.get_variable('x', [1]) with tf.variable_scope('scope1', reuse = True) as scope: print(tf.get_variable_scope().reuse) x2 = tf.get_variable('x', [1]) print(x1 == x2)#输出Truetf.variable_scope函数的reuse默认为None。当reuse=False或者None时,tf.get_variable将创建新的变量,如果同名的变量已经存在了,那么会报错。如果reuse=True,tf.get_variable函数将会直接获取已经创建的变量,如果变量不存在,则会报错。

此外tf.variable_scope函数可以嵌套,当有外层已经被指定为reuse=True之后,内层嵌套的其他同类函数的reuse都会被默认设置为True,除非有显示地指明。此处就不再以代码示人。

2.命名空间

tf.variable_scope生成的上下文管理器会创建一个TF的命名空间,在该空间内创建的变量name都会带上这个命名空间名作为前缀。命名空间随着tf.variable_scope的嵌套,也可以进行嵌套,会有不同的name属性。但是如果在内层中使用的变量标识符与外层使用的相同,则该变量会被更新。如果是并列的没有包含关系的命名空间,使用相同的标识符表示变量则不会有冲突。 import tensorflow as tf sess = tf.InteractiveSession() with tf.variable_scope('scope1') as scope: print(tf.get_variable_scope().reuse) x1 = tf.get_variable('x', [1], initializer = tf.constant_initializer(1)) print(x1.name)#print scope1/x:0 with tf.variable_scope('scope2') as scope: x2 = tf.get_variable('x', [1],initializer = tf.constant_initializer(2))#如果这里使用x1作为标识符,则x1会被更新为2 print(x2.name)#print scope1/scope2/x:0 with tf.variable_scope('scope1', reuse = True) as scope: x3 = tf.get_variable('x', [1]) print(x3.name) print(x1 == x3)#print scope1/x:0 with tf.variable_scope('', reuse = True) as scope:#这里只能用空的名称 print(tf.get_variable_scope().reuse) x4 = tf.get_variable('scope1/x', [1], initializer = tf.constant_initializer(3))#这里的初始化没有作用,仍然为1 print(x4.name)#print scope1/x:0 print(x4 == x1)#True 如果scope2中的标识符也叫x1,那么会输出Fasle sess.run(tf.initialize_all_variables()) print(x1.eval(), x2.eval(), x3.eval(), x4.eval())#print 1 2 1 1 print(tf.get_default_graph().get_tensor_by_name('scope1/x:0').eval())#1 print(tf.get_default_graph().get_tensor_by_name('scope1/scope2/x:0').eval())#2

3.初始化

TF中使用tf.initialize_all_variables()对所有的变量进行初始化。一般有两种调用方法。 sess.run(tf.initialize_all_variables()) #tf.initialize_all_variables().run()也可以单独对某个变量进行初始化,但是不常见 sess.run(x1.initializer)
转载请注明原文地址: https://www.6miu.com/read-2050128.html

最新回复(0)