tensorflow
想到的時候就做點筆記,不然過三天之後就忘光了(x
Carrot蘿蔔Sun, May 31, 2020 11:33 AM
沒想到再回到tensorflow懷抱的這天
版本已經更新到了2.1.0
蘿蔔頭頂上都新開了幾場初雪
TF更早已不是當初那青澀的少年了
遙想起sess.run()的那個年代
其實說不定還是帶著幾分懷念的呢 ~
好吧並沒有,我再也不想用session來開圖了
那根本就不是應該出現在python style的語言形式
隨著版本更迭來到了2.0.0
昔日神經網路API的另外一大巨頭Keras
現在被收錄到了TF底下
對於我這一路走來都信奉著TF神教的狂熱者
真的是一大福音呢 ~
儘管我認為總有一天也要把pytorch學起來就是了
大家做的事情都差不多
不要細分那麼多彼此嘛 ~
以下記錄了一些蘿蔔使用TF2開荒以來
遇到的各式各樣神奇的問題
又或是一些值得記錄下來的code
希望這篇文章的內容可以為需要TF2教學的你
或是過了幾個月之後忘東忘西的蘿蔔
帶來相當實質的幫助 ~~
在快速模型寫法當中
我們可以先直觀的定義一個ForwardPropagation的function
就像上面的customize_model
一樣
讓假想的訓練資料一層一層的傳遞下去
直到獲得output
之後
再使用tf.keras.Model()
指定輸入tensor與輸出tensor
便完成了model物件初步的建立
下一步呢則是為model指定optimizer以及loss function
使用model.compile()
並且傳入參數optimizer= "Adam"
以及loss= "categorical_crossentropy"
目前這裡是假設要做一個圖片類別的分類器啦
參數的選擇會因為任務型態的不同而有所差別
在model物件回傳回來之後
便接著使用model.fit()
開始訓練
一般而言
在通用模型中,我們會需要實作以下幾個區塊
他們分別為
@tf.function
這些區塊在快速寫法中
都被keras高級的API包的妥妥貼貼
造成這些寫法雖然很方便
但真要實作某些內容的時候卻又綁手綁腳
於是,我們終究還是得回來看看這些底層的寫法
並把所有區塊都寫成自己想要的樣子吧!
底層??
Am I a joke to you???
TensorFlow
這裡最重要的是class繼承自tf.keras.layers.Layer
而且一定要把需要的layer定義成member
Model class則是繼承自tf.keras.Model
首先train_step必須被定義為@tf.function
接著資料流將會經過以下數個階段
值得一提的是,如果你的網路中存在著需要迭代訓練的多個模型
比如說GAN中的generator和discriminator
那就可以使用上方第二個多個tape的寫法
各自定義出各自的tape,獨立地進行梯度下降
這樣一來
就算網路架構再複雜都不用怕啦 ~
這個大區塊主要處理了以下幾件事情
千萬記得要在訓練過程中才callckpt_manager.save()
呀
我們會在下個區塊中再次看到它~
train_model的核心步驟如下