# PyFusion 1 - 初步调研 ###### tags: `MSRA Intern` ## Summary 语法初步调研 + 实现小 demo ## Details ### Decorators and/or with-statement #### with-statement with-statement 我理解不太合适。 ```python with EXPR as VAR: BLOCK ``` 大概等价于 ```python VAR = EXPR VAR.__enter__() try: BLOCK finally: VAR.__exit__() ``` 关键点是 `BLOCK` 的每一行都会被确确实实地运行。 对于 NNFusionRT,我理解需要把 BLOCK 整个替换掉,所以可能不太合适。 若真的想做到的话,我理解只能 ```python with nnfusion.jit() as ctx: ctx.init() # Raise an exception here... BLOCK ``` 通过抛出 exception 来跳过 `BLOCK` 并在 `__exit__` 运行 NNFusionRT,而且 `BLOCK` 段的代码和逻辑也只能通过比较复杂的手段来获取。 感觉这样是不是太 tricky 也不利于使用和开发维护~ > Update: 若要不显式 raise exception 来跳过 `BLOCK` 需要 hack 等级 [trick](https://stackoverflow.com/questions/12594148/skipping-execution-of-with-block) #### Decorators Decorators 就非常自然,只需要类似 ```python def jit(func): def wrapper(*args, **kwargs): if wrapper.forward is None: wrapper.forward = get_nnfusionrt_forward(func, *args, **kwargs) return wrapper.forward(*args, **kwargs) wrapper.forward = None return wrapper ``` 就能很自然取得 `func` 逻辑并在第一次呼叫时 tune 然后替换掉 `func`。(也可以 async tune 完再替换掉) ### 复用相关 我理解应该设计成:tune 之后若没有变动,下次直接复用。 如何定义 没有变动 需要有明确的规则。 **复用** Signature + To avoid collision + file path + funtion name++ + maybe the same (inner function, redefine...) + input shape? + ... + To detect modification + function source code? + auto-format + remove comment + hash + some signature of other involked functions? + ... **限制** + 能否使用外部变量 + 其他 involked function 是否也需要被考虑 或者设计成限制比较宽松,然后在第一次呼叫时检查 NNFusionRT 的输出是否与 ground truth 一致。 + [optional keyword arguments](https://realpython.com/primer-on-python-decorators/#both-please-but-never-mind-the-bread) + ```python @nnfusion.jit(check_first=True) def some_function(...): ... ``` + ```python @nnfusion.jit(check_first=True, async=True) ``` ### Demo Demo 的 runtime 暂时用的是 ORT,因为目前运行 NNFusionRT 有环境问题跑不了范例,但设计上我理解暂时没有差别 <details> <summary>Details (Click to expand) </summary> <br> docker 安装 nnfusion ``` [ERROR] 2022-01-28T06:16:13z src/nnfusion/util/errors.hpp 169 Check failed: 'dtype == "float32"' at /root/nnfusion/src/nnfusion/core/operators/generic_op/generic_op_define/Sum.cpp:17: (no explanation given) terminate called after throwing an instance of 'nnfusion::errors::CheckError' what(): Check failed: 'dtype == "float32"' at /root/nnfusion/src/nnfusion/core/operators/generic_op/generic_op_define/Sum.cpp:17: (no explanation given) Aborted (core dumped) ``` main branch nnfusion (跟 Double precision 和 ORT 有关) ``` ONNX model check passed! Importing ONNX model into ONNX Runtime... Traceback (most recent call last): File "/usr/local/bin/templates/onnx/ort_run_frozen.py", line 88, in <module> ort_session.set_providers([args.provider]) File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 155, in set_providers self._reset_session(providers, provider_options) File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 407, in _reset_session self._create_inference_session(providers, provider_options) File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 379, in _create_inference_session sess.initialize_session(providers, provider_options, disabled_optimizers) onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Conv(1) node with name 'Conv_1' ``` ``` ONNX model check passed! Importing ONNX model into ONNX Runtime... Traceback (most recent call last): File "/usr/local/bin/templates/onnx/ort_run_frozen.py", line 85, in <module> ort_session = ort.InferenceSession(args.file, sess_options, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 335, in __init__ self._create_inference_session(providers, provider_options, disabled_optimizers) File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 379, in _create_inference_session sess.initialize_session(providers, provider_options, disabled_optimizers) onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains compiled nodes. Please disable any execution providers which generate compiled nodes. ``` </details> <br> **code** ```python= import functools import inspect from hashlib import md5 from pathlib import Path import torch import torch.nn.functional as F # torch.set_default_tensor_type(torch.DoubleTensor) class TorchModule(torch.nn.Module): def __init__(self, func): super(TorchModule, self).__init__() self.func = func def forward(self, *args, **kwargs): return self.func(*args, **kwargs) def get_signature(func): # TODO add more func_name = func.__name__ source_hash = str(md5(inspect.getsource(func).encode('utf-8')).hexdigest()) return '@'.join((func_name, source_hash)) def ort_inference_func(func, *args): def export_onnx_model(model_name): outputs = func(*args) if isinstance(outputs, torch.Tensor): outputs = [outputs] model = TorchModule(func) torch.onnx.export(model, args, model_name, # verbose=True, input_names=[f"input_{i}" for i in range(len(args))], output_names=[f"output_{i}" for i in range(len(outputs))]) def signature2name(signature): return signature + ".onnx" signature = get_signature(func) model_name = signature2name(signature) # TODO clean out-of-date cache if not Path(model_name).is_file(): export_onnx_model(model_name) import onnxruntime as ort ort_session = ort.InferenceSession( model_name, providers=['CUDAExecutionProvider']) def forward(*inputs): ort_inputs = { ort_session.get_inputs()[i].name: input_.cpu().numpy() for i, input_ in enumerate(inputs) } ort_outputs = ort_session.run(None, ort_inputs) return [ torch.from_numpy(output) for output in ort_outputs ] return forward def ort_decorator(func): @functools.wraps(func) def wrapper(*args): # TODO support kwargs? if wrapper.forward is None: wrapper.forward = ort_inference_func(func, *args) return wrapper.forward(*args) wrapper.forward = None return wrapper def poisson(r0, p, phi, filter_): conv_out = F.conv2d(p, filter_, padding=1) alpha = torch.mul(r0, r0).sum() / torch.mul(p, conv_out).sum() phi = alpha * p + phi r1 = r0 - alpha * conv_out r1_sum = torch.mul(r1, r1).sum() beta = r1_sum / torch.mul(r0, r0).sum() p = r1 + beta * p return r1_sum, phi, p, r1 def compute(func, r0, p, phi, filter_, step=20): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) with torch.no_grad(): p_p = p start.record() for i in range(step): print(i, end='\r') r1_sum_p, phi_p, p_p, r1_p = func(r0, p, phi, filter_) r0, p = r1_p, p_p end.record() print() torch.cuda.synchronize() time_elapsed = start.elapsed_time(end) / step result = r1_sum_p, phi_p, p_p, r1_p return time_elapsed, result if __name__ == "__main__": M = 1024 * 4 N = 1024 * 8 torch.cuda.set_device(0) alpha = torch.randn([1], device="cuda") r0 = torch.randn(1, 1, M, N, device="cuda") p = r0 phi = torch.randn(1, 1, M, N, device="cuda") filter_ = torch.tensor( [[0., 1., 0.], [1., -4., 1.], [0., 1., 0.]], device="cuda" ).view(1, 1, 3, 3) time_pytorch, res_pytorch = compute(poisson, r0, p, phi, filter_) print(f"PyTorch step time: {time_pytorch} ms") time_ort, res_ort = compute(ort_decorator(poisson), r0, p, phi, filter_) print(f"ORT step time: {time_ort} ms") for t1, t2 in zip(res_pytorch, res_ort): assert torch.allclose(t1.cpu(), t2.cpu(), rtol=1e-2, atol=1e-3), (t1, t2) ```