PyTorch_0

 1 import torch
 2 from torch.autograd import Variable
 3 batch_n      = 100        #一个批次中输入数据的数量
 4 hidden_layer = 100        
 5 input_data   = 1000        #每个数据包含特征数量
 6 output_data  = 10        
 7 
 8 class Model(torch.nn.Module):
 9 
10     def __init__(self):
11         super(Model,self).__init__()
12 
13     def forward(self,input,w1,w2):
14         x = torch.mm(input,w1)
15         x = torch.clamp(x, min=0)
16         x = torch.mm(x,w2)
17         return x
18 
19     def backward(self):
20         pass
21 
22 
23 # x = torch.randn(batch_n, input_data)                #示例输入(方法一)
24 # y = torch.randn(batch_n, output_data)                #标准输出
25 x = Variable(torch.randn(batch_n,input_data),requires_grad = False)        #方法二、三
26 y = Variable(torch.randn(batch_n,output_data),requires_grad = False)
27 
28 # w1 = torch.randn(input_data, hidden_layer)        #权重(方法一)
29 # w2 = torch.randn(hidden_layer, output_data)
30 w1 = Variable(torch.randn(input_data,hidden_layer),requires_grad = True)    #方法二、三
31 w2 = Variable(torch.randn(hidden_layer,output_data),requires_grad= True)
32 
33 epoch_n = 30                #步数
34 learning_rate = 1e-6        #学习率
35 
36 model = Model()
37 
38 for epoch in range(epoch_n):
39     # y_pred= x.mm(w1).clamp(min=0).mm(w2)
40     y_pred = model(x,w1,w2)
41     loss = (y_pred - y).pow(2).sum()                        #损失函数
42     # print("Epoch:{ }, Loss:{:.4f}",format(epoch,loss))
43     print("Epoch:",epoch, " Loss:",loss.item())
44 
45     # 反向传播 方法一:手动计算梯度,更新权值
46     # grad_y_pred = 2*(y_pred - y)
47     # grad_w2 = h1.t().mm(grad_y_pred)
48     #
49     # grad_h = grad_y_pred.clone()
50     # grad_h = grad_h.mm(w2.t())
51     # grad_h.clamp_(min=0)
52     # grad_w1 = x.t().mm(grad_h)
53     #
54     # w1 -= learning_rate*grad_w1
55     # w2 -= learning_rate*grad_w2
56 
57     # 反向传播 方法二:自动梯度
58     loss.backward()
59 
60     w1.data -= learning_rate*w1.grad
61     w2.data -= learning_rate*w2.grad
62 
63     w1.grad.data.zero_()
64     w2.grad.data.zero_()

 

    原文作者:pytorch
    原文地址: https://www.cnblogs.com/RaspberryFarmer/p/10348742.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞