1.想好你的网络结构,这点是无人可以帮你的,无论是什么结构,你在写代码前一定要对自己的结构了若指掌,具体包括维度,结构,要用的函数,输出是什么等,总而言之,不能在对自己结构还不是很清楚的时候就开始写代码,否则你会写的非常迷糊,各种数据的维度混乱,一路报错。
2.分割数据集,对数据进行预处理,这一部分也可以说是一个重点内容,一个优良处理过的数据集会使得你搭建网络不用照顾太多关于数据本身的问题,而专注于构建模型的结构,可以说无论怎么分割数据,你的数据都是要为了你的网络结构而服务,而非是你定义网络时要去遵从你数据的样子,否则这样你的模型泛用性能会非常的低下,毕竟如果要为了数据的输入而刻意更改网络结构的话,则相当于换一个数据集你可能就要重构一次网络的前向传播方法,这使得当你用该模型做一个论文时,模型相当于是在不断的变化的,这一点尤其不利。
3.搭建自己的网络,根据pytorch的官方指南,一个网络至少要实现以下两种方法:
- 网络的初始化,这些定义了整个网络要输入的所有参数,以及要前向传播所要用所有函数,这个可以在每一个网络都要继承的初始化里看出来
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
- 网络的前向传播过程,利用初始化中的函数对输入进行前向传播得到你要的结果
- 你自己需要网络实现的其它功能
4.搭建好了网络,我们就需要定义loss,可能有的人会问为什么loss不定义在网络结构里,我个人认为是这样理解的,可以设置在网络中,但是没有必要,因为网络负责的是输出,至于和标注数据的比较,这不是网络要做的事情,这一点也设计到了我第二点阐述的内容,如果你要将loss直接放入你的model中,那么说明你的model只是对这个数据集specific的,换一句话将,也是换一个数据集你的loss函数可能就要变化,但是我们写论文要求的是模型不变的,如果你loss放在模型的定义中,那就相当于你的模型变了,这一点也是不好的。回想一下类似BERT,ELMO,LSTM这样的网络结构,它并没有把loss加入进他们所画的网络图中,因为一个model的输出是这样的,但是具体和哪个标注数据比较,loss怎么算,并不是归模型管,我们是利用loss去优化这个网络的参数,而非是loss本身就是结构中的一环
5.进行梯度回传,这里涉及到了优化器的选择问题,其实这也和loss一样,都不是模型的一部分。
6.随后就是效果测试,参数调节,模型保存。