# 客製化NN模型轉ONNX檔
## 1. 已知NN模型input及output

## 2. ChatGPT生成NN模型
```python=
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.onnx
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(1 * 635 * 128, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 31)
# Initialize weights with random values
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
x = self.flatten(x)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
### Assuming batch_size is known during initialization
batch_size = 1 # You can change this to your desired batch size
model = SimpleNN()
### Example input with the specified size
dummy_input_tensor = torch.randn((batch_size, 1, 635, 128))
dynamic_axes = {'input':{0: 'batch_size'},
'output':{0: '1'}}
### Export the model to ONNX format
onnx_filename = "simple_nn_model.onnx"
torch.onnx.export(model, dummy_input_tensor, onnx_filename,
input_names=['input'],
output_names=['output'],
dynamic_axes=dynamic_axes,
verbose=True)
print(f"Model exported to {onnx_filename}")
```
**關鍵在於加入dynamic_axes, 讓onnx知道input是batch_size, output是1維vector**
## 3. ONNX Prediction
```python=
import onnxruntime
import numpy as np
# Load the ONNX model
onnx_filename = "simple_nn_model.onnx"
ort_session = onnxruntime.InferenceSession(onnx_filename)
# Generate a sample input with the specified size
batch_size = 1
input_data = np.random.randn(batch_size, 1, 635, 128).astype(np.float32)
# Get the input name from the ONNX model
input_name = ort_session.get_inputs()[0].name
# Perform inference
ort_inputs = {input_name: input_data}
ort_outputs = ort_session.run(None, ort_inputs)
# Print the output
output_data = ort_outputs[0]
print("Output size:", output_data.shape)
print("Output values:", output_data)
```
## 4. 輸出ONNX graph
input: [batch_size, 1, 635, 128]
output: [1,31]

參考文獻
1. https://answers.opencv.org/question/224547/dnn-onnx-model-with-variable-batch-size/
2. https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html