長短期記憶神經網絡(LSTM)介紹以及簡單應用分析

本文分為四個部分,第一部分簡要介紹LSTM的應用現狀;第二部分介紹LSTM的發展歷史,并引出了受眾多學者關注的LSTM變體——門控遞歸單元(GRU);第三部分介紹LSTM的基本結構,由基本循環神經網絡結構引出LSTM的具體結構。第四部分,應用Keras框架提供的API,比較和分析簡單循環神經網絡(SRN)、LSTM和GRU在手寫數字mnist數據集上的表現。

 


應用現狀

       長短期記憶神經網絡(LSTM)是一種特殊的循環神經網絡(RNN)。原始的RNN在訓練中,隨著訓練時間的加長以及網絡層數的增多,很容易出現梯度爆炸或者梯度消失的問題,導致無法處理較長序列數據,從而無法獲取長距離數據的信息。

       LSTM應用的領域包括:文本生成、機器翻譯、語音識別、生成圖像描述和視頻標記等。

      2009年, 應用LSTM搭建的神經網絡模型贏得了ICDAR手寫識別比賽冠軍。

      2015年以來,在機械故障診斷和預測領域,相關學者應用LSTM來處理機械設備的振動信號。

      2016年, 谷歌公司應用LSTM來做語音識別和文字翻譯,其中Google翻譯用的就是一個7-8層的LSTM模型。

      2016年, 蘋果公司使用LSTM來優化Siri應用。

 


發展歷史

        1997年,Sepp Hochreiter 和 Jürgen Schmidhuber[1]提出了長短期記憶神經網絡(LSTM),有效解決了RNN難以解決的人為延長時間任務的問題,并解決了RNN容易出現梯度消失的問題。

        1999年,Felix A. Gers等人[2]發現[1]中提出的LSTM在處理連續輸入數據時,如果沒有重置網絡內部的狀態,最終會導致網絡崩潰。因此,他們在文獻[1]基礎上引入了遺忘門機制,使得LSTM能夠重置自己的狀態。

         2000年,Felix A. Gers和Jiirgen Schmidhuber[3]發現,通過在LSTM內部狀態單元內添加窺視孔(Peephole)連接,可以增強網絡對輸入序列之間細微特征的區分能力。

         2005年,Alex Graves和Jürgen Schmidhuber[4]在文獻[1] [2] [3]的基礎上提出了一種雙向長短期記憶神經網絡(BLSTM),也稱為vanilla LSTM,是當前應用最廣泛的一種LSTM模型。

         2005年-2015年期間,相關學者提出了多種LSTM變體模型,此處不多做描述。

         2016年,Klaus Greff 等人[5]回顧了LSTM的發展歷程,并比較分析了八種LSTM變體在語音識別、手寫識別和弦音樂建模方面的能力,實驗結果表明這些變體不能顯著改進標準LSTM體系結構,并證明了遺忘門和輸出激活功能是LSTM的關鍵組成部分。在這八種變體中,vanilla LSTM的綜合表現能力最佳。另外,還探索了LSTM相關超參數的設定影響,實驗結果表明學習率是最關鍵的超參數,其次是網絡規模(網絡層數和隱藏層單元數),而動量梯度等設置對最終結果影響不大。

        下圖展示了Simple RNN(圖左)和vanilla LSTM(圖右,圖中藍色線條表示窺視孔連接)的基本單元結構圖[5]:

        在眾多LSTM變體中,2014年Kyunghyun Cho等人[6]提出的變體引起了眾多學者的關注。Kyunghyun Cho等人簡化了LSTM架構,稱為門控遞歸單元(GRU)。GRU擺脫了單元狀態,基本結構由重置門和更新門組成。LSTM和GRU的基本結構單元如下圖(具體可參考:Illustrated Guide to LSTM’s and GRU’s: A step by step explanation)。

 

        在GRU被提出后,Junyoung Chung等人[7]比較了LSTM和GRU在復音音樂和語音信號建模方面的能力,實驗結果表明GRU和LSTM表現相當。

        GRU被提出至今(2019年),也只有幾年時間,關于它的一些應用利弊到目前還未探索清楚。不過,相對于LSTM架構,GRU的的參數較少,在數據量較大的情況下,其訓練速度更快。

         LSTM是深度學習技術中的一員,其基本結構比較復雜,計算復雜度較高,導致較難進行較深層次的學習,例如谷歌翻譯也只是應用7-8層的LSTM網絡結構。另外,在訓練學習過程中有可能會出現過擬合,可以通過應用dropout來解決過擬合問題(這在Keras等框架中均有實現,具體可參考:LSTM原理與實踐,原來如此簡單)。

         LSTM在當前應用比較的結構是雙向LSTM或者多層堆疊LSTM,這兩種結構的實現在Keras等框架中均有對應的API可以調用。

        下圖展示一個堆疊兩層的LSTM結構圖(來源:運用TensorFlow處理簡單的NLP問題):

 

        下圖展示了一個雙向LSTM的結構圖(來源:雙向LSTM

 

 

 

 


基本原理

        本節首先講解一下RNN的基本結構,然后說明LSTM的具體原理(下面要介紹的LSTM即為vanilla LSTM)。

        原始的RNN基本結構圖如下圖所示(原圖來源:Understanding LSTM Networks)。

       由上圖可知,RNN展開后由多個相同的單元連續連接。但是,RNN的實際結構確和上圖左邊的結構所示,是一個自我不斷循環的結構。即隨著輸入數據的不斷增加,上述自我循環的結構把上一次的狀態傳遞給當前輸入,一起作為新的輸入數據進行當前輪次的訓練和學習,一直到輸入或者訓練結束,最終得到的輸出即為最終的預測結果。

        LSTM是一種特殊的RNN,兩者的區別在于普通的RNN單個循環結構內部只有一個狀態。而LSTM的單個循環結構(又稱為細胞)內部有四個狀態。相比于RNN,LSTM循環結構之間保持一個持久的單元狀態不斷傳遞下去,用于決定哪些信息要遺忘或者繼續傳遞下去。

        包含三個連續循環結構的RNN如下圖,每個循環結構只有一個輸出:

        包含三個連續循環結構的LSTM如下圖,每個循環結構有兩個輸出,其中一個即為單元狀態:

        一層LSTM是由單個循環結構結構組成,既由輸入數據的維度和循環次數決定單個循環結構需要自我更新幾次,而不是多個單個循環結構連接組成(此處關于這段描述,在實際操作的理解詳述請參考:Keras關于LSTM的units參數,還是不理解? ),即當前層LSTM的參數總個數只需計算一個循環單元就行,而不是計算多個連續單元的總個數。

       下面將由一組圖來詳細結構LSTM細胞的基本組成和實現原理。LSTM細胞由輸入門、遺忘門、輸出門和單元狀態組成。

  • 輸入門:決定當前時刻網絡的輸入數據有多少需要保存到單元狀態。
  • 遺忘門:決定上一時刻的單元狀態有多少需要保留到當前時刻。
  • 輸出門:控制當前單元狀態有多少需要輸出到當前的輸出值。

       下圖展示了應用上一個時刻的輸出h_t-1和當前的數據輸入x_t,通過遺忘門得到f_t的過程。(下面的一組原圖來源:Understanding LSTM Networks

       下圖展示了應用上一個時刻的輸出h_t-1和當前的數據輸入x_t,通過輸入門得到i_t,以及通過單元狀態得到當前時刻暫時狀態C~t的過程。

       下圖展示了應用上一個細胞結構的單元狀態C_t-1、遺忘門輸出f_t、輸入門輸出i_t以及單元狀態的輸出C~t,得到當前細胞的狀態C_t的過程。

       下圖展示了應用上一個時刻的輸出h_t-1和當前的數據輸入x_t,通過輸出門得到o_t的過程,以及結合當前細胞的單元狀態C_t和o_t得到最終的輸出h_t的過程。

 

 


基于Keras框架的手寫數字識別實驗

        本節應用Keras提供的API,比較和分析Simple RNN、LSTM和GRU在手寫數字mnist數據集上的預測準確率。

應用Simple RNN進行手寫數字預測訓練的代碼如下:

import keras
from keras.layers import LSTM , SimpleRNN, GRU
from keras.layers import Dense, Activation
from keras.datasets import mnist
from keras.models import Sequential
from keras.optimizers import Adam
learning_rate
= 0.001 training_iters = 20 batch_size = 128 display_step = 10 n_input = 28 n_step = 28 n_hidden = 128 n_classes = 10 (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(-1, n_step, n_input) x_test = x_test.reshape(-1, n_step, n_input) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 y_train = keras.utils.to_categorical(y_train, n_classes) y_test = keras.utils.to_categorical(y_test, n_classes) model = Sequential() model.add(SimpleRNN(n_hidden, batch_input_shape=(None, n_step, n_input), unroll=True)) model.add(Dense(n_classes)) model.add(Activation('softmax')) adam = Adam(lr=learning_rate) model.summary() model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy']) history = model.fit(x_train, y_train, batch_size=batch_size, epochs=training_iters, verbose=1, validation_data=(x_test, y_test)) scores = model.evaluate(x_test, y_test, verbose=0) print('Simple RNN test score(loss value):', scores[0]) print('Simple RNN test accuracy:', scores[1])

訓練結果:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn_1 (SimpleRNN)     (None, 128)               20096     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
_________________________________________________________________
activation_1 (Activation)    (None, 10)                0         
=================================================================
Total params: 21,386
Trainable params: 21,386
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 3s 51us/step - loss: 0.4584 - acc: 0.8615 - val_loss: 0.2459 - val_acc: 0.9308
Epoch 2/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1923 - acc: 0.9440 - val_loss: 0.1457 - val_acc: 0.9578
Epoch 3/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1506 - acc: 0.9555 - val_loss: 0.1553 - val_acc: 0.9552
Epoch 4/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1326 - acc: 0.9604 - val_loss: 0.1219 - val_acc: 0.9642
Epoch 5/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1184 - acc: 0.9651 - val_loss: 0.1014 - val_acc: 0.9696
Epoch 6/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1021 - acc: 0.9707 - val_loss: 0.1254 - val_acc: 0.9651
Epoch 7/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0987 - acc: 0.9708 - val_loss: 0.0946 - val_acc: 0.9733
Epoch 8/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0959 - acc: 0.9722 - val_loss: 0.1163 - val_acc: 0.9678
Epoch 9/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0888 - acc: 0.9742 - val_loss: 0.0983 - val_acc: 0.9718
Epoch 10/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0833 - acc: 0.9750 - val_loss: 0.1199 - val_acc: 0.9651
Epoch 11/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0814 - acc: 0.9750 - val_loss: 0.0939 - val_acc: 0.9722
Epoch 12/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0767 - acc: 0.9773 - val_loss: 0.0865 - val_acc: 0.9761
Epoch 13/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0747 - acc: 0.9778 - val_loss: 0.1077 - val_acc: 0.9697
Epoch 14/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0746 - acc: 0.9779 - val_loss: 0.1098 - val_acc: 0.9693
Epoch 15/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0671 - acc: 0.9799 - val_loss: 0.0776 - val_acc: 0.9771
Epoch 16/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0639 - acc: 0.9810 - val_loss: 0.0961 - val_acc: 0.9730
Epoch 17/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0701 - acc: 0.9792 - val_loss: 0.1046 - val_acc: 0.9713
Epoch 18/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0600 - acc: 0.9822 - val_loss: 0.0865 - val_acc: 0.9767
Epoch 19/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0635 - acc: 0.9813 - val_loss: 0.0812 - val_acc: 0.9790
Epoch 20/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0579 - acc: 0.9827 - val_loss: 0.0981 - val_acc: 0.9733
Simple RNN test score(loss value): 0.09805978989955037
Simple RNN test accuracy: 0.9733

        可知Simple RNN在測試集上的最終預測準確率為97.33%。

        只需修改下方代碼中Simple RNN為LSTM,即可調用LSTM進行模型訓練:

model.add(SimpleRNN(n_hidden,
               batch_input_shape=(None, n_step, n_input),
               unroll=True))

改變為:

model.add(LSTM(n_hidden,
               batch_input_shape=(None, n_step, n_input),
               unroll=True))

訓練結果:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_1 (LSTM)                (None, 128)               80384     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1290      
_________________________________________________________________
activation_2 (Activation)    (None, 10)                0         
=================================================================
Total params: 81,674
Trainable params: 81,674
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 10s 172us/step - loss: 0.5226 - acc: 0.8277 - val_loss: 0.1751 - val_acc: 0.9451
Epoch 2/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.1474 - acc: 0.9549 - val_loss: 0.1178 - val_acc: 0.9641
Epoch 3/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.1017 - acc: 0.9690 - val_loss: 0.0836 - val_acc: 0.9748
Epoch 4/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0764 - acc: 0.9764 - val_loss: 0.0787 - val_acc: 0.9759
Epoch 5/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0607 - acc: 0.9811 - val_loss: 0.0646 - val_acc: 0.9813
Epoch 6/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0542 - acc: 0.9834 - val_loss: 0.0630 - val_acc: 0.9801
Epoch 7/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0452 - acc: 0.9859 - val_loss: 0.0603 - val_acc: 0.9803
Epoch 8/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0406 - acc: 0.9874 - val_loss: 0.0531 - val_acc: 0.9849
Epoch 9/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0345 - acc: 0.9888 - val_loss: 0.0540 - val_acc: 0.9834
Epoch 10/20
60000/60000 [==============================] - 8s 132us/step - loss: 0.0305 - acc: 0.9901 - val_loss: 0.0483 - val_acc: 0.9848
Epoch 11/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0281 - acc: 0.9913 - val_loss: 0.0517 - val_acc: 0.9843
Epoch 12/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0256 - acc: 0.9918 - val_loss: 0.0472 - val_acc: 0.9847
Epoch 13/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0229 - acc: 0.9929 - val_loss: 0.0441 - val_acc: 0.9874
Epoch 14/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0204 - acc: 0.9935 - val_loss: 0.0490 - val_acc: 0.9855
Epoch 15/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0192 - acc: 0.9938 - val_loss: 0.0486 - val_acc: 0.9851
Epoch 16/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0203 - acc: 0.9937 - val_loss: 0.0450 - val_acc: 0.9866
Epoch 17/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0160 - acc: 0.9948 - val_loss: 0.0391 - val_acc: 0.9882
Epoch 18/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0147 - acc: 0.9955 - val_loss: 0.0544 - val_acc: 0.9834
Epoch 19/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0147 - acc: 0.9953 - val_loss: 0.0456 - val_acc: 0.9880
Epoch 20/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0153 - acc: 0.9952 - val_loss: 0.0465 - val_acc: 0.9867
LSTM test score(loss value): 0.046479647984029725
LSTM test accuracy: 0.9867

       可知LSTM在測試集上的最終預測準確率為98.67%。

       采用同樣的思路,把Simple RNN改為GRU,即可調用GRU進行模型訓練。

訓練結果:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
gru_1 (GRU)                  (None, 128)               60288     
_________________________________________________________________
dense_3 (Dense)              (None, 10)                1290      
_________________________________________________________________
activation_3 (Activation)    (None, 10)                0         
=================================================================
Total params: 61,578
Trainable params: 61,578
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 10s 166us/step - loss: 0.6273 - acc: 0.7945 - val_loss: 0.2062 - val_acc: 0.9400
Epoch 2/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.1656 - acc: 0.9501 - val_loss: 0.1261 - val_acc: 0.9606
Epoch 3/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.1086 - acc: 0.9667 - val_loss: 0.0950 - val_acc: 0.9697
Epoch 4/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0824 - acc: 0.9745 - val_loss: 0.0761 - val_acc: 0.9769
Epoch 5/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0644 - acc: 0.9797 - val_loss: 0.0706 - val_acc: 0.9793
Epoch 6/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0540 - acc: 0.9829 - val_loss: 0.0678 - val_acc: 0.9799
Epoch 7/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0479 - acc: 0.9854 - val_loss: 0.0601 - val_acc: 0.9811
Epoch 8/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0402 - acc: 0.9877 - val_loss: 0.0495 - val_acc: 0.9848
Epoch 9/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0346 - acc: 0.9895 - val_loss: 0.0591 - val_acc: 0.9821
Epoch 10/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0306 - acc: 0.9901 - val_loss: 0.0560 - val_acc: 0.9836
Epoch 11/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0290 - acc: 0.9910 - val_loss: 0.0473 - val_acc: 0.9857
Epoch 12/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0249 - acc: 0.9922 - val_loss: 0.0516 - val_acc: 0.9852
Epoch 13/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0222 - acc: 0.9930 - val_loss: 0.0448 - val_acc: 0.9863
Epoch 14/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0206 - acc: 0.9934 - val_loss: 0.0453 - val_acc: 0.9872
Epoch 15/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0178 - acc: 0.9944 - val_loss: 0.0559 - val_acc: 0.9833
Epoch 16/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0173 - acc: 0.9947 - val_loss: 0.0502 - val_acc: 0.9854
Epoch 17/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0150 - acc: 0.9955 - val_loss: 0.0401 - val_acc: 0.9880
Epoch 18/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0164 - acc: 0.9949 - val_loss: 0.0486 - val_acc: 0.9872
Epoch 19/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0133 - acc: 0.9960 - val_loss: 0.0468 - val_acc: 0.9882
Epoch 20/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0107 - acc: 0.9965 - val_loss: 0.0470 - val_acc: 0.9879
GRU test score(loss value): 0.04698457587567973
GRU test accuracy: 0.9879

       可知GRU在測試集上的最終預測準確率為98.79%。

      由上述實驗結果可知,LSTM和GRU的預測準確率要顯著高于Simple RNN,而LSTM和GRU的預測準確率相差較小。

 

 


參考文獻

[1] S. Hochreiter and J. Schmidhuber, “Long Short-Term Memory,” Neural Comput, vol. 9, no. 8, pp. 1735–1780, Nov. 1997.

[2] F. A. Gers, J. Schmidhuber, and F. A. Cummins, “Learning to Forget: Continual Prediction with LSTM,” Neural Comput., vol. 12, pp. 2451–2471, 2000.

[3] F. A. Gers and J. Schmidhuber, “Recurrent nets that time and count,” Proc. IEEE-INNS-ENNS Int. Jt. Conf. Neural Netw. IJCNN 2000 Neural Comput. New Chall. Perspect. New Millenn., vol. 3, pp. 189–194 vol.3, 2000.

[4] A. Graves and J. Schmidhuber, “Framewise phoneme classification with bidirectional LSTM and other neural network architectures,” Neural Netw., vol. 18, no. 5, pp. 602–610, Jul. 2005.

[5] K. Greff, R. K. Srivastava, J. Koutník, B. R. Steunebrink, and J. Schmidhuber, “LSTM: A Search Space Odyssey,” IEEE Trans. Neural Netw. Learn. Syst., vol. 28, no. 10, pp. 2222–2232, Oct. 2017.

[6] K. Cho et al., “Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation,” ArXiv14061078 Cs Stat, Jun. 2014.

[7] J. Chung, C. Gulcehre, K. Cho, and Y. Bengio, “Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling,” ArXiv14123555 Cs, Dec. 2014.

 

posted @ 2019-10-05 20:39 舞動的心 閱讀(...) 評論(...) 編輯 收藏
手机投注彩票合法吗