查看非叶节点梯度的两种方法
在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:
- 使用autograd.grad函数
- 使用hook
autograd.grad
和hook
方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用hook
方法,但是在实际使用中应尽量避免修改grad的值。
求z对y的导数
x = V(t.ones(3)) w = V(t.rand(3),requires_grad=True) y = w.mul(x) z = y.sum() # hook # hook没有返回值,参数是函数,函数的参数是梯度值 def variable_hook(grad): print("hook梯度输出:\r\n",grad) hook_handle = y.register_hook(variable_hook) # 注册hook z.backward(retain_graph=True) # 内置输出上面的hook hook_handle.remove() # 释放 print("autograd.grad输出:\r\n",t.autograd.grad(z,y)) # t.autograd.grad方法
hook梯度输出: Variable containing: 1 1 1 [torch.FloatTensor of size 3] autograd.grad输出: (Variable containing: 1 1 1 [torch.FloatTensor of size 3] ,)
多次反向传播试验
实际就是使用retain_graph参数,
# 构件图 x = V(t.ones(3)) w = V(t.rand(3),requires_grad=True) y = w.mul(x) z = y.sum() z.backward(retain_graph=True) print(w.grad) z.backward() print(w.grad)
Variable containing: 1 1 1 [torch.FloatTensor of size 3] Variable containing: 2 2 2 [torch.FloatTensor of size 3]
如果不使用retain_graph参数,
实际上效果是一样的,AccumulateGrad object仍然会积累梯度
# 构件图 x = V(t.ones(3)) w = V(t.rand(3),requires_grad=True) y = w.mul(x) z = y.sum() z.backward() print(w.grad) y = w.mul(x) # <----- z = y.sum() # <----- z.backward() print(w.grad)
Variable containing: 1 1 1 [torch.FloatTensor of size 3] Variable containing: 2 2 2 [torch.FloatTensor of size 3]
分析:
这里的重新建立高级节点意义在这里:实际上高级节点在创建时,会缓存用于输入的低级节点的信息(值,用于梯度计算),但是这些buffer在backward之后会被清空(推测是节省内存),而这个buffer实际也体现了上面说的动态图的”动态”过程,之后的反向传播需要的数据被清空,则会报错,这样我们上面过程就分别从:保留数据不被删除&重建数据两个角度实现了多次backward过程。
实际上第二次的z.backward()已经不是第一次的z所在的图了,体现了动态图的技术,静态图初始化之后会留在内存中等待feed数据,但是动态图不会,动态图更类似我们自己实现的机器学习框架实践,相较于静态逻辑简单一点,只是PyTorch的静态图和我们的比会在反向传播后清空存下的数据:下次要么完全重建,要么反向传播之后指定不舍弃图z.backward(retain_graph=True)。
总之图上的节点是依赖buffer记录来完成反向传播,TensorFlow中会一直存留,PyTorch中就会backward后直接舍弃(默认时)。