--- tags: tensorflow,keras,cnn --- # Mnist--手寫辨識使用CNN ``` python= from tensorflow.keras.datasets import mnist import matplotlib.pyplot as plt import pandas as pd import numpy as np (x_train, y_train), (x_test,y_test) = mnist.load_data() def one_hot(data,size): shape=(len(data),size) val=np.zeros(shape=shape) for i in range(len(data)): val[i][data[i]]=1 return val np.set_printoptions(edgeitems=256) #print(x_train) ``` ## 資料預處理部分 將圖片資料歸一化(除以255),因為像素範圍從0-255 Conv2D,這裡的D不是維度(Dimension),而是軸(axis)的概念 會要reshape的原因是 因為是灰階圖片,會省略最後面的通道 通常Conv2D輸入(pixel,pixel,channel) channel=3為彩色圖片(ex: RGB三通道) ``` python= x_train=x_train.astype('float32')/255 x_train=x_train.reshape(-1,28,28,1) y_train=one_hot(y_train,10) x_test=x_test.astype('float32')/255 x_test=x_test.reshape(-1,28,28,1) y_test=one_hot(y_test,10) ``` ## 建立卷積神經網路 - Conv2D卷積層 - 有點類似一直放大(zoom)圖片 - padding為填補,設定為same會==自動補齊==因卷積後,導致輸入輸出圖片大小不一的情形 - MaxPooling2D池化層(Pooling Layer) - 用來鎖定特徵,假設今天圖片為32x32的大小,而池化層設為2,那麼池化層的輸出便是16x16,也就是將像素都除以2 - Flatten平坦層 - 負責下降維度,將矩陣拉成一維陣列的概念。假設有8張50x50的特徵圖(Feature map),在該層會被轉成1D向量,變成50x50x8=20000維 - Dense全連接層(又稱Linear layer) - 可做線性轉換 - 最後一層可用來分類(sigmoid二元分類、softmax多元分類) - Dropout是丟棄層 - 用來減緩神經網路的過度擬和問題(Overfitting) relu函數,通常用來當作卷積神經網路的開端 ![](https://i.imgur.com/tmcz0Zd.png) softmax在另一篇Mnist有提到,就不說明了 ```python= from tensorflow.keras import layers, models cnn=models.Sequential() cnn.add(layers.Conv2D(64,(3,3),padding='same',activation='relu',input_shape=(28,28,1))) cnn.add(layers.MaxPooling2D((2,2))) cnn.add(layers.Conv2D(64,(3,3),padding='same',activation='relu')) cnn.add(layers.MaxPooling2D((2,2))) cnn.add(layers.Flatten()) cnn.add(layers.Dense(10,activation='softmax')) cnn.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc']) history=cnn.fit(x=x_train,y=y_train,batch_size=128,epochs=20,validation_split=0.1) #x_train=x_train.reshape((60000,28,28,3)) #print(x_train[0].shape) ``` 精確度,找時間補齊Confusion metrix ![](https://i.imgur.com/u95h158.png) 印出訓練趨勢圖 ```python= plt.plot(history.history['acc'],"r-") plt.plot(history.history['val_acc'],"b--") plt.title('Training/validating accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') plt.legend(['training accuracy','validating accuracy'],loc="best",frameon=False) plt.show() ``` ![](https://i.imgur.com/ZoyqMCH.png)