『PyTorch』第五弹_深入理解autograd_中:Variable梯度探究

查看非叶节点梯度的两种方法

在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:

  • 使用autograd.grad函数
  • 使用hook

autograd.gradhook方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用hook方法,但是在实际使用中应尽量避免修改grad的值。

求z对y的导数

x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

# hook
# hook没有返回值,参数是函数,函数的参数是梯度值
def variable_hook(grad):
    print("hook梯度输出:\r\n",grad)

hook_handle = y.register_hook(variable_hook)         # 注册hook
z.backward(retain_graph=True)                        # 内置输出上面的hook
hook_handle.remove()                                 # 释放

print("autograd.grad输出:\r\n",t.autograd.grad(z,y)) # t.autograd.grad方法
hook梯度输出:
 Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

autograd.grad输出:
 (Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]
,)

 

多次反向传播试验

实际就是使用retain_graph参数,

# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

z.backward(retain_graph=True)
print(w.grad)
z.backward()
print(w.grad)
Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

Variable containing:
 2
 2
 2
[torch.FloatTensor of size 3]

 

如果不使用retain_graph参数,

实际上效果是一样的,AccumulateGrad object仍然会积累梯度

# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

z.backward()
print(w.grad)
y = w.mul(x)  # <-----
z = y.sum()  # <-----
z.backward()
print(w.grad)
Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

Variable containing:
 2
 2
 2
[torch.FloatTensor of size 3]

 

分析:

这里的重新建立高级节点意义在这里:实际上高级节点在创建时,会缓存用于输入的低级节点的信息(值,用于梯度计算),但是这些buffer在backward之后会被清空(推测是节省内存),而这个buffer实际也体现了上面说的动态图的”动态”过程,之后的反向传播需要的数据被清空,则会报错,这样我们上面过程就分别从:保留数据不被删除&重建数据两个角度实现了多次backward过程。

实际上第二次的z.backward()已经不是第一次的z所在的图了,体现了动态图的技术,静态图初始化之后会留在内存中等待feed数据,但是动态图不会,动态图更类似我们自己实现的机器学习框架实践,相较于静态逻辑简单一点,只是PyTorch的静态图和我们的比会在反向传播后清空存下的数据:下次要么完全重建,要么反向传播之后指定不舍弃图z.backward(retain_graph=True)。

总之图上的节点是依赖buffer记录来完成反向传播,TensorFlow中会一直存留,PyTorch中就会backward后直接舍弃(默认时)。

    原文作者:pytorch
    原文地址: https://www.cnblogs.com/hellcat/p/8449801.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞