QizNLP使用:利用Transformer训练单轮闲聊机器人

QizNLP介绍:基于tensorflow(1.x)的NLP框架,提供NLP多种任务(分类、匹配、序列标注、生成等)代码模板,包括数据处理、模型训练、部署推断的全流程,同时内置一些常见模型提供调用,并支持基于horovod的数据并行式分布训练。

Qznan/QizNLP         https://github.com/Qznan/QizNLP

前言

深度学习中的单轮闲聊机器人(single-turn chitchat-bot),通常采用与机器翻译相同的处理范式,即序列到序列模型(seq2seq)。本文将介绍在QizNLP框架中如何利用经典的Transformer模型训练一个闲聊机器人。(想看效果可直接翻到文末截图)

训练数据

训练数据采用清华组发布的LCCC闲聊语料(github地址),该语料包括base&large版本,其中base版本使用了更严格的过滤规则,所以更干净,故本次实验采用base版本。

训练过程

利用QizNLP中的seq2seq模型训练代码模板来进行基于Transformer的闲聊机器人训练。

  • 步骤一:安装QizNLP:

pip install QizNLP
  • 步骤二:在新建项目地址中初始化:

mkdir ~/myproject && cd ~/myproject
qiznlp_init
# 执行完毕后会看到myproject中已生成run、model、data等目录
  • 步骤三:下载LCCC-base-split.zip语料解压后放入/data目录下:

└─data
    └─LCCC-base-split
         ├─LCCC-base_test.json
         ├─LCCC-base_train.json
         └─LCCC-base_valid.json
  • 步骤四:修改run/run_s2s.py中的相关代码,包括:

  1. 设置训练参数,如训练轮数设为4:

‘‘‘ line 23~29 ’’’
conf = utils.dict2obj({
    'early_stop_patience': None,
    'just_save_best': False,
    'n_epochs': 4,
    'data_type': 'tfrecord',
    # 'data_type': 'pkldata',
})

2. 选择字典文件名(通过仅保留LCCC项而注释其它,下同):

’’’ line 43~46 ’’’
self.token2id_dct = {
    # 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/s2s_char2id.dct', use_line_no=True),  # 自有数据
    # 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/XHJ_s2s_char2id.dct', use_line_no=True),  # 小黄鸡
    'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/LCCC_s2s_char2id.dct', use_line_no=True),  # LCCC
}

3. 选择模型(Transformer):

’’’ line 359~360 ’’’
rm_s2s = Run_Model_S2S('trans')  # use transformer seq2seq
# rm_s2s = Run_Model_S2S('rnn_s2s')  # use biGRU encoder + bah_attn + GRU decoder

4. 选择训练函数和指定batch_size:

’’’ line 362~369 ’’’
# 训练自有数据
# rm_s2s.train('s2s_ckpt_1', '../data/s2s_example_data.txt', preprocess_raw_data=preprocess_raw_data, batch_size=512)  # train

# 训练小黄鸡
# rm_s2s.train('s2s_ckpt_XHJ1', '', preprocess_raw_data=preprocess_common_dataset_XiaoHJ, batch_size=512)  # train

# 训练LCCC语料
rm_s2s.train('s2s_ckpt_LCCC1', '', preprocess_raw_data=preprocess_common_dataset_LCCC, batch_size=512)  # train

5. 选择在训练完要测试模型时载入的ckpt名:

’’’ line 373 ’’’
rm_s2s.restore('s2s_ckpt_LCCC1')  # for infer
  • 步骤五:执行python run_s2s.py,开始进行训练。并且训练结束后会自动载入模型进行推断(输入问句,闲聊机器人给出回复):

训练截图:

《QizNLP使用:利用Transformer训练单轮闲聊机器人》

此图是后面补的(重新运行训练),所以tfrecord等文件都已存在,另外请忽略速度(因为这是在cpu..)

训练4个epo后模型loss是:训练集3.32/测试集3.41(目测继续训练还能再降)。载入该模型进行推断。(推断时已集成了简单的回复后处理排序模块)

推断例子:

输入“你好”,模型原始输出(分数越大即绝对值越小越好):

《QizNLP使用:利用Transformer训练单轮闲聊机器人》

输入“你好”,模型原始输出

输入“你好”,模型后处理排序后输出 (bad respond的分数设为极小值-1e7) :

《QizNLP使用:利用Transformer训练单轮闲聊机器人》

输入“你好”,模型后处理排序后输出

输入“最近工作压力好大啊”,模型原始输出:

《QizNLP使用:利用Transformer训练单轮闲聊机器人》

输入“最近工作压力好大啊”,模型原始输出

输入“最近工作压力好大啊”,模型后处理排序后输出:

《QizNLP使用:利用Transformer训练单轮闲聊机器人》

输入“最近工作压力好大啊”,模型后处理排序后输出

  • 步骤六:根据回复效果调整模型解码时的相关参数(model/s2s_model.py中):

# line 15~31
conf = utils.dict2obj({
    'vocab_size': 4500,
    'embed_size': 300,
    'hidden_size': 300,
    'num_heads': 6,
    'num_encoder_layers': 6,
    'num_decoder_layers': 6,
    'dropout_rate': 0.2,
    'lr': 1e-3,
    'pretrain_emb': None,
    # 以上是模型参数,以下是解码时相关参数
    'beam_size': 40,
    'max_decode_len': 50,
    'eos_id': 2,  # 句子结束符对应词典id(默认为2)
    'gamma': 1,  # 多样性鼓励因子
    'num_group': 1,  # 分组beam_search
    'top_k': 30  # 分组beam_search首字符采样范围
})

(步骤五中的回复示例的参数设置即如上图所示)

主要参数说明

  • beam_size:beam_search时束的大小。越大则回复的多样性越好,但计算量及响应时间会增加。设为1 即使用greedy search

  • max_decode_len:最大解码长度

  • gamma:多样性鼓励因子,越大多样性越好。一般设为1。参考此论文

  • num_group:beam_search时设置分组数量,每组共享同一个开头字符,组间不同开头字符,以保持多样性。设为1即不分组

  • top_k:分组时每组开头字符从前k个字符中采样。要求top_k >= num_group

这里尝试展示一下设置num_group=10的效果,此时top_k=30会生效。仍以输入“最近工作压力好大啊”为例,模型原始输出:

《QizNLP使用:利用Transformer训练单轮闲聊机器人》

设置分组后,输入“最近工作压力好大啊”,模型原始输出

输入“最近工作压力好大啊”,模型后处理排序后输出:

《QizNLP使用:利用Transformer训练单轮闲聊机器人》

设置分组后,输入“最近工作压力好大啊”,模型后处理排序后输出

在最终实践应用时,可根据情况截取后处理排序后的前k个作为候选回复,并从中随机采样1个作为最终回复。

模型部署

部署可参考deploy/下的example.py和web_api.py文件。前者有如何载入已训练好的模型(ckpt或pbmodel)的示例,后者有利用tornado搭建模型webAPI服务的示例。具体可参考QizNLP项目所在github的说明^_^

结语

以上就是使用QizNLP框架来快速工程化实现一个基于Transformer的闲聊机器人流程。由于LCCC是中文领域中难得的高质量对话语料,尽管只进行了简单的训练,但机器人回复效果还不错。当然这也离不开在beam_search解码过程中使用的一些优化手段:如多样性鼓励、分组采样等。之后会围绕该框架继续介绍更多NLP实践,欢迎各位大佬star~~

参考

LCCC语料:A Large-Scale Chinese Short-Text Conversation Dataset.(github.com/thu-coai/CDi)

QizNLP:(github.com/Qznan/QizNLP)

 

    原文作者:just do it now
    原文地址: https://blog.csdn.net/yaohaishen/article/details/112848422
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞