MSRA Intern
语法初步调研 + 实现小 demo
with-statement 我理解不太合适。
with EXPR as VAR:
BLOCK
大概等价于
VAR = EXPR
VAR.__enter__()
try:
BLOCK
finally:
VAR.__exit__()
关键点是 BLOCK
的每一行都会被确确实实地运行。
对于 NNFusionRT,我理解需要把 BLOCK 整个替换掉,所以可能不太合适。
若真的想做到的话,我理解只能
with nnfusion.jit() as ctx:
ctx.init() # Raise an exception here...
BLOCK
通过抛出 exception 来跳过 BLOCK
并在 __exit__
运行 NNFusionRT,而且 BLOCK
段的代码和逻辑也只能通过比较复杂的手段来获取。
感觉这样是不是太 tricky 也不利于使用和开发维护~
Update: 若要不显式 raise exception 来跳过
BLOCK
需要 hack 等级 trick
Decorators 就非常自然,只需要类似
def jit(func):
def wrapper(*args, **kwargs):
if wrapper.forward is None:
wrapper.forward = get_nnfusionrt_forward(func, *args, **kwargs)
return wrapper.forward(*args, **kwargs)
wrapper.forward = None
return wrapper
就能很自然取得 func
逻辑并在第一次呼叫时 tune 然后替换掉 func
。(也可以 async tune 完再替换掉)
我理解应该设计成:tune 之后若没有变动,下次直接复用。
如何定义 没有变动 需要有明确的规则。
复用
Signature
限制
或者设计成限制比较宽松,然后在第一次呼叫时检查 NNFusionRT 的输出是否与 ground truth 一致。
@nnfusion.jit(check_first=True)
def some_function(...):
...
@nnfusion.jit(check_first=True, async=True)
Demo 的 runtime 暂时用的是 ORT,因为目前运行 NNFusionRT 有环境问题跑不了范例,但设计上我理解暂时没有差别
docker 安装 nnfusion
[ERROR] 2022-01-28T06:16:13z src/nnfusion/util/errors.hpp 169 Check failed: 'dtype == "float32"' at /root/nnfusion/src/nnfusion/core/operators/generic_op/generic_op_define/Sum.cpp:17:
(no explanation given)
terminate called after throwing an instance of 'nnfusion::errors::CheckError'
what(): Check failed: 'dtype == "float32"' at /root/nnfusion/src/nnfusion/core/operators/generic_op/generic_op_define/Sum.cpp:17:
(no explanation given)
Aborted (core dumped)
main branch nnfusion (跟 Double precision 和 ORT 有关)
ONNX model check passed!
Importing ONNX model into ONNX Runtime...
Traceback (most recent call last):
File "/usr/local/bin/templates/onnx/ort_run_frozen.py", line 88, in <module>
ort_session.set_providers([args.provider])
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 155, in set_providers
self._reset_session(providers, provider_options)
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 407, in _reset_session
self._create_inference_session(providers, provider_options)
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 379, in _create_inference_session
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Conv(1) node with name 'Conv_1'
ONNX model check passed!
Importing ONNX model into ONNX Runtime...
Traceback (most recent call last):
File "/usr/local/bin/templates/onnx/ort_run_frozen.py", line 85, in <module>
ort_session = ort.InferenceSession(args.file, sess_options, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 335, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 379, in _create_inference_session
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains compiled nodes. Please disable any execution providers which generate compiled nodes.
code
import functools
import inspect
from hashlib import md5
from pathlib import Path
import torch
import torch.nn.functional as F
# torch.set_default_tensor_type(torch.DoubleTensor)
class TorchModule(torch.nn.Module):
def __init__(self, func):
super(TorchModule, self).__init__()
self.func = func
def forward(self, *args, **kwargs):
return self.func(*args, **kwargs)
def get_signature(func):
# TODO add more
func_name = func.__name__
source_hash = str(md5(inspect.getsource(func).encode('utf-8')).hexdigest())
return '@'.join((func_name, source_hash))
def ort_inference_func(func, *args):
def export_onnx_model(model_name):
outputs = func(*args)
if isinstance(outputs, torch.Tensor):
outputs = [outputs]
model = TorchModule(func)
torch.onnx.export(model,
args,
model_name,
# verbose=True,
input_names=[f"input_{i}" for i in range(len(args))],
output_names=[f"output_{i}" for i in range(len(outputs))])
def signature2name(signature):
return signature + ".onnx"
signature = get_signature(func)
model_name = signature2name(signature)
# TODO clean out-of-date cache
if not Path(model_name).is_file():
export_onnx_model(model_name)
import onnxruntime as ort
ort_session = ort.InferenceSession(
model_name,
providers=['CUDAExecutionProvider'])
def forward(*inputs):
ort_inputs = {
ort_session.get_inputs()[i].name: input_.cpu().numpy()
for i, input_ in enumerate(inputs)
}
ort_outputs = ort_session.run(None, ort_inputs)
return [
torch.from_numpy(output)
for output in ort_outputs
]
return forward
def ort_decorator(func):
@functools.wraps(func)
def wrapper(*args): # TODO support kwargs?
if wrapper.forward is None:
wrapper.forward = ort_inference_func(func, *args)
return wrapper.forward(*args)
wrapper.forward = None
return wrapper
def poisson(r0, p, phi, filter_):
conv_out = F.conv2d(p, filter_, padding=1)
alpha = torch.mul(r0, r0).sum() / torch.mul(p, conv_out).sum()
phi = alpha * p + phi
r1 = r0 - alpha * conv_out
r1_sum = torch.mul(r1, r1).sum()
beta = r1_sum / torch.mul(r0, r0).sum()
p = r1 + beta * p
return r1_sum, phi, p, r1
def compute(func, r0, p, phi, filter_, step=20):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
p_p = p
start.record()
for i in range(step):
print(i, end='\r')
r1_sum_p, phi_p, p_p, r1_p = func(r0, p, phi, filter_)
r0, p = r1_p, p_p
end.record()
print()
torch.cuda.synchronize()
time_elapsed = start.elapsed_time(end) / step
result = r1_sum_p, phi_p, p_p, r1_p
return time_elapsed, result
if __name__ == "__main__":
M = 1024 * 4
N = 1024 * 8
torch.cuda.set_device(0)
alpha = torch.randn([1], device="cuda")
r0 = torch.randn(1, 1, M, N, device="cuda")
p = r0
phi = torch.randn(1, 1, M, N, device="cuda")
filter_ = torch.tensor(
[[0., 1., 0.],
[1., -4., 1.],
[0., 1., 0.]], device="cuda"
).view(1, 1, 3, 3)
time_pytorch, res_pytorch = compute(poisson, r0, p, phi, filter_)
print(f"PyTorch step time: {time_pytorch} ms")
time_ort, res_ort = compute(ort_decorator(poisson), r0, p, phi, filter_)
print(f"ORT step time: {time_ort} ms")
for t1, t2 in zip(res_pytorch, res_ort):
assert torch.allclose(t1.cpu(), t2.cpu(), rtol=1e-2, atol=1e-3), (t1, t2)
PyFusion 0 - Task PyFusion 1 - 初步调研 PyFusion 2 - 安装记录 + 一些问题 PyFusion 3 - 初版 demo + 测试 PyFusion 4 - Superbenchmark 测试 nnfusion.jit config 设计 关于 nnfusion jit Decorator for class method + other details Others
Mar 24, 2022姓名: 陈声发 学号:2019280355 Problem 1 We can maintain winning statistics of player 0 and a play terminates with a win for player 0 iff player 0 is declared a winner and the play is a win for player 1 iff it runs forever. The set Q of nodes of this game consists of nodes of the form $(a, p, \tilde{b})$ where $a$ is a node of the parity game, the player $p \in {\text{player 0}, \text{player 1}}$ moves next and $\tilde{b}$ represents the winning statistics of player 0. The number of elements of Q can be bounded by $O(n^4)$.
Dec 26, 202120210914 跑实验,存 pkls 来画 auc 图 conda activate chem git checkout get_fold 第一个实验 python train.py --data_path ~/chemprop/CTSL/CTSL_train_1.csv --dataset_type classification --save_dir checkpoints/CTSL_first_exp --gpu 6 --num_folds 10 --config_path ~/chemprop/CTSL/CTSL.json --ensemble_size 2 &> train_first_exp.log &
Sep 16, 2021or
By clicking below, you agree to our terms of service.
New to HackMD? Sign up