复现论文Automated detection of atrial fibrillation using long short-term memory network with RR inter...

摘要

本论文是基于RR间期的使用双向LSTM网络的房颤自动检测算法研究。房颤是成年人最常见的持久性的心律失调。在成年人中的发病率大约是0.4%,并且随年龄增长而变得更为普遍。随着老龄化的加剧,房颤自动检测变得更有意义。研究表明,房颤会大幅增加中风和血栓的风险,导致更高的发病率和死亡率。

研究发现,房颤的RR间期分布规律不同于正常窦性心律的分布规律[T. Hennig, P. Maass, J. Hayano, S. Heinrichs, Exponential distribution of long heart beat intervals during atrial brillation and their relevance for white noise behaviour in power spectrum, Journal of biological physics 32 (2006) 383{392}.],临床实践中也通过节律变化来进行人工检测。其灵敏度能达到90%以上,特异度也可以达到70%以上。

本文通过提取ECG信号的RR间期特征序列作为模型的输入数据,输入数据长度为100个RR间期的序列,使用双向LSTM网络对labelled数据的RR间期序列的特征进行学习,经过一个最大池化层和一个全连接层,最终对房颤和正常窦性节律进行分类,达到了98%以上的灵敏度和特异度。

数据来源

1.数据格式

本文所用数据来自MIT-BIH房颤数据库(afdb,MIT-BIH Atrial Fibrillation Database,https://physionet.org/physiobank/database/afdb/),该数据库共有25条记录(含3种类型的数据:.hea格式的头文件,.dat格式的数据文件和.atr/.qrs格式的注释文件),但只有23组含.dat数据文件(00735和03665号数据没有.dat文件),所有记录均为双导联。

其中.dat文件格式为format212,即三个字节存储2个数,每个数12bits,每个数分别代表其中某一个导联数据点,数据存储结构参考MIT-BIH Arrhythmia的数据格式(format212,11bits),心律不齐数据库的格式如下:

《复现论文Automated detection of atrial fibrillation using long short-term memory network with RR inter...》 format212格式

MIT-BIH Arrhythmia数据的读取方法可以参考这篇博客的Matlab版的代码(
https://blog.csdn.net/chenyusiyuan/article/details/2040234),还有Python版的代码(
https://www.zhihu.com/question/273874101/answer/389351234)。

2.数据识读

afdb和mitdb的数据格式虽然都是format212,但仍然有差异,不能直接使用上述读取代码。可以通过在physionet官网下载安装wfdb-app-matlab包,使用MATLAB在线读取afdb的数据,并且将其转化为通用的txt格式的数据。以下是我使用wfdb包自动下载和转换格式的matlab代码,供参考。

clc;clear;
path = 'E:\project_TensorFlow\ecg_classify\MIT-BIH AF DB\';
file_list= {'00735','03665','04015','04043','04048','04126','04746','04908','04936','05091','05121','05261','06426','06453','06995','07162','07859','07879','07910','08215','08219','08378','08405','08434','08455'};

% 转换数据和注释文件
for n = 1:length(file_list)
    fprintf('开始下载转换 %s 号数据',file_list{n});
    dpath = strcat('afdb\',file_list{n});
    
    %.dat文件转换
    [sig,fs,tm]=rdsamp(dpath,[1;2]);
    data = [tm,sig];
    dataname = strcat(path,file_list{n},'dat','.txt');
    dlmwrite(dataname, data, 'precision', '%.4f', 'delimiter', '\t');
%     fid=fopen(dataname,'w');%建立文件
%     %循环写入数据
%     for i=1:length(data)
%         fprintf(fid,'%.4f,%.4f,%.4f\r\n',data(i,:));%保存小数点后4位
%     end
%     fclose(fid);
    
    %.atr文件转换
    [atr,anntype,subtype,chan,num,comments] = rdann(dpath,'atr'); %标注的房颤位置
    atr1 = num2cell(atr*0.004);
    atr2 = [atr1,comments];
    atrname = strcat(path,file_list{n},'atr','.txt');
    fid=fopen(atrname,'w');%建立文件
    %循环写入数据
    for i=1:length(atr)
        fprintf(fid,'%.4f\t%s\r\n',cell2mat(atr2(i,1)),cell2mat(atr2(i,2)));%保存小数点后4位
    end
    fclose(fid);
    
    %.qrs文件转换
    [qrs] = rdann(dpath,'qrs')*0.004; %标注的qrs波位置    
    qrsname = strcat(path,file_list{n},'qrs','.txt');
    dlmwrite(qrsname, qrs, 'precision', '%.4f', 'delimiter', '\t');
%     fid=fopen(qrsname,'w');%建立文件
%     %循环写入数据
%     for i=1:length(qrs)
%         fprintf(fid,'%.4f\r\n',qrs(i,:));%保存小数点后4位
%     end
%     fclose(fid);
end

其中,filelist来自于afdb数据库的RECORDS文件(https://physionet.org/physiobank/database/afdb/RECORDS

《复现论文Automated detection of atrial fibrillation using long short-term memory network with RR inter...》 file_list来源

转换完后,得到三类数据的txt文件:

《复现论文Automated detection of atrial fibrillation using long short-term memory network with RR inter...》 下载转换后的数据

预处理

由于该数据库有标注了的QRS波的位置(.qrs注释文件中),可以直接使用,因此不需要我们再去检测R波了,也就没有必要对ECG信号做预处理了。直接应用.qrs注释文件中的QRS位置标注信息提取RR间期特征,构成新的输入序列。论文把RR间期序列作为输入数据,每100个RR间期作为一段输入,根据.atr注释文件中的AFIB标注信息为每段输入模型的数据打标签,只要这段数据里面有一个心拍被标注为了AFIB,即认为这段数据的label为AFIB,除此之外其余所有数据段的label均为Normal。相关Python代码如下:

import os
import numpy as np
import pandas as pd

def rWave(filename_dat,filename_atr,filename_qrs):
    # 并联.dat,.atr,.qrs三类数据,把R波信息筛选出来,包括每个R波对应的时间,是否AFIB
    names1 = ['tm','sig1','sig2']
    dat = pd.read_csv(filename_dat,'\t',header=None,names=names1)
    names2 = ['tm','atr']
    atr = pd.read_csv(filename_atr,'\t',header=None,names=names2)
    names3 = ['tm','qrs']
    qrs = pd.read_csv(filename_qrs,'\t',header=None,names=names3)
    qrs = qrs.fillna(1)
    # print (df.fillna(method='pad')) #向前填充
    # print (df.fillna(method='backfill')) #向后填充
    df = pd.merge(dat, atr, how='outer', on=['tm'])
    df = df.fillna(method='backfill')
    df = df.fillna('(N')
    df1 = pd.merge(df, qrs, how='outer', on=['tm'])
    df1 = df1.fillna(0)
    # print(df1[(df1.qrs==1) & (df1.atr=='(N')][0:1000])
    df2 = df1[df1.qrs==1]
    return df2

def rrInterval(df):
    # 提取RR间期序列及其对应的是否AFIB的标签
    df1 = df.reindex(columns=['tm','sig1','sig2','rr','atr','qrs'])
    df1['rr'] = df1['tm'].diff(-1) * -1
    df2 = df1[0:-1]
    return df2

def inputData(df,num,step):
    # 生成LSTM需要的输入数据,num个RR间期为一段,步长为step
    i = 0
    feature,label = [],[]
    while i < len(df1) - num + 1:
        feature.append(list(df['rr'][i:i+num]))
        if '(AFIB' in list(df['atr'][i:i+num]):
            label.append(1.0)
        else:
            label.append(0.0)
        i = i + step
    # featues = np.array(feature)
    # labels = np.array(label)
    return feature,label


if __name__ == '__main__':

    path = "E:\project_TensorFlow\ecg_classify\data"

    # 遍历path目录及其子目录下的所有文件,将其保存在filenames中
    filenames_dat,filenames_atr,filenames_qrs = [],[],[]
    for root, dirs, files in os.walk(path, topdown=False):
        for name in files:
            # print(os.path.join(root, name))
            if "dat" in name:
                filenames_dat.append(os.path.join(root, name))
            elif "atr" in name:
                filenames_atr.append(os.path.join(root, name))
            else:
                filenames_qrs.append(os.path.join(root, name))

    xdata,ydata = [],[]
    for i in range(0,len(filenames_dat)):
        # print(filenames_dat[i])
        filename_dat = filenames_dat[i]
        filename_atr = filenames_atr[i]
        filename_qrs = filenames_qrs[i]
        df = rWave(filename_dat,filename_atr,filename_qrs)
        # print(df)
        df1 = rrInterval(df)
        # print(df1)
        featues, labels = inputData(df1,100,1)
        # xdata = xdata + featues
        # ydata = ydata + labels
        xdata.extend(featues)
        ydata.extend(labels)
        # print(featues, labels)
        # print(featues.shape, labels.shape)
        print(len(xdata), len(ydata))

    xdata = np.array(xdata)
    ydata = np.array(ydata)
    np.savetxt("E:/project_TensorFlow/ecg_classify/xdata.txt", xdata, fmt='%.3f')
    np.savetxt("E:/project_TensorFlow/ecg_classify/ydata.txt", ydata, fmt='%.1f')

模型构建

该论文中的模型由一个双向LSTM网络、一个最大池化层和两个全连接层组成,双向LSTM网络可以有效的提取输入的100个RR间期序列的信息,经过一个最大池化层保留重要信息,最后经过全连接层做一个二分类——房颤或者非房颤。模型架构如下图所示:

《复现论文Automated detection of atrial fibrillation using long short-term memory network with RR inter...》 双向LSTM网络识别房颤模型

各网络层的参数设置文中也给了明确的说明:

《复现论文Automated detection of atrial fibrillation using long short-term memory network with RR inter...》 模型参数

论文作者提出把23组数据即23个患者的数据被拆分为两部分,其中20组用来训练,其他3组作为盲测。在训练时,又做了一个10折交叉验证,以提高模型的泛化能力。最终达到98%以上的灵敏度和特异度。

《复现论文Automated detection of atrial fibrillation using long short-term memory network with RR inter...》 交叉验证的结果和测试集的效果

根据文论所述,本文作者使用TensorFlow+Keras完整实现了模型的构建,并在本地开始了漫长的训练和调优。以下是相关Python代码:

import numpy as np
from keras import optimizers
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense,Dropout,Embedding,LSTM,Bidirectional,GlobalMaxPooling1D
from keras.layers.core import Reshape
from keras.layers.normalization import BatchNormalization
from keras.utils import plot_model
from keras.models import load_model
from keras.callbacks import ModelCheckpoint,CSVLogger
from keras.initializers import glorot_uniform
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedShuffleSplit,cross_val_score,cross_val_predict
import time


def build_model():
    """
    建立模型,包含一个双向LSTM层,1维全局最大池化层,两个全连接层
    """
    model = Sequential()
    model.add(Reshape((100,1),input_shape=(100,)))
    model.add(Bidirectional(LSTM(200,return_sequences=True,kernel_initializer=glorot_uniform(seed=1),recurrent_dropout=0.1),merge_mode='concat'))
    # model.add(BatchNormalization())
    model.add(GlobalMaxPooling1D())
    model.add(Dense(50, activation='relu',kernel_initializer=glorot_uniform(seed=2)))
    # model.add(BatchNormalization())
    model.add(Dropout(0.1))
    model.add(Dense(1, activation='sigmoid',kernel_initializer=glorot_uniform(seed=3)))
    adam = optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 
    model.compile(loss='binary_crossentropy',optimizer=adam,metrics=['accuracy'])
    return model

if __name__ == '__main__':
    # 读取afdb所有数据
    xdata = np.loadtxt('./xdata.txt')
    ydata = np.loadtxt('./ydata.txt')

    # 取2号数据做demo,约6万的特征和label对
    x_train = xdata[0:968303]
    x_test = xdata[968303:1126261]
    y_train = ydata[0:968303]
    y_test = ydata[968303:1126261]

    filepath="./model_{epoch:03d}-{val_acc:.4f}.hdf5"
    checkpoint = ModelCheckpoint(filepath, monitor='val_acc',verbose=1, save_best_only=True)
    csvlog = CSVLogger('./trainlog.csv', separator=',', append=True)
    # cross validation
    skf = StratifiedShuffleSplit(n_splits=10,random_state=0)
    model = build_model()
    for cnt,(train,test) in enumerate(skf.split(x_train,y_train)):
        train_data=x_train[train,:]
        test_data=x_train[test,:]
        train_label=y_train[train]
        test_label=y_train[test]
        # 若训练的模型中断了,可以加载已训练出来的模型,修改fit函数中initial_epoch的值,initial_epoch=0表示从第1个epoch训练
        model.fit(train_data, train_label,batch_size=1024,epochs=80,validation_data=(test_data,test_label),
            shuffle=True,callbacks=[csvlog,checkpoint],initial_epoch=0)
        model = build_model()

测试集效果评估:

import sys
import numpy as np
from keras.models import load_model
from keras.utils import plot_model

def calAccSeSp(real,pred):
    pred1 = np.around(pred,0)
    # real,pred均为np.array类型的数据
    acc=0;se=0;sp=0
    for i in range(len(real)):
        if real[i] == pred1[i]:
            acc+=1
            if real[i] == 1.0:
                se+=1
            elif real[i]==0.0:
                sp+=1
    Sensitivity = se/np.sum(real==1.0)
    Specificity = sp/np.sum(real==0.0)
    Accuracy = acc/len(real)
    return Accuracy,Sensitivity,Specificity

if __name__ == '__main__':
    # 读取afdb所有数据
    xdata = np.loadtxt('./xdata.txt')
    ydata = np.loadtxt('./ydata.txt')

    # 取2号数据做demo,约6万的特征和label对
    x_train = xdata[0:968303]
    x_test = xdata[968303:1126261]
    y_train = ydata[0:968303]
    y_test = ydata[968303:1126261]

    # 加载训练好的模型,命令行传入模型文件(含路径) eg: python3 model_predict.py ./model_020-0.9999.hdf5
    modelfile = sys.argv[1]
    model = load_model(modelfile)
    y_pred = model.predict(x_test)

    acc,se,sp = calAccSeSp(y_test,y_pred)
    print("Performance of the BiLSTM Model is: Accuracy:%.2f%% Sensitivity:%.2f%% Specificity:%.2f%%" % (acc*100,se*100,sp*100))

然而,经过了数次试验,包括增加BN层、修改学习率、调高dropout比例……毫无例外,均出现了模型在训练集和验证集acc达到99%以上,但在测试集上最好不超过85%的情况。与几位同行和前辈讨论,一致认为是样本太少,数据存在病人维度的差异。
那么问题来了,本论文的作者是怎么做到在测试集上达到甚至超过训练集的99%的acc?

(以上便是本论文复现的全部过程,仅供参考,如果不是我的思路或者代码有问题,那么现在的学术论文的质量当真堪忧呀……)

    原文作者:龙少伊
    原文地址: https://www.jianshu.com/p/757c7b7d810f
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞