這個問題也困擾了我很久,後來終於明白了,很多資料都沒有在這個地方做詳細的解釋,那就是 LSTM 的 cell 裡面的 num_units 該怎麼理解,其實也是很簡單,看看下圖:
可以看到中間的 cell 裡面有四個黃色小框,你如果理解了那個代表的含義一切就明白了,每一個小黃框代表一個前饋網路層,對,就是經典的神經網路的結構,num_units就是這個層的隱藏神經元個數,就這麼簡單。其中1、2、4的啟用函式是 sigmoid,第三個的啟用函式是 tanh。
另外幾個需要注意的地方:
1、 cell 的狀態是一個向量,是有多個值的。。。一開始沒有理解這點的時候怎麼都想不明白
2、 上一次的狀態 h(t-1)是怎麼和下一次的輸入 x(t) 結合(concat)起來的,這也是很多資料沒有明白講的地方,也很簡單,concat, 直白的說就是把二者直接拼起來,比如 x是28位的向量,h(t-1)是128位的,那麼拼起來就是156位的向量,就是這麼簡單。。
3、 cell 的權重是共享的,這是什麼意思呢?這是指這張圖片上有三個綠色的大框,代表三個 cell 對吧,但是實際上,它只是代表了一個 cell 在不同時序時候的狀態,所有的資料只會透過一個 cell,然後不斷更新它的權重。
4、那麼一層的 LSTM 的引數有多少個?根據第 3 點的說明,我們知道引數的數量是由 cell 的數量決定的,這裡只有一個 cell,所以引數的數量就是這個 cell 裡面用到的引數個數。假設 num_units 是128,輸入是28位的,那麼根據上面的第 2 點,可以得到,四個小黃框的引數一共有 (128+28)*(128*4),也就是156 * 512,可以看看 TensorFlow 的最簡單的 LSTM 的案例,中間層的引數就是這樣,不過還要加上輸出的時候的啟用函式的引數,假設是10個類的話,就是128*10的 W 引數和10個bias 引數
5、cell 最上面的一條線的狀態即 s(t) 代表了長時記憶,而下面的 h(t)則代表了工作記憶或短時記憶
暫時這麼多。
這個問題也困擾了我很久,後來終於明白了,很多資料都沒有在這個地方做詳細的解釋,那就是 LSTM 的 cell 裡面的 num_units 該怎麼理解,其實也是很簡單,看看下圖:
可以看到中間的 cell 裡面有四個黃色小框,你如果理解了那個代表的含義一切就明白了,每一個小黃框代表一個前饋網路層,對,就是經典的神經網路的結構,num_units就是這個層的隱藏神經元個數,就這麼簡單。其中1、2、4的啟用函式是 sigmoid,第三個的啟用函式是 tanh。
另外幾個需要注意的地方:
1、 cell 的狀態是一個向量,是有多個值的。。。一開始沒有理解這點的時候怎麼都想不明白
2、 上一次的狀態 h(t-1)是怎麼和下一次的輸入 x(t) 結合(concat)起來的,這也是很多資料沒有明白講的地方,也很簡單,concat, 直白的說就是把二者直接拼起來,比如 x是28位的向量,h(t-1)是128位的,那麼拼起來就是156位的向量,就是這麼簡單。。
3、 cell 的權重是共享的,這是什麼意思呢?這是指這張圖片上有三個綠色的大框,代表三個 cell 對吧,但是實際上,它只是代表了一個 cell 在不同時序時候的狀態,所有的資料只會透過一個 cell,然後不斷更新它的權重。
4、那麼一層的 LSTM 的引數有多少個?根據第 3 點的說明,我們知道引數的數量是由 cell 的數量決定的,這裡只有一個 cell,所以引數的數量就是這個 cell 裡面用到的引數個數。假設 num_units 是128,輸入是28位的,那麼根據上面的第 2 點,可以得到,四個小黃框的引數一共有 (128+28)*(128*4),也就是156 * 512,可以看看 TensorFlow 的最簡單的 LSTM 的案例,中間層的引數就是這樣,不過還要加上輸出的時候的啟用函式的引數,假設是10個類的話,就是128*10的 W 引數和10個bias 引數
5、cell 最上面的一條線的狀態即 s(t) 代表了長時記憶,而下面的 h(t)則代表了工作記憶或短時記憶
暫時這麼多。