关键字:不同网络之间的梯度传递
在使用TF进行链式法则求梯度非常方便,但是有时我们需要在不同的网络之间传递梯度,本文讨论这种情况下如何coding。
1. 简单链式传递
最简单的一种形式,是将一个网络的梯度传递至另外一个网络。
有两个网络,串联结构:网络1参数的优化目标由网络2的输出进行定义。
即:loss_function_net1 = loss_function_net2 = LOSS
很简单,使用公式 即可。
具体代码如下:
# 传递给net 1使用的梯度
grad_to_net1 = tf.gradients(loss, input_net2)
# net 1进行参数优化时使用的梯度
grad_on_params_net1 = tf.gradients(input_net2, params_net1, grad_to_net1)
2. 复杂情况下的梯度传递—V1
在某些情况下,网络1的loss function并非和网络 2的loss function相同,反而可能是其函数。
在这种情况下,net2 只能传过来关于output1的梯度,具体的代码如下:
# 将loss_2转换为list
# 这里假设loss_2为一维的tensor,如果是多维,需要进行对应的使用tf.slice的切片
loss2_list = [tf.slice(loss2, [i], [1]) for i in range(loss2.shape[0])]
# 传递给net1使用的梯度
grad_on_output1 = [tf.gradients(l2, output1)[0] for l2 in loss2_list]
# 计算loss1关于loss2的梯度
grads_on_loss2 = tf.gradients(loss1, loss2)
grads_on_loss2_list = [tf.slice(grads_on_loss2 , [i], [1]) for i in range(grads_on_loss2 .shape[0])]
# 计算loss1关于output1的梯度
grads_tmp = sum([g_l2 * g_o1] for g_l2, g_o1 in zip(grads_on_loss2_list, grads_on_output1))
# result
grad_on_params1 = tf.gradients(output1, params_net1, grad_tmp)
关键:TF在计算一个list对tensor的导数时,总是会求得其sum,必须手动切片分开,再进行求导,方可实现更为自由地链式法则求导。
2. 复杂情况下的梯度传递—V2
V1版本的实现方式虽然可以解决问题,但是在TF内部,会对整个切片进行构建graph,这将消耗大量内存,当loss2是一个高参数量的值时,并不是一种好的实现方式。
更为快速的方式为,两个网络之间来回调用 tf.gradients(ys, xs, ys_grad)
具体代码如下
# net 2 计算
loss2 = func() # 发送给net1
# net 1 计算
grads_on_loss2 = tf.gradients(loss1, loss2) # 发送给net2
# net 2 计算
grads_on_output1 = tf.gradient(loss1, output1, grads_on_loss2) # 发送给net1
# net1 计算
grads_on_params1 = tf.gradients(output1, params_net1, grads_on_output1)