修改pytorch官方实例适用于自己的二分类迁移学习项目

本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:

  1. 根据AUC来迭代最优参数;
  2. 五折交叉验证;
  3. 输出验证集错误分类图片;
  4. 输出分类报告并保存AUC结果图片。
      1 import os
      2 import numpy as np
      3 import torch
      4 import torch.nn as nn
      5 from torch.optim import lr_scheduler
      6 import torchvision
      7 from torchvision import datasets, models, transforms
      8 from torch.utils.data import DataLoader
      9 from sklearn.metrics import roc_auc_score, classification_report
     10 from sklearn.model_selection import KFold
     11 from torch.autograd import Variable
     12 import torch.optim as optim
     13 import time
     14 import copy
     15 import shutil
     16 import sys
     17 import scikitplot as skplt
     18 import matplotlib.pyplot as plt
     19 import pandas as pd
     20 
     21 plt.switch_backend('agg')
     22 N_CLASSES = 2
     23 BATCH_SIZE = 8
     24 DATA_DIR = './data'
     25 LABEL_DICT = {0: 'class_1', 1: 'class_2'}
     26 
     27 
     28 def imshow(inp, title=None):
     29     """Imshow for Tensor."""
     30     inp = inp.numpy().transpose((1, 2, 0))
     31     mean = np.array([0.485, 0.456, 0.406])
     32     std = np.array([0.229, 0.224, 0.225])
     33     inp = std * inp + mean
     34     inp = np.clip(inp, 0, 1)
     35     plt.imshow(inp)
     36     if title is not None:
     37         plt.title(title)
     38     plt.pause(100)
     39 
     40 
     41 def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
     42     since = time.time()
     43     # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
     44     best_model_wts = copy.deepcopy(model.state_dict())
     45     # best_acc = 0.0
     46     # 初始auc
     47     best_auc = 0.0
     48     best_desc = [0, 0, None]
     49     best_img_name = None
     50     plt_auc = [None, None]
     51 
     52     for epoch in range(num_epochs):
     53         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
     54         print('- ' * 50)
     55 
     56         for phase in ['train', 'val']:
     57             if phase == 'train':
     58                 # 训练的时候进行学习率规划,其定义在下面给出
     59                 scheduler.step()
     60                 model.train(True)
     61             else:
     62                 model.train(False)
     63             phase_pred = np.array([])
     64             phase_label = np.array([])
     65             img_name = np.zeros((1, 2))
     66             prob_pred = np.zeros((1, 2))
     67             running_loss = 0.0
     68             running_corrects = 0
     69             # 这样迭代方便跟踪图片路径,输出错误图片名称
     70             for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
     71                 inputs, labels = data
     72                 if use_gpu:
     73                     inputs = Variable(inputs.cuda())
     74                     labels = Variable(labels.cuda())
     75                 else:
     76                     inputs, labels = Variable(inputs), Variable(labels)
     77 
     78                 # 梯度参数设为0
     79                 optimizer.zero_grad()
     80 
     81                 # forward
     82                 outputs = model(inputs)
     83                 _, preds = torch.max(outputs.data, 1)
     84                 loss = criterion(outputs, labels)
     85 
     86                 # backward + 训练阶段优化
     87                 if phase == 'train':
     88                     loss.backward()
     89                     optimizer.step()
     90 
     91                 if phase == 'val':
     92                     img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
     93                     prob = outputs.data.cpu().numpy()
     94                     prob_pred = np.append(prob_pred, prob, axis=0)
     95 
     96                 phase_pred = np.append(phase_pred, preds.cpu().numpy())
     97                 phase_label = np.append(phase_label, labels.data.cpu().numpy())
     98                 running_loss += loss.item() * inputs.size(0)
     99                 running_corrects += torch.sum(preds == labels.data).float()
    100             print()
    101             epoch_loss = running_loss / dataset_sizes[phase]
    102             epoch_acc = running_corrects / dataset_sizes[phase]
    103             epoch_auc = roc_auc_score(phase_label, phase_pred)
    104             print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
    105                 phase, epoch_loss, epoch_acc, epoch_auc))
    106             report = classification_report(phase_label, phase_pred, target_names=class_names)
    107             print(report)
    108 
    109             img_name = zip(img_name[1:], phase_pred)
    110             # 当验证时遇到了更好的模型则予以保留
    111             if phase == 'val' and epoch_auc > best_auc:
    112                 best_auc = epoch_auc
    113                 best_desc = epoch_acc, epoch_auc, report
    114                 best_img_name = img_name
    115                 # 深拷贝模型参数
    116                 best_model_wts = copy.deepcopy(model.state_dict())
    117                 plt_auc = phase_label, prob_pred[1:]
    118 
    119         print()
    120     print(plt_auc[0].shape, plt_auc[1].shape)
    121     csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
    122     csv_file['true_label'] = pd.DataFrame(plt_auc[0])
    123     csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
    124     csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
    125     skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
    126     plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
    127     time_elapsed = time.time() - since
    128     print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    129     reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
    130                                                                                          best_desc[2])
    131     report_file.write(reports)
    132     print(reports)
    133     print('List the wrong judgement img ...')
    134     count = 0
    135     for i in best_img_name:
    136         actual_label = int(i[0][1])
    137         pred_label = i[1]
    138         if actual_label != pred_label:
    139             tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
    140                        f'pred: {LABEL_DICT[pred_label]}'
    141             print(tmp_word)
    142             label_file.write(tmp_word + '\n')
    143             count += 1
    144     print(f'This fold has {count} wrong records ...')
    145 
    146     # 载入最优模型参数
    147     model.load_state_dict(best_model_wts)
    148     return model
    149 
    150 
    151 def plot_img():
    152     for i, data in enumerate(dataloaders['train']):
    153         inputs, classes = data
    154         out = torchvision.utils.make_grid(inputs)
    155         imshow(out, title=[class_names[x] for x in classes])
    156 
    157 
    158 # 此函数可以修改适用于自己项目的图片文件名
    159 def move_file(data, file_path, dir_path, root_path):
    160     label_0 = 'class_2'
    161     label_1 = 'class_1'
    162     print(f'start copy the {file_path} file ...')
    163     os.chdir(dir_path)
    164     if os.path.exists(file_path):
    165         print(f'Find exist {file_path} file, the file will be dropped.')
    166         shutil.rmtree(os.path.join(root_path, dir_path, file_path))
    167         print(f'Finish drop the {file_path} file.')
    168 
    169     os.mkdir(file_path)
    170     tmp_path = os.path.join(os.getcwd(), file_path)
    171     tmp_pre_path = os.getcwd()
    172     for d in data:
    173         pre_path = os.path.join(tmp_pre_path, d)
    174         os.chdir(tmp_path)
    175         if d[:2] == label_0:
    176             if not os.path.exists(label_0):
    177                 os.mkdir(label_0)
    178             cur_path = os.path.join(tmp_path, label_0, d)
    179             shutil.copyfile(pre_path, cur_path)
    180         if d[:2] == label_1:
    181             if not os.path.exists(label_1):
    182                 os.mkdir(label_1)
    183             cur_path = os.path.join(tmp_path, label_1, d)
    184             shutil.copyfile(pre_path, cur_path)
    185     print('finish this work ...')
    186 
    187 
    188 if __name__ == "__main__":
    189     if not os.path.exists('roc_img'):
    190         os.mkdir('roc_img')
    191     if not os.path.exists('prob_result'):
    192         os.mkdir('prob_result')
    193     if not os.path.exists('report'):
    194         os.mkdir('report')
    195     if not os.path.exists('error_record'):
    196         os.mkdir('error_record')
    197     if not os.path.exists('model'):
    198         os.mkdir('model')
    199     label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w')
    200 
    201     kf = KFold(n_splits=5, shuffle=True, random_state=1)
    202     origin_path = '/home/project/'
    203     dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))])
    204 
    205     for m, n in enumerate(kf.split(dd_list), start=1):
    206         report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
    207         print(f'The {m} fold for copy file and training ...')
    208         move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
    209         os.chdir(origin_path)
    210         move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
    211         os.chdir(origin_path)
    212         data_transforms = {
    213             'train': transforms.Compose([
    214                 # 裁剪到224,224
    215                 transforms.RandomResizedCrop(224),
    216                 # 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
    217                 transforms.RandomHorizontalFlip(),
    218                 # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),  # HSV以及对比度变化
    219                 transforms.ToTensor(),
    220                 # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
    221                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    222             ]),
    223             'val': transforms.Compose([
    224                 transforms.Resize(256),
    225                 transforms.CenterCrop(224),
    226                 transforms.ToTensor(),
    227                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    228             ]),
    229         }
    230 
    231         image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
    232                                                   data_transforms[x])
    233                           for x in ['train', 'val']}
    234         dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
    235                                                       shuffle=True, num_workers=8, pin_memory=False)
    236                        for x in ['train', 'val']}
    237 
    238         dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    239 
    240         class_names = image_datasets['train'].classes
    241         size = len(class_names)
    242         print('label mapping: ')
    243         print(image_datasets['train'].class_to_idx)
    244         use_gpu = torch.cuda.is_available()
    245         model_ft = None
    246         if sys.argv[1] == 'resnet':
    247             model_ft = models.resnet50(pretrained=True)
    248             num_ftrs = model_ft.fc.in_features
    249             model_ft.fc = nn.Sequential(
    250                 nn.Linear(num_ftrs, N_CLASSES),
    251                 nn.Sigmoid()
    252             )
    253 
    254         # 这边可以自行把inception模型加进去
    255         if sys.argv[1] == 'inception':
    256             raise Exception("not provide inception model ...")
    257             # model_ft = models.inception_v3(pretrained=True)
    258 
    259         if sys.argv[1] == 'desnet':
    260             model_ft = models.densenet121(pretrained=True)
    261             num_ftrs = model_ft.classifier.in_features
    262             model_ft.classifier = nn.Sequential(
    263                 nn.Linear(num_ftrs, N_CLASSES),
    264                 nn.Sigmoid()
    265             )
    266             # use_gpu = False
    267 
    268         if use_gpu:
    269             model_ft = model_ft.cuda()
    270 
    271         criterion = nn.CrossEntropyLoss()
    272         optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    273         # 每7个epoch衰减0.1倍
    274         exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    275         model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
    276         print('Start save the model ...')
    277         torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
    278         print(f'The mission of the fold {m} finished.')
    279         print('# '*50)
    280         report_file.close()
    281     label_file.close()

     

    原文作者:pytorch
    原文地址: https://www.cnblogs.com/laresh/p/9552332.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞