Try   HackMD

PyFusion 1 - 初步调研

tags: MSRA Intern

Summary

语法初步调研 + 实现小 demo

Details

Decorators and/or with-statement

with-statement

with-statement 我理解不太合适。

with EXPR as VAR:
    BLOCK

大概等价于

VAR = EXPR
VAR.__enter__()
try:
    BLOCK
finally:
    VAR.__exit__()

关键点是 BLOCK 的每一行都会被确确实实地运行。
对于 NNFusionRT,我理解需要把 BLOCK 整个替换掉,所以可能不太合适。

若真的想做到的话,我理解只能

with nnfusion.jit() as ctx:
    ctx.init()  # Raise an exception here...
    BLOCK

通过抛出 exception 来跳过 BLOCK 并在 __exit__ 运行 NNFusionRT,而且 BLOCK 段的代码和逻辑也只能通过比较复杂的手段来获取。

感觉这样是不是太 tricky 也不利于使用和开发维护~

Update: 若要不显式 raise exception 来跳过 BLOCK 需要 hack 等级 trick

Decorators

Decorators 就非常自然,只需要类似

def jit(func):
    def wrapper(*args, **kwargs):
        if wrapper.forward is None:
            wrapper.forward = get_nnfusionrt_forward(func, *args, **kwargs)
        return wrapper.forward(*args, **kwargs)
    wrapper.forward = None
    return wrapper

就能很自然取得 func 逻辑并在第一次呼叫时 tune 然后替换掉 func。(也可以 async tune 完再替换掉)

复用相关

我理解应该设计成:tune 之后若没有变动,下次直接复用。

如何定义 没有变动 需要有明确的规则。

复用

Signature

  • To avoid collision
    • file path
    • funtion name++
      • maybe the same (inner function, redefine)
    • input shape?
  • To detect modification
    • function source code?
      • auto-format + remove comment + hash
    • some signature of other involked functions?

限制

  • 能否使用外部变量
  • 其他 involked function 是否也需要被考虑

或者设计成限制比较宽松,然后在第一次呼叫时检查 NNFusionRT 的输出是否与 ground truth 一致。

  • optional keyword arguments
    • ​​​@nnfusion.jit(check_first=True)
      ​​​def some_function(...):
      ​​​​​​    ...
      
    • ​​​​​​@nnfusion.jit(check_first=True, async=True)
      

Demo

Demo 的 runtime 暂时用的是 ORT,因为目前运行 NNFusionRT 有环境问题跑不了范例,但设计上我理解暂时没有差别

Details (Click to expand)

docker 安装 nnfusion

[ERROR] 2022-01-28T06:16:13z src/nnfusion/util/errors.hpp 169   Check failed: 'dtype == "float32"' at /root/nnfusion/src/nnfusion/core/operators/generic_op/generic_op_define/Sum.cpp:17:
(no explanation given)
terminate called after throwing an instance of 'nnfusion::errors::CheckError'
  what():  Check failed: 'dtype == "float32"' at /root/nnfusion/src/nnfusion/core/operators/generic_op/generic_op_define/Sum.cpp:17:
(no explanation given)
Aborted (core dumped)

main branch nnfusion (跟 Double precision 和 ORT 有关)

ONNX model check passed!
Importing ONNX model into ONNX Runtime...
Traceback (most recent call last):
  File "/usr/local/bin/templates/onnx/ort_run_frozen.py", line 88, in <module>
    ort_session.set_providers([args.provider])
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 155, in set_providers
    self._reset_session(providers, provider_options)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 407, in _reset_session
    self._create_inference_session(providers, provider_options)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 379, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Conv(1) node with name 'Conv_1'
ONNX model check passed!
Importing ONNX model into ONNX Runtime...
Traceback (most recent call last):
  File "/usr/local/bin/templates/onnx/ort_run_frozen.py", line 85, in <module>
    ort_session = ort.InferenceSession(args.file, sess_options, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 335, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 379, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains compiled nodes. Please disable any execution providers which generate compiled nodes.

code

import functools import inspect from hashlib import md5 from pathlib import Path import torch import torch.nn.functional as F # torch.set_default_tensor_type(torch.DoubleTensor) class TorchModule(torch.nn.Module): def __init__(self, func): super(TorchModule, self).__init__() self.func = func def forward(self, *args, **kwargs): return self.func(*args, **kwargs) def get_signature(func): # TODO add more func_name = func.__name__ source_hash = str(md5(inspect.getsource(func).encode('utf-8')).hexdigest()) return '@'.join((func_name, source_hash)) def ort_inference_func(func, *args): def export_onnx_model(model_name): outputs = func(*args) if isinstance(outputs, torch.Tensor): outputs = [outputs] model = TorchModule(func) torch.onnx.export(model, args, model_name, # verbose=True, input_names=[f"input_{i}" for i in range(len(args))], output_names=[f"output_{i}" for i in range(len(outputs))]) def signature2name(signature): return signature + ".onnx" signature = get_signature(func) model_name = signature2name(signature) # TODO clean out-of-date cache if not Path(model_name).is_file(): export_onnx_model(model_name) import onnxruntime as ort ort_session = ort.InferenceSession( model_name, providers=['CUDAExecutionProvider']) def forward(*inputs): ort_inputs = { ort_session.get_inputs()[i].name: input_.cpu().numpy() for i, input_ in enumerate(inputs) } ort_outputs = ort_session.run(None, ort_inputs) return [ torch.from_numpy(output) for output in ort_outputs ] return forward def ort_decorator(func): @functools.wraps(func) def wrapper(*args): # TODO support kwargs? if wrapper.forward is None: wrapper.forward = ort_inference_func(func, *args) return wrapper.forward(*args) wrapper.forward = None return wrapper def poisson(r0, p, phi, filter_): conv_out = F.conv2d(p, filter_, padding=1) alpha = torch.mul(r0, r0).sum() / torch.mul(p, conv_out).sum() phi = alpha * p + phi r1 = r0 - alpha * conv_out r1_sum = torch.mul(r1, r1).sum() beta = r1_sum / torch.mul(r0, r0).sum() p = r1 + beta * p return r1_sum, phi, p, r1 def compute(func, r0, p, phi, filter_, step=20): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) with torch.no_grad(): p_p = p start.record() for i in range(step): print(i, end='\r') r1_sum_p, phi_p, p_p, r1_p = func(r0, p, phi, filter_) r0, p = r1_p, p_p end.record() print() torch.cuda.synchronize() time_elapsed = start.elapsed_time(end) / step result = r1_sum_p, phi_p, p_p, r1_p return time_elapsed, result if __name__ == "__main__": M = 1024 * 4 N = 1024 * 8 torch.cuda.set_device(0) alpha = torch.randn([1], device="cuda") r0 = torch.randn(1, 1, M, N, device="cuda") p = r0 phi = torch.randn(1, 1, M, N, device="cuda") filter_ = torch.tensor( [[0., 1., 0.], [1., -4., 1.], [0., 1., 0.]], device="cuda" ).view(1, 1, 3, 3) time_pytorch, res_pytorch = compute(poisson, r0, p, phi, filter_) print(f"PyTorch step time: {time_pytorch} ms") time_ort, res_ort = compute(ort_decorator(poisson), r0, p, phi, filter_) print(f"ORT step time: {time_ort} ms") for t1, t2 in zip(res_pytorch, res_ort): assert torch.allclose(t1.cpu(), t2.cpu(), rtol=1e-2, atol=1e-3), (t1, t2)