定义一个特征提取的类:
参考pytorch论坛:How to extract features of an image from a trained model
from torchvision.models import resnet18 import torch.nn as nn myresnet=resnet18(pretrained=True) print (myresnet) class FeatureExtractor(nn.Module): def __init__(self, submodule, extracted_layers): super(FeatureExtractor, self).__init__() self.submodule = submodule self.extracted_layers = extracted_layers def forward(self, x): outputs = [] for name, module in self.submodule._modules.items(): if name is "fc": x = x.view(x.size(0), -1) x = module(x) # last layer output put into current layer input print(name) if name in self.extracted_layers: outputs.append(x) return outputs exact_list=["conv1","layer1","avgpool"] myexactor=FeatureExtractor(myresnet,exact_list).cuda() x = Variable(torch.rand(5, 3, 224, 224), requires_grad=True).cuda() y=myexactor(x) # 5x64x112x112 5x64x56x56 5x512x1x1 print (myexactor) print(type(y)) print(type(y[0])) for i in range(len(y)): print y[i].data.cpu().numpy().size print y[i].data.cpu().numpy().shape # <type 'list'> # <class 'torch.autograd.variable.Variable'> # 4014080 # (5, 64, 112, 112) # 1003520 # (5, 64, 56, 56) # 2560 # (5, 512, 1, 1)
#特征输出可视化
import matplotlib.pyplot as plt for i in range(64): ax = plt.subplot(8, 8, i + 1) ax.set_title('Sample #{}'.format(i)) ax.axis('off') plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet') plt.show()
- Accessing and modifying different layers of a pretrained model in pytorch:https://github.com/mortezamg63/Accessing-and-modifying-different-layers-of-a-pretrained-model-in-pytorch