深度學習程式設計入門(一):利用keras實現手寫數字識別
part 1 寫在前面
這篇文章涉及到MLP還有softmax、RMSprop等等一些基礎。
1、RMSprop的原理在這裡:blog.csdn.net/bvl10101111…
2、多層感知器大概長這樣:zhidao.baidu.com/question/16…
3、softmax的原理在這裡:www.zhihu.com/question/23…
本文中使用mnist資料集進行訓練和測試,mnist是一個很出名的手寫數字資料集,包含了60000個訓練資料和10000個測試資料,該資料集下載地址和說明在這裡:yann.lecun.com/exdb/mnist/
part2 確定目標&網路建立
前輩們說,如果要用神經網路解決問題,最好先嚐試最簡單的神經網路模型。
如果最簡單的神經網路解決不了問題,再嘗試使用更復雜的神經網路。
根據上述經驗,我們先使用最簡單的單層神經網路進行訓練和預測:
1、引入資料集
path = './mnist.npz'f = np.load(path)x_train, y_train = f['x_train'], f['y_train']x_test, y_test = f['x_test'], f['y_test']f.close()複製程式碼
2、資料預處理
將原來(n,28,28)的資料變成(n,784)的形式。
(28,28)是圖片的長寬,n為訓練or測試集的資料量
x_train = x_train.reshape(60000, 784).astype('float32')x_test = x_test.reshape(10000, 784).astype('float32')#將28X28的二維資料轉成 一維資料x_train /= 255x_test /= 255複製程式碼
3、確定我們使用的模型
根據"先簡單後複雜"的嘗試原則,我們先使用單隱藏層的神經網路進行訓練,模型大概這樣:
輸入層→單隱藏層(relu)→softmax層
model = Sequential()model.add(Dense(64, activation='relu', input_dim=784))model.add(Dense(10, activation='softmax'))複製程式碼
4、確定損失函式,優化方法
#確定學習率等等引數,在此我只調整了學習率,其他引數是copy API的rms = keras.optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=1e-06)#確定損失函式,優化器model.compile(loss='categorical_crossentropy', optimizer=rms, metrics=['accuracy'])#載入訓練資料,設定mini-batch的值,訓練波數model.fit(x_train, y_train, epochs=10, batch_size=128)複製程式碼
5、查看準確率
score = model.evaluate(x_test, y_test, batch_size=128) #檢視測試值準確度print(score)複製程式碼
6、執行,該模型準確率為97%左右
part2 調整超引數
我們知道人在辨認手寫數字時是很難出錯的,所以97%的準確率顯然不盡如人意,因此開始調參。
model.add(Dense(300, activation='relu', input_dim=784))複製程式碼
將單隱藏層的神經元數目增加到300個之後,正確率達到了98%,由於po主要去洗澡了,所以就不繼續往下調引數了。
神經元增加之後,正確率上升,個人感覺是之前的神經網路神經元太少,欠擬合了,但事實上,當隱藏層神經元達到300的時候,雖然正確率上升了,但這個網路是處於過擬合狀態的(因為訓練資料的識別正確率為99.9%,但測試資料的識別正確率只有98%),再增加神經元個數已經失去了意義。
鑑於之前的模型過擬合了,參考了相關論文,我將隱藏層的300個神經元減為150個神經元,並增加了一層100個神經元的隱藏層:
model.add(Dense(150, activation='relu', input_dim=784))model.add(Dense(100, activation='relu'))model.add(Dense(10, activation='softmax'))複製程式碼
經過以上修改後,過擬合現象基本消失了,但是識別率依然不盡如人意,在感到調整其他引數作用不大的情況下,我選擇改變損失函式的型別,將損失函式換為log損失函式。
log損失函式的參考資料在這裡:
https://www.zhihu.com/question/27126057複製程式碼
更換了損失函式之後,識別率上升至99.57%
更換了損失函式後為何識別率大幅上升呢?個人認為是因為之前的損失函式即使收斂到最小值,也只是一個“相對”的最優解,更換損失函式後,得到的值是一個近似的最優解了。
part3 總結
神經網路真好玩但是好難學,學了好久也才勉強入門,要多多寫程式碼研究多多學習才能進步呀,如果本文有錯歡迎指正,畢竟寫這文章的人也是個菜雞╮(╯▽╰)╭
還有就是調參如同煉丹一般,非常玄學