首頁>技術>

深度學習程式設計入門(一):利用keras實現手寫數字識別

part 1 寫在前面

這篇文章涉及到MLP還有softmax、RMSprop等等一些基礎。

1、RMSprop的原理在這裡:blog.csdn.net/bvl10101111…

2、多層感知器大概長這樣:zhidao.baidu.com/question/16…

3、softmax的原理在這裡:www.zhihu.com/question/23…

本文中使用mnist資料集進行訓練和測試,mnist是一個很出名的手寫數字資料集,包含了60000個訓練資料和10000個測試資料,該資料集下載地址和說明在這裡:yann.lecun.com/exdb/mnist/

part2 確定目標&網路建立

前輩們說,如果要用神經網路解決問題,最好先嚐試最簡單的神經網路模型。

如果最簡單的神經網路解決不了問題,再嘗試使用更復雜的神經網路。

根據上述經驗,我們先使用最簡單的單層神經網路進行訓練和預測:

1、引入資料集

path = './mnist.npz'f = np.load(path)x_train, y_train = f['x_train'], f['y_train']x_test, y_test = f['x_test'], f['y_test']f.close()複製程式碼

2、資料預處理

將原來(n,28,28)的資料變成(n,784)的形式。

(28,28)是圖片的長寬,n為訓練or測試集的資料量

x_train = x_train.reshape(60000, 784).astype('float32')x_test = x_test.reshape(10000, 784).astype('float32')#將28X28的二維資料轉成 一維資料x_train /= 255x_test /= 255複製程式碼

3、確定我們使用的模型

根據"先簡單後複雜"的嘗試原則,我們先使用單隱藏層的神經網路進行訓練,模型大概這樣:

輸入層→單隱藏層(relu)→softmax層

model = Sequential()model.add(Dense(64, activation='relu', input_dim=784))model.add(Dense(10, activation='softmax'))複製程式碼

4、確定損失函式,優化方法

#確定學習率等等引數,在此我只調整了學習率,其他引數是copy API的rms = keras.optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=1e-06)#確定損失函式,優化器model.compile(loss='categorical_crossentropy', optimizer=rms, metrics=['accuracy'])#載入訓練資料,設定mini-batch的值,訓練波數model.fit(x_train, y_train, epochs=10, batch_size=128)複製程式碼

5、查看準確率

score = model.evaluate(x_test, y_test, batch_size=128) #檢視測試值準確度print(score)複製程式碼

6、執行,該模型準確率為97%左右

part2 調整超引數

我們知道人在辨認手寫數字時是很難出錯的,所以97%的準確率顯然不盡如人意,因此開始調參。

model.add(Dense(300, activation='relu', input_dim=784))複製程式碼

將單隱藏層的神經元數目增加到300個之後,正確率達到了98%,由於po主要去洗澡了,所以就不繼續往下調引數了。

神經元增加之後,正確率上升,個人感覺是之前的神經網路神經元太少,欠擬合了,但事實上,當隱藏層神經元達到300的時候,雖然正確率上升了,但這個網路是處於過擬合狀態的(因為訓練資料的識別正確率為99.9%,但測試資料的識別正確率只有98%),再增加神經元個數已經失去了意義。

鑑於之前的模型過擬合了,參考了相關論文,我將隱藏層的300個神經元減為150個神經元,並增加了一層100個神經元的隱藏層:

model.add(Dense(150, activation='relu', input_dim=784))model.add(Dense(100, activation='relu'))model.add(Dense(10, activation='softmax'))複製程式碼

經過以上修改後,過擬合現象基本消失了,但是識別率依然不盡如人意,在感到調整其他引數作用不大的情況下,我選擇改變損失函式的型別,將損失函式換為log損失函式。

log損失函式的參考資料在這裡:

https://www.zhihu.com/question/27126057複製程式碼

更換了損失函式之後,識別率上升至99.57%

更換了損失函式後為何識別率大幅上升呢?個人認為是因為之前的損失函式即使收斂到最小值,也只是一個“相對”的最優解,更換損失函式後,得到的值是一個近似的最優解了。

part3 總結

神經網路真好玩但是好難學,學了好久也才勉強入門,要多多寫程式碼研究多多學習才能進步呀,如果本文有錯歡迎指正,畢竟寫這文章的人也是個菜雞╮(╯▽╰)╭

還有就是調參如同煉丹一般,非常玄學

  • BSA-TRITC(10mg/ml) TRITC-BSA 牛血清白蛋白改性標記羅丹明
  • mysql連線報“Communications link failure”錯誤