pytorch kaggle 泰坦尼克生存预测

《pytorch kaggle 泰坦尼克生存预测》

也不知道对不对,就凭着自己的思路写了一个

数据集:https://www.kaggle.com/c/titanic/data

 

  1 import torch
  2 import torch.nn as nn
  3 import pandas as pd
  4 import numpy as np
  5 
  6 
  7 class DataProcessing(object):
  8     def __init__(self):
  9         pass
 10 
 11     def get_data(self):
 12         data_train = pd.read_csv('train.csv')
 13         label = data_train[['Survived']]
 14         data_test = pd.read_csv('test.csv')
 15         # 读取指定列
 16         gender = pd.read_csv('gender_submission.csv', usecols=[1])
 17         return data_train, label, data_test, gender
 18 
 19     def data_processing(self, data_):
 20         # 训练集测试集都进行相同的处理
 21         data = data_[['Pclass', 'Sex', 'Age', 'SibSp', 'Fare', 'Cabin', 'Embarked']]
 22         data['Age'] = data['Age'].fillna(data['Age'].mean())
 23         data['Cabin'] = pd.factorize(data.Cabin)[0]
 24         data.fillna(0, inplace=True)
 25         data['Sex'] = [1 if x == 'male' else 0 for x in data.Sex]
 26         data['p1'] = np.array(data['Pclass'] == 1).astype(np.int32)
 27         data['p2'] = np.array(data['Pclass'] == 2).astype(np.int32)
 28         data['p3'] = np.array(data['Pclass'] == 3).astype(np.int32)
 29         data['e1'] = np.array(data['Embarked'] == 'S').astype(np.int32)
 30         data['e2'] = np.array(data['Embarked'] == 'C').astype(np.int32)
 31         data['e3'] = np.array(data['Embarked'] == 'Q').astype(np.int32)
 32         del data['Pclass']
 33         del data['Embarked']
 34         return data
 35 
 36     def data(self):
 37         # 读数据
 38         train_data, label, test_data, gender = self.get_data()
 39         # 处理数据
 40         # 训练集输入数据
 41         train = np.array(data_processing.data_processing(train_data))
 42         # 训练集标签
 43         train_label = np.array(label)
 44         # 测试集
 45         test = np.array(data_processing.data_processing(test_data))
 46         # 测试集标签
 47         test_label = np.array(gender)
 48 
 49         train = torch.from_numpy(train).float()
 50         train_label = torch.tensor(train_label).float()
 51         test = torch.tensor(test).float()
 52         test_label = torch.tensor(test_label)
 53 
 54         return train, train_label, test, test_label
 55 
 56 
 57 class MyNet(nn.Module):
 58     def __init__(self):
 59         super(MyNet, self).__init__()
 60         self.fc = nn.Sequential(
 61             nn.Linear(11, 7),
 62             nn.Sigmoid(),
 63             nn.Linear(7, 7),
 64             nn.Sigmoid(),
 65             nn.Linear(7, 1),
 66         )
 67         self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001)
 68         self.mls = nn.MSELoss()
 69 
 70     def forward(self, inputs):
 71         # 前向传播
 72         return self.fc(inputs)
 73 
 74     def train(self, inputs, y):
 75         # 训练
 76         out = self.forward(inputs)
 77         loss = self.mls(out, y)
 78         self.opt.zero_grad()
 79         loss.backward()
 80         self.opt.step()
 81         # print(loss)
 82 
 83     def test(self, x, y):
 84         # 测试
 85         # 将variable张量转为numpy
 86         # out = self.fc(x).data.numpy()
 87         count = 0
 88         out = self.fc(x)
 89         sum = len(y)
 90         for i, j in zip(out, y):
 91             i = i.detach().numpy()
 92             j = j.detach().numpy()
 93             loss = abs((i - j)[0])
 94             if loss < 0.3:
 95                 count += 1
 96         # 误差0.3内的正确率
 97         print(count/sum)
 98 
 99 
100 if __name__ == '__main__':
101     data_processing = DataProcessing()
102     train_data, train_label, test_data, test_label = data_processing.data()
103     net = MyNet()
104     count = 0
105     for i in range(20000):
106         # 为了减小电脑压力,分批训练 100个训练一次  ## 2018.12.22补充:正确的做法应该是用batch
107         for n in range(len(train_data)//100 + 1):
108             batch_data = train_data[n*100: n*100 + 100]
109             batch_label = train_label[n*100: n*100 + 100]
110             net.train(train_data, train_label)
111     net.test(test_data, test_label)  # 输出结果:0.7488038277511961

效果一般吧,不过至少出来了,hiahiahia

    原文作者:pytorch
    原文地址: https://www.cnblogs.com/MC-Curry/p/10120265.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞