tensorflow 中的tf.gradients 与 tf.stop

xiaoxiao2021-02-27  231

转载网址:http://blog.csdn.net/u012436149/article/details/53905797

gradient

tensorflow中有一个计算梯度的函数tf.gradients(ys, xs),要注意的是,xs中的x必须要与ys相关,不相关的话,会报错。 代码中定义了两个变量w1, w2, 但res只与w1相关

#wrong import tensorflow as tf w1 = tf.Variable([[1,2]]) w2 = tf.Variable([[3,4]]) res = tf.matmul(w1, [[2],[1]]) grads = tf.gradients(res,[w1,w2]) with tf.Session() as sess: tf.global_variables_initializer().run() re = sess.run(grads) print(re)

错误信息 TypeError: Fetch argument None has invalid type

# right import tensorflow as tf w1 = tf.Variable([[1,2]]) w2 = tf.Variable([[3,4]]) res = tf.matmul(w1, [[2],[1]]) grads = tf.gradients(res,[w1]) with tf.Session() as sess: tf.global_variables_initializer().run() re = sess.run(grads) print(re) # [array([[2, 1]], dtype=int32)]

tf.stop_gradient()

阻挡节点BP的梯度

import tensorflow as tf w1 = tf.Variable(2.0) w2 = tf.Variable(2.0) a = tf.multiply(w1, 3.0) a_stoped = tf.stop_gradient(a) # b=w1*3.0*w2 b = tf.multiply(a_stoped, w2) gradients = tf.gradients(b, xs=[w1, w2]) print(gradients) #输出 #[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]

可见,一个节点被 stop之后,这个节点上的梯度,就无法再向前BP了。由于w1变量的梯度只能来自a节点,所以,计算梯度返回的是None。

a = tf.Variable(1.0) b = tf.Variable(1.0) c = tf.add(a, b) c_stoped = tf.stop_gradient(c) d = tf.add(a, b) e = tf.add(c_stoped, d) gradients = tf.gradients(e, xs=[a, b]) with tf.Session() as sess: tf.global_variables_initializer().run() print(sess.run(gradients)) #输出 [1.0, 1.0]

虽然 c节点被stop了,但是a,b还有从d传回的梯度,所以还是可以输出梯度值的。

import tensorflow as tf w1 = tf.Variable(2.0) w2 = tf.Variable(2.0) a = tf.multiply(w1, 3.0) a_stoped = tf.stop_gradient(a) # b=w1*3.0*w2 b = tf.multiply(a_stoped, w2) opt = tf.train.GradientDescentOptimizer(0.1) gradients = tf.gradients(b, xs=tf.trainable_variables()) tf.summary.histogram(gradients[0].name, gradients[0])# 这里会报错,因为gradients[0]是None #其它地方都会运行正常,无论是梯度的计算还是变量的更新。总觉着tensorflow这么设计有点不好, #不如改成流过去的梯度为0 train_op = opt.apply_gradients(zip(gradients, tf.trainable_variables())) print(gradients) with tf.Session() as sess: tf.global_variables_initializer().run() print(sess.run(train_op)) print(sess.run([w1, w2]))
转载请注明原文地址: https://www.6miu.com/read-3351.html

最新回复(0)