2

ECGでの心拍分類+転移学習

2 min read

論文を基に1-D CNN・転移学習・ECGデータ処理法を実装法を中心に学ぶ。

Abstract

ECG Heartbeat Classification: A Deep Transferable Representation

従来のECG解析に関する機械学習メソッドは異なるタスクごとに独立したものであったが、この論文ではタスクごとに知識を再利用できないかを検証することを目的としている。

まずMIT-BIHデータセットで不整脈分類タスク(5クラス)について学習させる。その後得られた学習済みNN(転移学習)に対してPTBデータセットで心筋梗塞2値分類タスクについて学習させた。

結果、平均Accuracyは不整脈分類タスクでは 93.4% で心筋梗塞2値分類タスクでは 95.9%と高精度な予測器を生成することができた。この結果から不整脈分類タスクの知識をうまく心筋梗塞2値分類タスクに転移させたことができたと主張している。

Method & Implementation

Preprocessing

以下のような手順でECG波形の前処理を行ってます

  1. ECGデータを10秒間ごとのWindowに分けてそのうちの一つのWindowをとってくる
  2. 0~1に正規化する
  3. 極大値をすべて見つける
  4. 0.9以上の極大値をR-peakとする
  5. RR間隔の平均値を求め、その間隔をWindow長のTとする
  6. それぞれのR-peakから1.2Tだけデータをとる
  7. あらかじめ指定したデータ長に満たないデータを0で埋める(Zero-padding)

これら前処理済みデータセットはすでにKaggleに公開されています。

https://www.kaggle.com/shayanfazeli/heartbeat

実際のデータに学習モデルを適応するためには入力データの形式を学習時のそれと合わせる必要があります。実際に得られるECGデータはサンプリング周波数も異なるので上の前処理に加えてそこも調節する必要があります。

学習データセットは125Hzなのでまずはその周波数に従ってリサンプリングしましょう。

チュートリアル用に360HzでサンプリングされたScipyのECGデータを使います。

import numpy as np
from scipy.misc import electrocardiogram
import matplotlib.pyplot as plt

V = electrocardiogram()
Hz = 360 # 360Hzだから
T = np.arange(ecg.size) * 1000 / Hz
plt.figure(figsize=(10,7))
plt.plot(T, V)
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.xlim(0, 10000)
plt.ylim(-1.1, 2.0)
plt.show()

png

まずは125Hzでリサンプリングします。

from scipy import interpolate

def resample(T, V, Hz=125, kind='linear'):
    f = interpolate.interp1d(T,V,kind=kind)
    T = np.arange(np.min(T), np.max(T), 1000/Hz)
    V = f(T)
    return T, V

T_new, V_new = resample(T, V)
plt.figure(figsize=(10,7))
plt.plot(T, V)
plt.plot(T_new, V_new)
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.xlim(0, 10000)
plt.ylim(-1.1, 2.0)
plt.show()

png

ほぼ一致していることが分かります。それでは1.からみていきましょう。

  1. ECGデータを10秒間ごとのWindowに分けてそのうちの一つのWindowをとってくる
def split(T, V, window=10):
    Hz = int(1000 / (T[1] - T[0]))
    Ts = []
    Vs = []
    for i in range(0, len(T), window*Hz):
        if T[i + window * Hz - 1:i + window * Hz]:
            Ts.append(T[i:i+window*Hz])
            Vs.append(V[i:i+window*Hz])
        else:
            Ts.append(T[i:])
            Vs.append(V[i:])    
    return Ts, Vs

Ts, Vs = split(T_new, V_new)

plt.figure(figsize=(10,7))
for T, V in zip(Ts,Vs):    
    plt.plot(T, V)
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.show()

png

  1. 0~1に正規化する
def normalize(V):
    return (V-np.min(V))/(np.max(V)-np.min(V))

T_new, V_new = Ts[0], Vs[0]
V_new = normalize(V_new)

plt.figure(figsize=(10,7))
plt.plot(T_new, V_new)
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.show()

png

  1. 極大値をすべて見つける
  2. 0.9以上の極大値をR-peakとする
from scipy.signal import find_peaks

def find_R_peaks(V, threshold=0.9):
    R_peaks, _ = find_peaks(V, height=threshold)
    return R_peaks

R_peaks = find_R_peaks(V_new)

plt.figure(figsize=(10,7))
plt.plot(T_new, V_new)
plt.scatter(T_new[R_peaks], V_new[R_peaks], color='r')
plt.xlabel("[ms]")
plt.ylabel("[mV]")
plt.show()

png

このデータではThreshold=0.9というのはあまりよくないようですが、論文の通りにいきましょう。

  1. RR間隔の平均値を求め、その間隔をWindow長のTとする
def find_median_interval(R_peaks):
    return np.mean(np.diff(R_peaks)) # index

interval = find_median_interval(R_peaks)
  1. それぞれのR-peakから1.2Tだけデータをとる
  2. あらかじめ指定したデータ長に満たないデータを0で埋める(Zero-padding)
def extract_beats(T, V, R_peaks, interval, max_duration=187):
    window = int(1.2*interval) # index
    beats = []
    durations = []
    for peak in R_peaks:
        beat = np.zeros(max_duration) # 固定長の空の行列をつくっとく

        if peak + window <= len(V): # R_peakからWindow長データを取り切れる前提をおく

            if window > max_duration: # Window長が指定した固定長を超えている場合
                duration = [T[peak],T[peak+max_duration-1]]
                beat = V[peak:peak+max_duration]
                beats.append(beat)
                durations.append(duration)

            else:
                duration = [T[peak],T[peak+window-1]]
                beat[:window] = V[peak:peak+window]
                beats.append(beat)
                durations.append(duration)
            
    return np.array(beats), durations # 抽出された心拍データとその始まりと終わりの時間を返す

beats, durations = extract_beats(T_new, V_new, R_peaks, interval)
print("Shape of the extracted beats data: ", beats.shape)
Shape of the extracted beats data:  (4, 187)
def ecg_with_beats(T,V,durations):
    fig = plt.figure(figsize = (10,7))
    ax = fig.add_subplot(111)
    for i in range(len(durations)):
        duration = durations[i]
        ax.axvspan(duration[0], duration[1],color="coral" if i%2 == 0 else "lime" ,alpha=0.3)
    ax.plot(T,V)
    plt.xlabel("[ms]")
    plt.ylabel("[mV]")
    plt.show()
    return

ecg_with_beats(T_new, V_new, durations)

png

plt.figure(figsize=(10,7))
for beat in beats:
    plt.plot(beat)
plt.xlabel("index")
plt.ylabel("[mV]")
plt.show()

png

上手く心拍を抽出できていることがわかります。

最後に以上の処理をpreprocess関数にまとめてみましょう。

def preprocess(T, V, Hz=125, max_duration=187):
    T, V = resample(T, V, Hz)
    Ts, Vs = split(T, V)
    Beats = []
    Durations = []
    for T, V in zip(Ts, Vs):
        V = normalize(V)
        R_peaks = find_R_peaks(V)
        if len(R_peaks) >= 2:
            interval = find_median_interval(R_peaks)
            beats, durations = extract_beats(T, V, R_peaks, interval)
            if len(beats) >= 1:
                Beats.append(beats)
                Durations += durations
    Beats = np.vstack(Beats)
    return Beats, Durations

beats, durations = preprocess(T, V)
print("Shape of the extracted beats data: ", beats.shape)
ecg_with_beats(T, V, durations)
Shape of the extracted beats data:  (104, 187)

png

ほんの一部の心拍が抽出されていることが分かります。この論文大丈夫か心配になってきましたね。

Model

論文のモデルは以下

https://github.com/CVxTz/ECG_Heartbeat_Classification

論文にのってるモデルを少し変えたやつ(Residual blockなしバージョン)

1D-Convolution layer

"all convolution layers are applying 1-D convolution through time and each have 32 kernels of size 5"

  • カーネル:入力にかける行列のこと、今回は1次元。32 Kernalsはカーネルの層数を意味するので**Kerasだったらfilters = 32**にあたる。
  • サイズ:カーネルのWindow長。今回は一次元。size 5は**Kerasだったらkernel_size = 5**にあたる。

二次元畳み込み層よりもパラメータ数はもちろん少ない。今回は入力が一次元なので一次元畳み込み層で自然。

畳み込み層ついて詳しくは https://towardsdatascience.com/types-of-convolution-kernels-simplified-f040cb307c37

Dataset

https://www.kaggle.com/shayanfazeli/heartbeat のデータセットを使います。前処理済み最高長187の心拍がCSVファイルで格納されています。188番目のカラムにはその心拍のラベル(心室期外収縮や心筋梗塞など)が整数クラスで入ってます。

MITBIHのAnnotationは以下のようになってます。

N,S,V,F,Qはそれぞれ0,1,2,3,4クラスに対応しています。実際にデータセットを見てみましょう。

import pandas as pd

df_train = pd.read_csv("/content/drive/My Drive/kaggle_ECG/mitbih_train.csv", header=None) # 自分のGoogle driveにでもデータセットダウンロード
print("Data shape: ", df_train.shape)
print("All classes (shown in 188th column): ", df_train.iloc[:,187].unique())
Data shape:  (87554, 188)
All classes (shown in 188th column):  [0. 1. 2. 3. 4.]
plt.figure(figsize=(10,7))
for beat in df_train.iloc[:5,:].values:
    plt.plot(beat)
plt.xlabel("index")
plt.ylabel("[mV]")
plt.show()

png

論文記載のアルゴリズムに従って前処理されていることが分かります。

Training the Arrhythmia Classifier

MITBIHデータセットでまずは学習します。github借りパくです。

from keras import optimizers, losses, activations, models
from keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, ReduceLROnPlateau
from keras.layers import Dense, Input, Dropout, Convolution1D, MaxPool1D, GlobalMaxPool1D, GlobalAveragePooling1D, \
    concatenate
from sklearn.metrics import f1_score, accuracy_score


df_train = pd.read_csv("/content/drive/My Drive/kaggle_ECG/mitbih_train.csv", header=None) 
df_train = df_train.sample(frac=1)
df_test = pd.read_csv("/content/drive/My Drive/kaggle_ECG/mitbih_test.csv", header=None)

Y = np.array(df_train[187].values).astype(np.int8)
X = np.array(df_train[list(range(187))].values)[..., np.newaxis]

Y_test = np.array(df_test[187].values).astype(np.int8)
X_test = np.array(df_test[list(range(187))].values)[..., np.newaxis]


def get_model_mitbih():
    nclass = 5
    inp = Input(shape=(187, 1))
    img_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(inp)
    img_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(256, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(256, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = GlobalMaxPool1D()(img_1)
    img_1 = Dropout(rate=0.2)(img_1)

    dense_1 = Dense(64, activation=activations.relu, name="dense_1")(img_1)
    dense_1 = Dense(64, activation=activations.relu, name="dense_2")(dense_1)
    dense_1 = Dense(nclass, activation=activations.softmax, name="dense_3_mitbih")(dense_1)

    model = models.Model(inputs=inp, outputs=dense_1)
    opt = optimizers.Adam(0.001)

    model.compile(optimizer=opt, loss=losses.sparse_categorical_crossentropy, metrics=['acc'])
    model.summary()
    return model

model = get_model_mitbih()
file_path = "/content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5"
checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
early = EarlyStopping(monitor="val_acc", mode="max", patience=5, verbose=1)
redonplat = ReduceLROnPlateau(monitor="val_acc", mode="max", patience=3, verbose=2)
callbacks_list = [checkpoint, early, redonplat]  # early

model.fit(X, Y, epochs=1000, verbose=2, callbacks=callbacks_list, validation_split=0.1)
model.load_weights(file_path)

pred_test = model.predict(X_test)
pred_test = np.argmax(pred_test, axis=-1)

f1 = f1_score(Y_test, pred_test, average="macro")

print("Test f1 score : %s "% f1)

acc = accuracy_score(Y_test, pred_test)

print("Test accuracy score : %s "% acc)
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 187, 1)]          0         
_________________________________________________________________
conv1d (Conv1D)              (None, 183, 16)           96        
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 179, 16)           1296      
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 89, 16)            0         
_________________________________________________________________
dropout (Dropout)            (None, 89, 16)            0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 87, 32)            1568      
_________________________________________________________________
conv1d_3 (Conv1D)            (None, 85, 32)            3104      
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 42, 32)            0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 42, 32)            0         
_________________________________________________________________
conv1d_4 (Conv1D)            (None, 40, 32)            3104      
_________________________________________________________________
conv1d_5 (Conv1D)            (None, 38, 32)            3104      
_________________________________________________________________
max_pooling1d_2 (MaxPooling1 (None, 19, 32)            0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 19, 32)            0         
_________________________________________________________________
conv1d_6 (Conv1D)            (None, 17, 256)           24832     
_________________________________________________________________
conv1d_7 (Conv1D)            (None, 15, 256)           196864    
_________________________________________________________________
global_max_pooling1d (Global (None, 256)               0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                16448     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3_mitbih (Dense)       (None, 5)                 325       
=================================================================
Total params: 254,901
Trainable params: 254,901
Non-trainable params: 0
_________________________________________________________________
Epoch 1/1000

Epoch 00001: val_acc improved from -inf to 0.92005, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.3817 - acc: 0.8841 - val_loss: 0.2877 - val_acc: 0.9201
Epoch 2/1000

Epoch 00002: val_acc improved from 0.92005 to 0.95420, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.2364 - acc: 0.9326 - val_loss: 0.1660 - val_acc: 0.9542
Epoch 3/1000

Epoch 00003: val_acc improved from 0.95420 to 0.96551, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1691 - acc: 0.9537 - val_loss: 0.1357 - val_acc: 0.9655
Epoch 4/1000

Epoch 00004: val_acc improved from 0.96551 to 0.96962, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1409 - acc: 0.9620 - val_loss: 0.1146 - val_acc: 0.9696
Epoch 5/1000

Epoch 00005: val_acc improved from 0.96962 to 0.97270, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1234 - acc: 0.9672 - val_loss: 0.1008 - val_acc: 0.9727
Epoch 6/1000

Epoch 00006: val_acc improved from 0.97270 to 0.97487, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.1114 - acc: 0.9697 - val_loss: 0.0901 - val_acc: 0.9749
Epoch 7/1000

Epoch 00007: val_acc improved from 0.97487 to 0.97602, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.1013 - acc: 0.9723 - val_loss: 0.0847 - val_acc: 0.9760
Epoch 8/1000

Epoch 00008: val_acc did not improve from 0.97602
2463/2463 - 10s - loss: 0.0945 - acc: 0.9741 - val_loss: 0.0886 - val_acc: 0.9751
Epoch 9/1000

Epoch 00009: val_acc improved from 0.97602 to 0.97796, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0894 - acc: 0.9752 - val_loss: 0.0814 - val_acc: 0.9780
Epoch 10/1000

Epoch 00010: val_acc improved from 0.97796 to 0.97967, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0840 - acc: 0.9771 - val_loss: 0.0723 - val_acc: 0.9797
Epoch 11/1000

Epoch 00011: val_acc did not improve from 0.97967
2463/2463 - 10s - loss: 0.0795 - acc: 0.9775 - val_loss: 0.0743 - val_acc: 0.9788
Epoch 12/1000

Epoch 00012: val_acc improved from 0.97967 to 0.98184, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0772 - acc: 0.9776 - val_loss: 0.0666 - val_acc: 0.9818
Epoch 13/1000

Epoch 00013: val_acc did not improve from 0.98184
2463/2463 - 10s - loss: 0.0741 - acc: 0.9790 - val_loss: 0.0649 - val_acc: 0.9814
Epoch 14/1000

Epoch 00014: val_acc did not improve from 0.98184
2463/2463 - 11s - loss: 0.0702 - acc: 0.9798 - val_loss: 0.0660 - val_acc: 0.9802
Epoch 15/1000

Epoch 00015: val_acc improved from 0.98184 to 0.98241, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0711 - acc: 0.9795 - val_loss: 0.0708 - val_acc: 0.9824
Epoch 16/1000

Epoch 00016: val_acc did not improve from 0.98241
2463/2463 - 10s - loss: 0.0679 - acc: 0.9804 - val_loss: 0.0646 - val_acc: 0.9823
Epoch 17/1000

Epoch 00017: val_acc did not improve from 0.98241
2463/2463 - 10s - loss: 0.0656 - acc: 0.9808 - val_loss: 0.0651 - val_acc: 0.9812
Epoch 18/1000

Epoch 00018: val_acc improved from 0.98241 to 0.98344, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0641 - acc: 0.9813 - val_loss: 0.0602 - val_acc: 0.9834
Epoch 19/1000

Epoch 00019: val_acc did not improve from 0.98344
2463/2463 - 10s - loss: 0.0615 - acc: 0.9820 - val_loss: 0.0659 - val_acc: 0.9814
Epoch 20/1000

Epoch 00020: val_acc did not improve from 0.98344
2463/2463 - 10s - loss: 0.0615 - acc: 0.9820 - val_loss: 0.0592 - val_acc: 0.9828
Epoch 21/1000

Epoch 00021: val_acc did not improve from 0.98344

Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
2463/2463 - 11s - loss: 0.0599 - acc: 0.9828 - val_loss: 0.0545 - val_acc: 0.9831
Epoch 22/1000

Epoch 00022: val_acc improved from 0.98344 to 0.98561, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0449 - acc: 0.9867 - val_loss: 0.0477 - val_acc: 0.9856
Epoch 23/1000

Epoch 00023: val_acc improved from 0.98561 to 0.98584, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 11s - loss: 0.0408 - acc: 0.9879 - val_loss: 0.0453 - val_acc: 0.9858
Epoch 24/1000

Epoch 00024: val_acc improved from 0.98584 to 0.98595, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0390 - acc: 0.9882 - val_loss: 0.0439 - val_acc: 0.9860
Epoch 25/1000

Epoch 00025: val_acc improved from 0.98595 to 0.98721, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5
2463/2463 - 12s - loss: 0.0358 - acc: 0.9891 - val_loss: 0.0426 - val_acc: 0.9872
Epoch 26/1000

Epoch 00026: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0347 - acc: 0.9897 - val_loss: 0.0430 - val_acc: 0.9864
Epoch 27/1000

Epoch 00027: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0351 - acc: 0.9889 - val_loss: 0.0438 - val_acc: 0.9869
Epoch 28/1000

Epoch 00028: val_acc did not improve from 0.98721

Epoch 00028: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
2463/2463 - 11s - loss: 0.0338 - acc: 0.9897 - val_loss: 0.0425 - val_acc: 0.9862
Epoch 29/1000

Epoch 00029: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0324 - acc: 0.9896 - val_loss: 0.0421 - val_acc: 0.9862
Epoch 30/1000

Epoch 00030: val_acc did not improve from 0.98721
2463/2463 - 10s - loss: 0.0324 - acc: 0.9897 - val_loss: 0.0422 - val_acc: 0.9866
Epoch 00030: early stopping
Test f1 score : 0.9158830356755775 
Test accuracy score : 0.9850630367257446 

結果

  • Test f1 score : 0.9158830356755775
  • Test accuracy score : 0.9850630367257446

Training the MI Predictor

先ほどのMITBIHデータセットで得られたNNを利用して心筋梗塞2値分類タスクについてPTBDBデータセットで学習します。

論文では不整脈分類タスクのNNの最後の2層のみFine-tuningしてましたが、今回使うGithubのほうでは最後の2層以外の重みを固定するということはしないで、一緒に学習しなおすということをして実際最終2層より以前の重みをフリーズして学習するよりもスコアが良かったのでそちらを紹介します。

from keras import optimizers, losses, activations, models
from keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, ReduceLROnPlateau
from keras.layers import Dense, Input, Dropout, Convolution1D, MaxPool1D, GlobalMaxPool1D, GlobalAveragePooling1D, \
    concatenate
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

df_1 = pd.read_csv("/content/drive/My Drive/kaggle_ECG/ptbdb_normal.csv", header=None)
df_2 = pd.read_csv("/content/drive/My Drive/kaggle_ECG/ptbdb_abnormal.csv", header=None)
df = pd.concat([df_1, df_2])

df_train, df_test = train_test_split(df, test_size=0.2, random_state=1337, stratify=df[187])


Y = np.array(df_train[187].values).astype(np.int8)
X = np.array(df_train[list(range(187))].values)[..., np.newaxis]

Y_test = np.array(df_test[187].values).astype(np.int8)
X_test = np.array(df_test[list(range(187))].values)[..., np.newaxis]


def get_model_ptbdb():
    nclass = 1
    inp = Input(shape=(187, 1))
    img_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(inp)
    img_1 = Convolution1D(16, kernel_size=5, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(32, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = MaxPool1D(pool_size=2)(img_1)
    img_1 = Dropout(rate=0.1)(img_1)
    img_1 = Convolution1D(256, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = Convolution1D(256, kernel_size=3, activation=activations.relu, padding="valid")(img_1)
    img_1 = GlobalMaxPool1D()(img_1)
    img_1 = Dropout(rate=0.2)(img_1)

    dense_1 = Dense(64, activation=activations.relu, name="dense_1")(img_1)
    dense_1 = Dense(64, activation=activations.relu, name="dense_2")(dense_1)
    dense_1 = Dense(nclass, activation=activations.sigmoid, name="dense_3_ptbdb")(dense_1)

    model = models.Model(inputs=inp, outputs=dense_1)
    opt = optimizers.Adam(0.001)

    model.compile(optimizer=opt, loss=losses.binary_crossentropy, metrics=['acc'])
    model.summary()
    return model

model = get_model_ptbdb()
file_path = "/content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5"
checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
early = EarlyStopping(monitor="val_acc", mode="max", patience=5, verbose=1)
redonplat = ReduceLROnPlateau(monitor="val_acc", mode="max", patience=3, verbose=2)
callbacks_list = [checkpoint, early, redonplat]  # early
model.load_weights("/content/drive/My Drive/kaggle_ECG/baseline_cnn_mitbih.h5", by_name=True)
model.fit(X, Y, epochs=1000, verbose=2, callbacks=callbacks_list, validation_split=0.1)
model.load_weights(file_path)

pred_test = model.predict(X_test)
pred_test = (pred_test>0.5).astype(np.int8)

f1 = f1_score(Y_test, pred_test)

print("Test f1 score : %s "% f1)

acc = accuracy_score(Y_test, pred_test)

print("Test accuracy score : %s "% acc)
Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 187, 1)]          0         
_________________________________________________________________
conv1d_8 (Conv1D)            (None, 183, 16)           96        
_________________________________________________________________
conv1d_9 (Conv1D)            (None, 179, 16)           1296      
_________________________________________________________________
max_pooling1d_3 (MaxPooling1 (None, 89, 16)            0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 89, 16)            0         
_________________________________________________________________
conv1d_10 (Conv1D)           (None, 87, 32)            1568      
_________________________________________________________________
conv1d_11 (Conv1D)           (None, 85, 32)            3104      
_________________________________________________________________
max_pooling1d_4 (MaxPooling1 (None, 42, 32)            0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 42, 32)            0         
_________________________________________________________________
conv1d_12 (Conv1D)           (None, 40, 32)            3104      
_________________________________________________________________
conv1d_13 (Conv1D)           (None, 38, 32)            3104      
_________________________________________________________________
max_pooling1d_5 (MaxPooling1 (None, 19, 32)            0         
_________________________________________________________________
dropout_6 (Dropout)          (None, 19, 32)            0         
_________________________________________________________________
conv1d_14 (Conv1D)           (None, 17, 256)           24832     
_________________________________________________________________
conv1d_15 (Conv1D)           (None, 15, 256)           196864    
_________________________________________________________________
global_max_pooling1d_1 (Glob (None, 256)               0         
_________________________________________________________________
dropout_7 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                16448     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3_ptbdb (Dense)        (None, 1)                 65        
=================================================================
Total params: 254,641
Trainable params: 254,641
Non-trainable params: 0
_________________________________________________________________
Epoch 1/1000

Epoch 00001: val_acc improved from -inf to 0.78112, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.4867 - acc: 0.7630 - val_loss: 0.4635 - val_acc: 0.7811
Epoch 2/1000

Epoch 00002: val_acc improved from 0.78112 to 0.87382, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.3440 - acc: 0.8470 - val_loss: 0.2850 - val_acc: 0.8738
Epoch 3/1000

Epoch 00003: val_acc improved from 0.87382 to 0.87811, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.2697 - acc: 0.8834 - val_loss: 0.2732 - val_acc: 0.8781
Epoch 4/1000

Epoch 00004: val_acc improved from 0.87811 to 0.92275, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.2303 - acc: 0.9035 - val_loss: 0.1837 - val_acc: 0.9227
Epoch 5/1000

Epoch 00005: val_acc improved from 0.92275 to 0.93476, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.2027 - acc: 0.9156 - val_loss: 0.1670 - val_acc: 0.9348
Epoch 6/1000

Epoch 00006: val_acc improved from 0.93476 to 0.95880, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.1842 - acc: 0.9235 - val_loss: 0.1247 - val_acc: 0.9588
Epoch 7/1000

Epoch 00007: val_acc did not improve from 0.95880
328/328 - 1s - loss: 0.1509 - acc: 0.9394 - val_loss: 0.1494 - val_acc: 0.9485
Epoch 8/1000

Epoch 00008: val_acc did not improve from 0.95880
328/328 - 1s - loss: 0.1482 - acc: 0.9408 - val_loss: 0.1179 - val_acc: 0.9562
Epoch 9/1000

Epoch 00009: val_acc improved from 0.95880 to 0.97082, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.1253 - acc: 0.9518 - val_loss: 0.0895 - val_acc: 0.9708
Epoch 10/1000

Epoch 00010: val_acc did not improve from 0.97082
328/328 - 1s - loss: 0.1143 - acc: 0.9570 - val_loss: 0.0847 - val_acc: 0.9648
Epoch 11/1000

Epoch 00011: val_acc did not improve from 0.97082
328/328 - 1s - loss: 0.0967 - acc: 0.9634 - val_loss: 0.0795 - val_acc: 0.9708
Epoch 12/1000

Epoch 00012: val_acc improved from 0.97082 to 0.97253, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0927 - acc: 0.9648 - val_loss: 0.0813 - val_acc: 0.9725
Epoch 13/1000

Epoch 00013: val_acc improved from 0.97253 to 0.97425, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0887 - acc: 0.9675 - val_loss: 0.0731 - val_acc: 0.9742
Epoch 14/1000

Epoch 00014: val_acc improved from 0.97425 to 0.98026, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0780 - acc: 0.9702 - val_loss: 0.0661 - val_acc: 0.9803
Epoch 15/1000

Epoch 00015: val_acc did not improve from 0.98026
328/328 - 1s - loss: 0.0755 - acc: 0.9702 - val_loss: 0.0704 - val_acc: 0.9760
Epoch 16/1000

Epoch 00016: val_acc did not improve from 0.98026
328/328 - 1s - loss: 0.0778 - acc: 0.9716 - val_loss: 0.0559 - val_acc: 0.9803
Epoch 17/1000

Epoch 00017: val_acc improved from 0.98026 to 0.98112, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0717 - acc: 0.9715 - val_loss: 0.0539 - val_acc: 0.9811
Epoch 18/1000

Epoch 00018: val_acc improved from 0.98112 to 0.98283, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0669 - acc: 0.9746 - val_loss: 0.0513 - val_acc: 0.9828
Epoch 19/1000

Epoch 00019: val_acc did not improve from 0.98283
328/328 - 1s - loss: 0.0622 - acc: 0.9774 - val_loss: 0.0822 - val_acc: 0.9760
Epoch 20/1000

Epoch 00020: val_acc did not improve from 0.98283
328/328 - 1s - loss: 0.0598 - acc: 0.9771 - val_loss: 0.0530 - val_acc: 0.9820
Epoch 21/1000

Epoch 00021: val_acc improved from 0.98283 to 0.98455, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0569 - acc: 0.9791 - val_loss: 0.0457 - val_acc: 0.9845
Epoch 22/1000

Epoch 00022: val_acc did not improve from 0.98455
328/328 - 1s - loss: 0.0610 - acc: 0.9790 - val_loss: 0.0497 - val_acc: 0.9785
Epoch 23/1000

Epoch 00023: val_acc did not improve from 0.98455
328/328 - 1s - loss: 0.0517 - acc: 0.9817 - val_loss: 0.0607 - val_acc: 0.9803
Epoch 24/1000

Epoch 00024: val_acc improved from 0.98455 to 0.98627, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0545 - acc: 0.9812 - val_loss: 0.0416 - val_acc: 0.9863
Epoch 25/1000

Epoch 00025: val_acc did not improve from 0.98627
328/328 - 2s - loss: 0.0542 - acc: 0.9796 - val_loss: 0.0483 - val_acc: 0.9837
Epoch 26/1000

Epoch 00026: val_acc did not improve from 0.98627
328/328 - 2s - loss: 0.0496 - acc: 0.9830 - val_loss: 0.0621 - val_acc: 0.9803
Epoch 27/1000

Epoch 00027: val_acc improved from 0.98627 to 0.98884, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0475 - acc: 0.9824 - val_loss: 0.0342 - val_acc: 0.9888
Epoch 28/1000

Epoch 00028: val_acc did not improve from 0.98884
328/328 - 2s - loss: 0.0449 - acc: 0.9817 - val_loss: 0.0705 - val_acc: 0.9760
Epoch 29/1000

Epoch 00029: val_acc did not improve from 0.98884
328/328 - 2s - loss: 0.0451 - acc: 0.9823 - val_loss: 0.0426 - val_acc: 0.9871
Epoch 30/1000

Epoch 00030: val_acc did not improve from 0.98884

Epoch 00030: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
328/328 - 2s - loss: 0.0425 - acc: 0.9844 - val_loss: 0.0460 - val_acc: 0.9837
Epoch 31/1000

Epoch 00031: val_acc improved from 0.98884 to 0.99142, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0290 - acc: 0.9905 - val_loss: 0.0298 - val_acc: 0.9914
Epoch 32/1000

Epoch 00032: val_acc did not improve from 0.99142
328/328 - 1s - loss: 0.0196 - acc: 0.9938 - val_loss: 0.0288 - val_acc: 0.9914
Epoch 33/1000

Epoch 00033: val_acc improved from 0.99142 to 0.99313, saving model to /content/drive/My Drive/kaggle_ECG/baseline_cnn_ptbdb_transfer_fullupdate.h5
328/328 - 2s - loss: 0.0166 - acc: 0.9942 - val_loss: 0.0299 - val_acc: 0.9931
Epoch 34/1000

Epoch 00034: val_acc did not improve from 0.99313
328/328 - 1s - loss: 0.0160 - acc: 0.9937 - val_loss: 0.0294 - val_acc: 0.9914
Epoch 35/1000

Epoch 00035: val_acc did not improve from 0.99313
328/328 - 1s - loss: 0.0160 - acc: 0.9951 - val_loss: 0.0281 - val_acc: 0.9923
Epoch 36/1000

Epoch 00036: val_acc did not improve from 0.99313

Epoch 00036: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
328/328 - 1s - loss: 0.0169 - acc: 0.9944 - val_loss: 0.0275 - val_acc: 0.9906
Epoch 37/1000

Epoch 00037: val_acc did not improve from 0.99313
328/328 - 1s - loss: 0.0150 - acc: 0.9950 - val_loss: 0.0279 - val_acc: 0.9923
Epoch 38/1000

Epoch 00038: val_acc did not improve from 0.99313
328/328 - 1s - loss: 0.0121 - acc: 0.9959 - val_loss: 0.0275 - val_acc: 0.9914
Epoch 00038: early stopping
Test f1 score : 0.995249406175772 
Test accuracy score : 0.9931295087598764 

結果

  • Test f1 score : 0.995249406175772
  • Test accuracy score : 0.9931295087598764

問題点

前処理

Preprocessingパートでみたように基線がゆれるようなデータではスケーリングがばらばらになるので論文のアルゴリズムだと十分に心拍がとれない。実際のデータにこのアルゴリズムをつかって心拍を抽出するとほんの一部分にしか推論を行えないことになる。

また、ST上昇のようなケースではST波の方がRピークよりもでかいことがあり、そのような場合にはST波の頂上をRピークと勘違いしてデータセットに入れることになるのでこれもまた論文のアルゴリズムではカバーできていない。

データセット

学習データとして心電図のII誘導しか用いていない。少ないデータで推論が行えるという利点を裏返せばこの予測器を適応できるデータは非常に限られるということだ。

またクラスの不均衡、患者のTrainとTestでのオーバーラップなどデータ分布においても問題はあって、実際の臨床心電図データに活用できるかといわれれば懐疑的である。

結語

今回の論文では1-D CNNで心電図波形の分類を行った。転移学習の利点、学習しやすい形にする前処理の方法、心電図波形の扱い方を学んだ。比較的単純なNNモデルでも強力な予測器にすることができることも分かった。

一方好成績というのはあくまでもそのデータセットやTrain/Test分割後のバリデーションにおいてのことだけであって、Kaggle勉強会でやったように実世界の問題に基づいたバリデーションの枠組みやデータセット・データ分布の設計がなされていないと得られた予測器はいわゆる「井の中の蛙大海を知らず」ということになる。

実臨床で使える医療AIを設計するためには、実臨床データの分布となるべく似たデータセットやそれらを抽出するための前処理アルゴリズムがまず必要であるということを肝に銘じて終わる。

Discussion

コメントにはログインが必要です。