PyTorch debug技巧 如何打印Sequential包装的网络层的输出

在使用PyTorch时,你可能会遇到这种问题,网络层被Sequential类包装了起来,没办法在其中插入print语句来查看在网络中流动的张量的形状与具体的值.例如以下代码:

model = nn.Sequential(
        nn.Linear(1, 5),
        nn.ReLU(),
        nn.Linear(5,1),
        nn.LogSigmoid()
 )

通过求助PyTorch官方论坛,找到了以下解决方法:

我们可以通过继承,创建一个PrintLayer,专门用于调试

class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x)      #print(x.shape)
        return x

在我们需要调试的网络层之间插入PrintLayer的实例化对象,例如

model = nn.Sequential(
        nn.Linear(1, 5),

        PrintLayer(), # Add Print layer for debug
        
        nn.ReLU(),
        nn.Linear(5,1),
        nn.LogSigmoid(),
)

x = Variable(torch.randn(10, 1))
output = model(x)

这样就可以打印出通过nn.Linear(1,5)后,张量的形状或数值.

参考文献:

https://discuss.pytorch.org/t/how-do-i-print-output-of-each-layer-in-sequential/5773/4discuss.pytorch.org

    原文作者:wenxin
    原文地址: https://zhuanlan.zhihu.com/p/52013707
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞