# 客製化NN模型轉ONNX檔 ## 1. 已知NN模型input及output ![image](https://hackmd.io/_uploads/S1HuBjIwa.png) ## 2. ChatGPT生成NN模型 ```python= import torch import torch.nn as nn import torch.nn.init as init import torch.onnx class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.flatten = nn.Flatten() self.fc1 = nn.Linear(1 * 635 * 128, 64) self.fc2 = nn.Linear(64, 32) self.fc3 = nn.Linear(32, 31) # Initialize weights with random values self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): init.kaiming_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) def forward(self, x): x = self.flatten(x) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x ### Assuming batch_size is known during initialization batch_size = 1 # You can change this to your desired batch size model = SimpleNN() ### Example input with the specified size dummy_input_tensor = torch.randn((batch_size, 1, 635, 128)) dynamic_axes = {'input':{0: 'batch_size'}, 'output':{0: '1'}} ### Export the model to ONNX format onnx_filename = "simple_nn_model.onnx" torch.onnx.export(model, dummy_input_tensor, onnx_filename, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes, verbose=True) print(f"Model exported to {onnx_filename}") ``` **關鍵在於加入dynamic_axes, 讓onnx知道input是batch_size, output是1維vector** ## 3. ONNX Prediction ```python= import onnxruntime import numpy as np # Load the ONNX model onnx_filename = "simple_nn_model.onnx" ort_session = onnxruntime.InferenceSession(onnx_filename) # Generate a sample input with the specified size batch_size = 1 input_data = np.random.randn(batch_size, 1, 635, 128).astype(np.float32) # Get the input name from the ONNX model input_name = ort_session.get_inputs()[0].name # Perform inference ort_inputs = {input_name: input_data} ort_outputs = ort_session.run(None, ort_inputs) # Print the output output_data = ort_outputs[0] print("Output size:", output_data.shape) print("Output values:", output_data) ``` ## 4. 輸出ONNX graph input: [batch_size, 1, 635, 128] output: [1,31] ![image](https://hackmd.io/_uploads/Sy46uXPwa.png) 參考文獻 1. https://answers.opencv.org/question/224547/dnn-onnx-model-with-variable-batch-size/ 2. https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html