## 從pytorch建構網路、生成onnx檔,到網頁中以onnx.js執行類神經網路。 ### Part 0、前言 本篇主要是以最精省的方式示範,如何將在Pytorch中的網路模型導出成onnx檔,再由網頁的onnx.js讀取並執行該神經網路,那麼為甚麼要這麼做呢? 我的大目標是要建構出,能跟人類玩遊戲能競爭、合作的AI,雖然我們在訓練模型時是以pytorch在後端訓練,但實際在執行網頁遊戲時,神經網路需要時時控制著遊戲腳色,但要讓伺服器以如此高頻率的方式同時應付大量客戶端是不現實的,必須將**運行神經網路、控制遊戲腳色**的重擔交到**客戶端**,因此必須找到將Pytorch模型放進網頁執行的方法。 ### Part 1、生成onnx檔 首先想到的問題是,pytorch訓練好的網路模型要怎麼丟給前端,這邊用到的是名為onnx(Open Neural Network Exchange),是一套開放神經網路交換格式,就是各個玩神經網路的平台大家都說好用這種格式描述神經網路結構,因此第一步就是要將pytorch的神經網路以onnx格式輸出。 第3~5行就是建構一個1進3出的簡單循序模型,之後可以帶替換成更複雜的神經網路。 ```python= import torch # 建構序列型模型 model = torch.nn.Sequential( torch.nn.Linear(1,3) ) dummy_input = torch.randn(1, 1) # 輸出onnx torch.onnx.export(model,dummy_input,'./model.onnx', export_params=True) ``` 第9行torch.onnx.export(),裡面的四個參數都必不可少,model就是你要輸出的網路模型,dummy_input是示範該模型的輸入,在第7行中dummy_input可以設定為,torch.randn(1, 1) 或 torch.randn(1) ,兩者分別為張量與向量,雖然不管選哪個都能正常輸出onnx檔,但到下一階段onnx.js套件只能以張量進行運算,所以這裡**dummy_input**必須設定為張量**torch.randn(1, 1)**。 './model.onnx'是onnx檔案的路徑與名稱,最後**export_params=True**設定輸出的模型是帶有**預訓練好的模型參數**,包含權重、偏移等等,如果少了這個參數,模型就白輸出了。 ### Part 2、讀進onnx檔 網頁上讀進onnx檔靠的是onnx.js套件,onnx.js套件於第1行中載入。 主程式由session這個Promise進行異步操作,用了兩層**then()**的**鍊式調用**結構,以確保程式執行的先後順序,第一層等模型載入完才開始做推理運算,第二層等算完之後才取得輸出結果,如此才算完成一次神經網路的運作。 ```javascript= <script src="https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js"></script> <script> // 載入模型 const session = new onnx.InferenceSession(); session.loadModel("http://localhost:8000/model.onnx").then(() => { // 建構輸入張量 const input = new onnx.Tensor(new Float32Array([1]), 'float32', [1, 1]); // 進行推理 session.run([input]).then(outputMap => { // 獲取輸出張量 const outputTensor = outputMap.values().next().value; // 輸出结果 console.log(outputTensor.data); }) }); </script> ``` 第7行input = new onnx.Tensor(),同樣是用於宣告網路模型的輸入數據,第3個參數是定義數據結構[1,1],必須與上一段dummy_input的設定一樣。 ### Part 3、建立網頁伺服器 上面這段網頁程式是放在本地端做測試,網頁伺服器要調用相同路徑下的onnx檔,就會受到CORS policy 阻擋,我的作法是用python -m http.server指令建立最簡單的伺服器,再搭配Chrome上的插件Allow CORS便能成功執行,執行成功便會在瀏覽器的console中得到神經網路的輸出值。 ![](https://hackmd.io/_uploads/HkP7E5Muh.png) 另外,有發生過後端onnx檔已經重新輸出並覆蓋舊檔,但網頁刷新後還是抓到舊的內容,這是由於瀏覽器的快取(cache)機制造成,會將已經載入過的檔案暫存起來以提高網頁載入速度,可以用Ctrl + Shift + R強制重新載入。