# PyFusion 1 - 初步调研
###### tags: `MSRA Intern`
## Summary
语法初步调研 + 实现小 demo
## Details
### Decorators and/or with-statement
#### with-statement
with-statement 我理解不太合适。
```python
with EXPR as VAR:
BLOCK
```
大概等价于
```python
VAR = EXPR
VAR.__enter__()
try:
BLOCK
finally:
VAR.__exit__()
```
关键点是 `BLOCK` 的每一行都会被确确实实地运行。
对于 NNFusionRT,我理解需要把 BLOCK 整个替换掉,所以可能不太合适。
若真的想做到的话,我理解只能
```python
with nnfusion.jit() as ctx:
ctx.init() # Raise an exception here...
BLOCK
```
通过抛出 exception 来跳过 `BLOCK` 并在 `__exit__` 运行 NNFusionRT,而且 `BLOCK` 段的代码和逻辑也只能通过比较复杂的手段来获取。
感觉这样是不是太 tricky 也不利于使用和开发维护~
> Update: 若要不显式 raise exception 来跳过 `BLOCK` 需要 hack 等级 [trick](https://stackoverflow.com/questions/12594148/skipping-execution-of-with-block)
#### Decorators
Decorators 就非常自然,只需要类似
```python
def jit(func):
def wrapper(*args, **kwargs):
if wrapper.forward is None:
wrapper.forward = get_nnfusionrt_forward(func, *args, **kwargs)
return wrapper.forward(*args, **kwargs)
wrapper.forward = None
return wrapper
```
就能很自然取得 `func` 逻辑并在第一次呼叫时 tune 然后替换掉 `func`。(也可以 async tune 完再替换掉)
### 复用相关
我理解应该设计成:tune 之后若没有变动,下次直接复用。
如何定义 没有变动 需要有明确的规则。
**复用**
Signature
+ To avoid collision
+ file path
+ funtion name++
+ maybe the same (inner function, redefine...)
+ input shape?
+ ...
+ To detect modification
+ function source code?
+ auto-format + remove comment + hash
+ some signature of other involked functions?
+ ...
**限制**
+ 能否使用外部变量
+ 其他 involked function 是否也需要被考虑
或者设计成限制比较宽松,然后在第一次呼叫时检查 NNFusionRT 的输出是否与 ground truth 一致。
+ [optional keyword arguments](https://realpython.com/primer-on-python-decorators/#both-please-but-never-mind-the-bread)
+ ```python
@nnfusion.jit(check_first=True)
def some_function(...):
...
```
+ ```python
@nnfusion.jit(check_first=True, async=True)
```
### Demo
Demo 的 runtime 暂时用的是 ORT,因为目前运行 NNFusionRT 有环境问题跑不了范例,但设计上我理解暂时没有差别
<details>
<summary>Details (Click to expand) </summary>
<br>
docker 安装 nnfusion
```
[ERROR] 2022-01-28T06:16:13z src/nnfusion/util/errors.hpp 169 Check failed: 'dtype == "float32"' at /root/nnfusion/src/nnfusion/core/operators/generic_op/generic_op_define/Sum.cpp:17:
(no explanation given)
terminate called after throwing an instance of 'nnfusion::errors::CheckError'
what(): Check failed: 'dtype == "float32"' at /root/nnfusion/src/nnfusion/core/operators/generic_op/generic_op_define/Sum.cpp:17:
(no explanation given)
Aborted (core dumped)
```
main branch nnfusion (跟 Double precision 和 ORT 有关)
```
ONNX model check passed!
Importing ONNX model into ONNX Runtime...
Traceback (most recent call last):
File "/usr/local/bin/templates/onnx/ort_run_frozen.py", line 88, in <module>
ort_session.set_providers([args.provider])
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 155, in set_providers
self._reset_session(providers, provider_options)
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 407, in _reset_session
self._create_inference_session(providers, provider_options)
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 379, in _create_inference_session
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Conv(1) node with name 'Conv_1'
```
```
ONNX model check passed!
Importing ONNX model into ONNX Runtime...
Traceback (most recent call last):
File "/usr/local/bin/templates/onnx/ort_run_frozen.py", line 85, in <module>
ort_session = ort.InferenceSession(args.file, sess_options, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 335, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 379, in _create_inference_session
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains compiled nodes. Please disable any execution providers which generate compiled nodes.
```
</details>
<br>
**code**
```python=
import functools
import inspect
from hashlib import md5
from pathlib import Path
import torch
import torch.nn.functional as F
# torch.set_default_tensor_type(torch.DoubleTensor)
class TorchModule(torch.nn.Module):
def __init__(self, func):
super(TorchModule, self).__init__()
self.func = func
def forward(self, *args, **kwargs):
return self.func(*args, **kwargs)
def get_signature(func):
# TODO add more
func_name = func.__name__
source_hash = str(md5(inspect.getsource(func).encode('utf-8')).hexdigest())
return '@'.join((func_name, source_hash))
def ort_inference_func(func, *args):
def export_onnx_model(model_name):
outputs = func(*args)
if isinstance(outputs, torch.Tensor):
outputs = [outputs]
model = TorchModule(func)
torch.onnx.export(model,
args,
model_name,
# verbose=True,
input_names=[f"input_{i}" for i in range(len(args))],
output_names=[f"output_{i}" for i in range(len(outputs))])
def signature2name(signature):
return signature + ".onnx"
signature = get_signature(func)
model_name = signature2name(signature)
# TODO clean out-of-date cache
if not Path(model_name).is_file():
export_onnx_model(model_name)
import onnxruntime as ort
ort_session = ort.InferenceSession(
model_name,
providers=['CUDAExecutionProvider'])
def forward(*inputs):
ort_inputs = {
ort_session.get_inputs()[i].name: input_.cpu().numpy()
for i, input_ in enumerate(inputs)
}
ort_outputs = ort_session.run(None, ort_inputs)
return [
torch.from_numpy(output)
for output in ort_outputs
]
return forward
def ort_decorator(func):
@functools.wraps(func)
def wrapper(*args): # TODO support kwargs?
if wrapper.forward is None:
wrapper.forward = ort_inference_func(func, *args)
return wrapper.forward(*args)
wrapper.forward = None
return wrapper
def poisson(r0, p, phi, filter_):
conv_out = F.conv2d(p, filter_, padding=1)
alpha = torch.mul(r0, r0).sum() / torch.mul(p, conv_out).sum()
phi = alpha * p + phi
r1 = r0 - alpha * conv_out
r1_sum = torch.mul(r1, r1).sum()
beta = r1_sum / torch.mul(r0, r0).sum()
p = r1 + beta * p
return r1_sum, phi, p, r1
def compute(func, r0, p, phi, filter_, step=20):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
p_p = p
start.record()
for i in range(step):
print(i, end='\r')
r1_sum_p, phi_p, p_p, r1_p = func(r0, p, phi, filter_)
r0, p = r1_p, p_p
end.record()
print()
torch.cuda.synchronize()
time_elapsed = start.elapsed_time(end) / step
result = r1_sum_p, phi_p, p_p, r1_p
return time_elapsed, result
if __name__ == "__main__":
M = 1024 * 4
N = 1024 * 8
torch.cuda.set_device(0)
alpha = torch.randn([1], device="cuda")
r0 = torch.randn(1, 1, M, N, device="cuda")
p = r0
phi = torch.randn(1, 1, M, N, device="cuda")
filter_ = torch.tensor(
[[0., 1., 0.],
[1., -4., 1.],
[0., 1., 0.]], device="cuda"
).view(1, 1, 3, 3)
time_pytorch, res_pytorch = compute(poisson, r0, p, phi, filter_)
print(f"PyTorch step time: {time_pytorch} ms")
time_ort, res_ort = compute(ort_decorator(poisson), r0, p, phi, filter_)
print(f"ORT step time: {time_ort} ms")
for t1, t2 in zip(res_pytorch, res_ort):
assert torch.allclose(t1.cpu(), t2.cpu(), rtol=1e-2, atol=1e-3), (t1, t2)
```