简介
这篇博客将会简要记录使用python版本的fastText对不同类别新闻进行分类,中间会使用结巴分词,pandas的数据处理。新闻数据可以使用清华的新闻数据。
安装依赖
Python版本:3.6
安装结巴分词以及fasttext
pip install jieba
pip install fasttext
分词处理
分词过程中会删除一些常用的停用词,停用词可以使用https://github.com/dongxiexidian/Chinese/tree/master/dict
segmentation.py
import jieba
import pandas as pd
import codecs
import math
import random
stopwords_set = set()
basedir = '/Users/derry/Desktop/Data/'
# 分词结果文件
train_file = codecs.open(basedir + "news.data.seg.train", 'w', 'utf-8')
test_file = codecs.open(basedir + "news.data.seg.test", 'w', 'utf-8')
# 停用词文件
with open(basedir + 'stop_text.txt', 'r', encoding='utf-8') as infile:
for line in infile:
stopwords_set.add(line.strip())
train_data = pd.read_table(basedir + 'News_info_train.txt', header=None, error_bad_lines=False)
label_data = pd.read_table(basedir + 'News_pic_label_train.txt', header=None, error_bad_lines=False)
train_data.drop([2], axis=1, inplace=True)
train_data.columns = ['id', 'text']
label_data.drop([2, 3], axis=1, inplace=True)
label_data.columns = ['id', 'class']
train_data = pd.merge(train_data, label_data, on='id', how='outer')
for index, row in train_data.iterrows():
# 结巴分词
seg_text = jieba.cut(row['text'].replace("\t", " ").replace("\n", " "))
outline = " ".join(seg_text)
outline = " ".join(outline.split())
# 去停用词与HTML标签
outline_list = outline.split(" ")
outline_list_filter = [item for item in outline_list if item not in stopwords_set]
outline = " ".join(outline_list_filter)
# 写入
if not math.isnan(row['class']):
outline = outline + "\t__label__" + str(int(row['class'])) + "\n"
train_file.write(outline)
train_file.flush()
# 划分数据集
# if random.random() > 0.7:
# test_file.write(outline)
# test_file.flush()
# else:
# train_file.write(outline)
# train_file.flush()
train_file.close()
test_file.close()
分类预测
这里使用fasttext进行训练的时候调整了一下参数word_ngrams,原本默认值为1,效果可能会好一点。不过要在后面加上bucket=2000000(默认值) ,不然会出错,在issue里面查了一下,好像是Python版本的fasttext版本比较旧,使用官方C++版就不会出现这个问题了。
classification.py
import logging
import fasttext
import pandas as pd
import codecs
basedir = '/Users/derry/Desktop/Data/'
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
# 训练
classifier = fasttext.supervised(basedir + "news.data.seg.train", basedir + "news.dat.seg.model", label_prefix="__label__", word_ngrams=3, bucket=2000000)
# 测试并输出 F-score
result = classifier.test(basedir + "news.data.seg.test")
print(result.precision * result.recall * 2 / (result.recall + result.precision))
# 读取验证集
validate_texts = []
with open(basedir + 'news.data.seg.validate', 'r', encoding='utf-8') as infile:
for line in infile:
validate_texts += [line]
# 预测结果
labels = classifier.predict(validate_texts)
# 结果文件
result_file = codecs.open(basedir + "result.txt", 'w', 'utf-8')
validate_data = pd.read_table(basedir + 'News_info_validate.txt', header=None, error_bad_lines=False)
validate_data.drop([2], axis=1, inplace=True)
validate_data.columns = ['id', 'text']
# 写入
for index, row in validate_data.iterrows():
outline = row['id'] + '\t' + labels[index][0] + '\tNULL\tNULL\n'
result_file.write(outline)
result_file.flush()
result_file.close()
最后附上GitHub地址:https://github.com/DerryChan/CSIT6000/tree/master/Derry