# 深度學習第三次作業 ## 前導 * **數據** 這份作業的目標是要識別手寫數字的軌跡,分別有 train_in.csv、train_out.csv、test_in.csv 三筆資料。數據集內容為手寫數字的軌跡,每組 ID 對應到的軌跡有八個二維坐標點,目標是建立一個 RNN 循環神經網絡來進行 label 的預測。 ## 實作 * **數據預處理** 一開始先從 train_in.csv 和 train_out.csv 讀取訓練數據集,以及從 test_in.csv 讀取測試數據集。並將數據轉換為NumPy數組,重新塑造並訓練數據集的形狀(8, 2),使其符合RNN所需的形狀。  * **Model** 這份作業我採用的是 RNN 循環神經網絡,先進行 train_test_split 增加訓練的隨機性以防止過度擬合,模型一共有兩層,分別是 LSTM 層和輸出層, LSTM 層作為循環神經網絡的核心,input 的形狀圍(8, 2),對應到八個不同時間下的座標,並設立 early_stopping 防止過度訓練。  * **激活函數的選用** 這份作是屬於多分類的問題,分別要識別出數字 0~9 ,因此激活函數選擇最適合用來解決多酚類問題的 soft max。 * **神經元數目** 測試我各種各樣的神經元組合,從中找出表現最為優異的參數組:  * **損失函數** 損失函數的部分,我選擇使用適合多分類問題用的: 交叉熵損失函數,它可以計算預測結果的概率分佈與真實標籤之間的交叉熵損失。可以將訓練以最小化預測結果與真實標籤之間的差異,從而提高模型的準確性和性能。 * **預測結果** 最後將訓練好的模型拿來預測 test_in.csv 的內容,每個軌跡找出最大機率值的 label,並將 label 配對在對應的 ID 上。最後將結果存成 test_out.csv 檔案。  ## 結果與討論 這次的作業內容相當有趣,一開始因為對 RNN 不熟的關係,在 input 的處理上沒有用好,卡了一段時間,後來稍微爬一下文後才知道要如何設定 input,知道如何設定 input 後就輕鬆許多,剩下的 code 也快速寫完,最後的成果也相當不錯,達到了 0.98 的精準率,這樣的結果令我很滿意!
×
Sign in
Email
Password
Forgot password
or
By clicking below, you agree to our
terms of service
.
Sign in via Facebook
Sign in via Twitter
Sign in via GitHub
Sign in via Dropbox
Sign in with Wallet
Wallet (
)
Connect another wallet
New to HackMD?
Sign up