# [RFC] Enable TVM QNN on RISC-V with Subword SIMD Computation ###### tags: `PL-Lab` `RISC-V` `TVM` Hello, we're the team from NTHU (National Tsing-Hua University), Taiwan. Our team mainly focuses on the design with supporting TVM on RISC-V architecture with SIMD instructions. In this RFC, we target on the application for RISC-V P extension(RVP). This is the extension for RISC-V DSP and subword SIMD extension. Note that a preliminary version of this work is reported at RISC-V Global Forum, Sep. 3, 2020, Lightning talk session([video link](https://www.youtube.com/watch?v=1nCn619cXJw&list=PL85jopFZCnbNDtFbl72oU0_8vANrljnh7&index=7 )). ## Intro of RISC-V P extension(RVP) RISC-V is an opsn source ISA with multiple extensions for differernt application needs. For vector computation, RISC-V provides "V" and "P" extension to support superword SIMD and subword SIMD, respectively. Here we target on RVP, it's designed for embedded processors or DSP-like devices. All of computation use general purpose registers (32, 64 bits) with lower precision numerical such as fixed point and integer. In our previous work below at TVM conference, we give the flow for fixed-point flow. As we learn that there is a QNN flow in TVM, we devise the TVM QNN flow for RISC-V P extension. This will make our flow more compatible with existing TVM flow. The previous work for RVP in TVM for fixed-point flow is given below. - Supporting TVM on RISC-V Architectures with SIMD computation ([video link](https://youtu.be/7-EaUUC6QZs?list=PLTPQEx-31JXjA2ZmvYT5s0RqDXFXTSjyL&t=2078)) The specification of RISC-V P extension is as follows. - https://github.com/riscv/riscv-p-spec/blob/master/P-ext-proposal.adoc ## Motivation As we're trying to find a friendly application that related to RVP, we found QNN as the best practice for us. Especially for pre-quantized flow from QNN dialect, most of Ops are either in int8/uint8 or int32, these are suitable to enable subword SIMD computation. As TVM upstream doesn't have support for RISC-V in topi implement or any handling for scheduling. We want to propose our work which mainly focuses on enabling tensorization on `conv2d_nchw_int8` and vectorization on Ops in int32. ## Approach ### Outline 1. New target : `riscv_cpu` 2. Introduce an intrinsic for dot-product in convolution, and enable it by tensorization. 3. Vectorize Ops to generate SIMD pattern (add). 4. Introduce a custom runtime to easily generate executable files for the Spike simulator. 5. Run spike to get the result. ### RISC-V Target - register new TVM target : `riscv_cpu` - using llvm as our target backend with `--mtriple=riscv64-unknown-elf --system-lib` - add `codegen_riscv.cc` as RISC-V specific code generator - register for `target_riscv32` and `target_riscv64` - register strategy, specially handle schedule for `conv2d_nchw_int8` with tensorize - uses x86's compute/schedule for others - since Spike doesn't support parallel computing, we use an empty schedule for `schedule_injective()`, except for Ops that gonna be vectorized ### Intrinsic for dot product In order to efficiently executing convolution, we propose to use the following insturctions in RVP : - **smaqa** : [Signed Multiply Four Bytes with 32-bit Adds](https://github.com/riscv/riscv-p-spec/blob/master/P-ext-proposal.adoc#5109-smaqa-signed-multiply-four-bytes-with-32-bit-adds) - **smaqa.su** : [Signed and Unsigned Multiply Four Bytes with 32-bit Adds](https://github.com/riscv/riscv-p-spec/blob/master/P-ext-proposal.adoc#5110-smaqasu-signed-and-unsigned-multiply-four-bytes-with-32-bit-adds) - with `helper_change_dtypes_to_uint8_int8()` from x86 legalize flow, we can get uint8 x int8 Instuctions above accumulate the product into 32-bits directly, it save the effort to save temp result as 16-bits, also preserve the accuracy compared with using SIMD Mul with 8-bits. The `int32_lanes` is fixed as **2** since maximum length of a register in RVP is 64-bits. This is done in one instruction and with plenty of subword parallelism. Intrinc func is delcared as : ```python # num_int8_elements = 4 # int32_lanes = 2 def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.tir.const(0, 'int32x%d' % (int32_lanes)))) return ib.get() dtype_a = '%s8x%d' % (data_dtype, num_int8_elements) dtype_b = '%s8x%d' % (kernel_dtype, int32_lanes * num_int8_elements) dtype_c = 'int32x%d' % (int32_lanes) a_int8 = ins[0].vload([0], dtype_a) re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8) vec_ai32 = re_int32.astype(dtype_c) vec_a = tvm.tir.call_intrin(dtype_b, 'tir.reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) # Call intrinsic for RVP d_dtype = 's' if data_dtype == 'int' else 'u' k_dtype = 's' if kernel_dtype == 'int' else 'u' if d_dtype == 'u' and k_dtype == 's': inst = 'llvm.riscv.simd.%s%sdot.v%di32' % ( k_dtype, d_dtype, int32_lanes) vdot = tvm.tir.call_llvm_pure_intrin(dtype_c, inst, tvm.tir.const(0, 'uint32'), vec_b, vec_a) else: inst = 'llvm.riscv.simd.%s%sdot.v%di32' % ( d_dtype, k_dtype, int32_lanes) vdot = tvm.tir.call_llvm_pure_intrin(dtype_c, inst, tvm.tir.const(0, 'uint32'), vec_a, vec_b) if index == 0: ib.emit(outs[0].vstore(0, vdot)) else: ib.emit(outs[0].vstore(0, vdot + outs[0].vload([0], 'int32x%d' % (int32_lanes)))) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) ``` ### Vectorization In addition to enable tensorization for convolution, we can also improve the performance for other Ops by vectorizing. For example, most of `add` in pre-quantized model are in `int32`, we can vectorize it with lanes 2 to utilize SIMD instructions like **[kadd32](https://github.com/riscv/riscv-p-spec/blob/master/P-ext-proposal.adoc#65-kadd32-simd-32-bit-signed-saturating-addition)**(SIMD 32-bit Signed Saturating Addition). ```python # python/tvm/topi/riscv_cpu/injective.py # schedule_injective(), check if op is `add` A = op.output(0) if op.input_tensors[0].dtype == 'int32' and op.input_tensors[1].dtype == 'int32': if A.shape[-1] % 2 == 0: o, i = s[A].split(A.op.axis[-1], 2) s[A].vectorize(i) ``` We're also considering to enable vectorize for other Ops. All the other works are in progerss (SIMD Mul, Max ..). ### RISC-V Custom Runtime (RISC-V DLR) As we're trying to use TVM's LLVM backend to generate implmenetation of the model, and plan to run it on Spike. We need a corrosponding LLVM for generating correct assembly and then write a C++ code to calling it (including set input, run, and get output). In the end, we need to compile this program with riscv-gnu-toolchain. To get a minimal runtime for such an environment, we use a custom runtime which is fetched from TVM GraphRuntime with extra features. We remake a C++ interface to invoke the GraphRuntime function, thus making it possible to make the input of runtime be clean and directly usable after `relay.build`. We found this concept is quite similar with **bundle_deploy**, and we’re currently looking for some advice on which flow we should follow or possible approach to reuse. Thus, for this C++ code, we need.. - The output after TVM's `relay.build` : `.graph`, `.params` - with data/label read function in C++ - calling a custom interface to invoke function in GraphRuntime - sample code `host.cpp` at [here](https://github.com/nthu-pllab/RISCV-DLR/blob/master/example/pre_quant_mobilenet_v1_tflite/host.cc) - this file do the similar behavior as `demo_static.cc` in **bundle_deploy** In this flow, we use a `build()` function which collect the needed information from `.graph` to generate a `kernel.inc` file. In this file, the order and the function to be called is presented. Following is the example of `kernel.inc`: ```cpp // kernel.inc extern "C" int32_t fused_transpose(void* args, void* arg_type_ids, int32_t num_args); extern "C" int32_t fused_nn_conv2d_7(void* args, void* arg_type_ids, int32_t num_args); extern "C" int32_t fused_nn_bias_add_5(void* args, void* arg_type_ids, int32_t num_args); extern "C" int32_t fused_nn_relu_4(void* args, void* arg_type_ids, int32_t num_args); // ... void dlr::DLR::Runx() { int32_t ret; ret = fused_transpose(opa[0].values, opa[0].tcodes, opa[0].num); assert(ret == 0); ret = fused_nn_conv2d_7(opa[1].values, opa[1].tcodes, opa[1].num); assert(ret == 0); ret = fused_nn_bias_add_5(opa[2].values, opa[2].tcodes, opa[2].num); assert(ret == 0); ret = fused_nn_relu_4(opa[3].values, opa[3].tcodes, opa[3].num); // ... } ``` This `Runx()` is similar with `tvm_runtime_run()` from **bundle_deploy**. The other functions like `set_input()`, `get_output()` is also provided in this runtime. ### Execute Once we prepare the `build_model.py`, `host.cpp`, we can simply run it and execute on Spike. The example is provided at [here](https://github.com/nthu-pllab/RISCV-DLR/tree/master/example/pre_quant_mobilenet_v1_tflite). ### Overview ![](https://i.imgur.com/aHbsVgc.png) ## Evaluation With evaluting on Spike, the only metric we can compare is **instruction count** for either each Ops or entire model. Thus, we compare the instruction count for the entire model between pre-quantized model with tensorization/vectorization and model in FP32. The models are downloaded from TFLite host models. The following table shows the instruction count for the enire model : | Model name | Pre-quantized with tensorization/vectorization | FP32 | SpeedUp | -------- | -------- | -------- | ---| | Mobilenet_v1 | 3763660304 | 804941581 | 4.67 | | Mobilenet_v2 | 2384447853 | 686688571 | 3.47 | | Inception_v3 | 39933768309 | 4909041142 | 8.14 | | Inception_v4 | 92224223497 | 10434567078 | 8.83 | - more models - inception v1/v2 - https://tfhub.dev/google/imagenet/inception_v1/classification/1 - https://tfhub.dev/google/imagenet/inception_v2/classification/1 - quantized inception v1/v2 - https://tfhub.dev/tensorflow/lite-model/inception_v1_quant/1/default/1 - https://tfhub.dev/tensorflow/lite-model/inception_v2_quant/1/default/1 - https://github.com/riscv/riscv-isa-sim/pull/572 ## Related project ### LLVM For matching with the intrinsic we designed and called in TVM, we need a LLVM which handles RISC-V P extension properly. Since there doesn't have an official version so far. We implemented it by ourselves with v9.0.0., please refer to [project in github](https://github.com/nthu-pllab/llvm9-project-rvp). ### RISC-V Custom Runtime (RISC-V DLR) As mentioned before, this is still an alternative approach and optimizable (we believe it is). Our team is still working on this part for making it more flexible and clean. We're also looking for a better idea, please leave some comments for it. As our need for evaluation is only on Spike on our devices. We didn't consider any constraint about remote/host. Once you create an executable file from this runtime, you can immediately run it by spike to get the result(without going through OpenOCD flow). The project is open source at [here](https://github.com/nthu-pllab/RISCV-DLR). The details and the building step is described in README. ### riscv-gnu-toolchain & Spike Both binutils and Spike(riscv-isa-sim) currently don't support the RISC-V P extension officially. Thus, we implement it and add instructions that may be used in this flow. Please refer to the project below. - riscv-binutils : https://github.com/nthu-pllab/riscv-binutils-rvp - riscv-isa-sim : https://github.com/nthu-pllab/riscv-isa-sim ## Next step We plan to collect the comments from the community and reorganize our code after, then the PR will be sent. --- # RISC-V Custom Runtime (DLR) This repository is a lightweight runtime for RISC-V utility. We uses [TVM](https://github.com/apache/incubator-tvm) runtime to create a C++ interface that can be called by `host.cpp` to invoke runtime function. After saving a TVM module as `.ll`, `.graph` and `.params`, with supporting from LLVM and `host.cpp`. Users can further compile the program with riscv-gun-toolcahin and easily run the inference with Spike. ## How to use ![](https://i.imgur.com/HJ4nOv2.png) ## Build steps There are 2 targets in DLR. - A static library: `libDLR.a` - for compiling with host code - An executable: `DumpKernel` - for generating kernel.inc from graph Because `libDLR.a` is for compiling with host code and eventually be executed on Spike, it need to be compiled by gcc/g++ of riscv-gnu-toolchain. On the other hand, `DumpKernel` can be executed directly, so it has no need to be compiled by gcc/g++ of riscv-gnu-toolchain. ```sh ############################## # RISCV-build for libDLR.a # ############################## # set RISCV to where riscv-gnu-toolchain installed mkdir build-riscv cd build-riscv CC=$RISCV/bin/riscv64-unknown-elf-gcc \ CXX=$RISCV/bin/riscv64-unknown-elf-g++ \ cmake .. \ -DCMAKE_INSTALL_PREFIX=../install-riscv cmake --build . cmake --install . cd .. ################################# # normal build for DumpKernel # ################################# mkdir build && cd build cmake .. \ -DCMAKE_INSTALL_PREFIX=../install cmake --build . cmake --install . cd .. ``` ## Example - Pre-quantized mobilenet v1 model The example is shown in `example/pre_quant_mobilenet_v1_tflite` - In `build_model.py`, we run the TVM routine and save the TVM module. - In `host.cpp`, we invoke function in DLR for init, load input, and running. After setting up the paths of dependences in Makefile, you can run all the flow easily, as illustrated in the figure above. - **RISCV_INSTALL**: where riscv-gnu-toolchain is installed - **DLR_INSTALL**: as the explanation in "Build Steps" section, we need to build DLR twice for different purpose. This one is for `DumpKernel` executable to generate `kernel.inc`. - **DLR_RISCV_INSTALL**: This one is for `libDLR.a` static library, which is compiled by gcc/g++ of riscv-gnu-toolchain. The static library will be compiled with the host code and run on Spike. - **LLVM_INSTALL**: modified LLVM with RISC-V P extension support. ```shell make make run ``` ## Dependent project - LLVM - The version we used as target for TMV QNN with RISC-V P extension support is open source at [here](https://github.com/nthu-pllab/llvm9-project-rvp). - riscv-gun-toolchain - To support our flow for TVM QNN, please use [riscv-binutils](https://github.com/nthu-pllab/riscv-binutils-rvp) and [spike](https://github.com/nthu-pllab/riscv-isa-sim) that we modified for supporting P extension ### Acknowledgement - Apache TVM ### CONTRIBUTORS.md - National Tsing-Hua University, PLLab - Jenq-Kuen Lee (jklee@cs.nthu.edu.tw) - Chia-Hsuan Chang (chchang@pllab.cs.nthu.edu.tw) - Yi-Ru Chen (yrchen@pllab.cs.nthu.edu.tw) - Hui-Hsin Liao (hhliao@pllab.cs.nthu.edu.tw) - Peakhills Group - Chao-Lin Lee (clli@peakhillsgroup.com) - Andes Technology - Yuan-Ming Chang (ymchang@andestech.com) - Chun-Chieh Yang (jet4085@andestech.com) - Allen Lu (allen@andestech.com) - Pi-You Chen (piyou@andestech.com) --- # RFC Q&A ## comaniac ### Q1 : Thanks for the RFC! While I’m not familiar with the current RISC-V applications, I’m carious about the purpose of running Spike simulator and what would be the usual next step after it. ### Q2: I also have some questions/thoughts about the implementation. In general I’m thinking if it would be better to integrate this flow via BYOC to provide more flexibility and opportunities for future hereogeneous execution. - 初步共識: NO BYOC!! ### Q3: You mentioned to use LLVM as the backend. How does this LLVM backend overlap to the current TVM LLVM backend? Will you reuse most of it, or you almost build another backend using LLVM? - We bascially reuse TVM LLVM backend, the file `codegen_riscv.cc` is similar with files that support other backend in `src/target/llvm` (codegen_arm.cc for example). ### Q4: I didn’t quite get the point of “since Spike doesn’t support parallel computing, we use an empty schedule for schedule_injective() , except for Ops that gonna be vectorized”. Does that mean you still have schedules for the ops that can be vectorized? If so, do we need someone to write schedules for RISC-V P on Spike in TOPI? - Yes, we can still enable vectorize on Ops after splitting(like add in example). LLVM will then do the pattern match from vectorized operation to specific RVP instructions. In TOPI, we may need to register schedule for specific Op. Actually, most of our work was done before `op_strategy`, we're still trying to integrate our code in the version with `op_strategy`. ### Q5: In terms of the runtime, currently TVM graph runtime includes several modules, such as metadata module and external runtime modules (for the case of BYOC). Where would your custom runtime be? ## areusch ### Q1: is your eventual target bare metal devices, or does your runtime require a kernel? Yes, the runtime is designed for running on a bare metal. ### Q2: `riscv_cpu target`: in the past we had introduced a special micro_dev target for µTVM work. recently, we deprecated that in favor of llvm and c targets. then, when creating the list of candidate schedules for a given op, we (for ARM) analyze the ISA supported by the CPU in -mcpu. is it possible to do something similar with risc-v (I.e. encode the P extension in some flag -mcpu=rv32p)? Thank you for your suggestion, our team haven't check uTVM flow before, we'll look into it. ### Q3: LLVM support for riscv P extension, and codegen: since you will need to build TVM against a forked LLVM, is it possible to use the c backend for any tests in the CI, until LLVM formally supports RISC-V P? it could be possible then to include a forked llvm compiler in one of the CI docker images, but still compile TVM against mainline LLVM. you could take a look at the GEMM impl for cortex-m7 as an example of how to do that. ### Q4: I’m also beginning to look at AOT compilation, which looks somewhat similar to your kernel.inc code (but would be generated from TVM). there are some additional considerations such as memory planning that may depend more on the device layout. do you have a full example of the kernel.inc anywhere I could look at? - Yes, sure! We put it in the same repository, please check at .... ### Q5: looks like the function signatures in your DLR differ from the typically generated signature: ```cpp typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args, TVMValue* out_ret_value, int* out_ret_tcode, void* resource_handle); ``` seems like the main difference between this func and DLR func is lack of out_* and resource_handle params? ### Q6: did you try using the new µTVM RPC server-based runtime with spike? this would allow you to use the graph runtime in the TVM python binary and perform autotuning. would it be possible to use that to submit the schedules as one PR and then split any runtime changes into another? we modified the micro_tflite tutorial to demonstrate one use of that runtime. - 哲嘉, 明瀚.. ### Q7: I don’t quite understand your evaluation numbers. are these measured over a fixed time period? otherwise, it seems like there should be fewer instructions executed using the intrinsic for one inference run, correct? Yes, sorry for the wrong table title, the one with less instructions should be "Pre-quantized with tensorization/vectorization" and the one with more is FP32 version. ### Q8: what is your plan for upstreaming binutils and riscv-isa-sim work? Upstreaming for `riscv-isa-sim` is in progress, we're cowork with Andes technology(chair for RVP) recetnely. As for `riscv-binutils`, ... ### Q9: for testing in CI, would we need to build a spike docker image? --- # Next Step ## Faced problem - Should using their flow for runtime instead of DLR - try to use the app that included in TVM upstream - bundle_deploy - uTVM - may need a board - go through OpenOCD - they refused to use our own LLVM - can't use intrinsic for tensorization - where should the code in `tensor_intrin` match to ?? - other solution for tensorization : using [C++](https://github.com/apache/incubator-tvm/blob/master/python/tvm/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py) - For arm, the intrinsics they used are included in the upstream LLVM - Organize the code - Handle pass - [ ] alterOpLayout - [ ] Legalize - [ ] DynamicToStatic - Disable getting configure value from AutoTVM - Op strategy - [ ] convolution related - [ ] dense - [ ] ... - Handle int32 operation with RVP codegen - either by intrinsic or llvm vector handle ## Bundle_deploy - https://hackmd.io/43P068RfRdS7739FKwN1CQ - Trying to go through this flow with - llvm backend (from our own LLVM) - c/c++ runtime from uTVM