在使用PyTorch时,你可能会遇到这种问题,网络层被Sequential类包装了起来,没办法在其中插入print语句来查看在网络中流动的张量的形状与具体的值.例如以下代码:
model = nn.Sequential(
nn.Linear(1, 5),
nn.ReLU(),
nn.Linear(5,1),
nn.LogSigmoid()
)
通过求助PyTorch官方论坛,找到了以下解决方法:
我们可以通过继承,创建一个PrintLayer,专门用于调试
class PrintLayer(nn.Module):
def __init__(self):
super(PrintLayer, self).__init__()
def forward(self, x):
# Do your print / debug stuff here
print(x) #print(x.shape)
return x
在我们需要调试的网络层之间插入PrintLayer的实例化对象,例如
model = nn.Sequential(
nn.Linear(1, 5),
PrintLayer(), # Add Print layer for debug
nn.ReLU(),
nn.Linear(5,1),
nn.LogSigmoid(),
)
x = Variable(torch.randn(10, 1))
output = model(x)
这样就可以打印出通过nn.Linear(1,5)后,张量的形状或数值.
参考文献: