# TVM with Fixed-point Type > Bing-Sung Lu --- ## 1. Fixed-Point The arithmetic instructions in RISC-V p extension only work with integer type. Therefore, we implement fixed-point type in TVM to quantize the inference program from floating-point type to integer type. ### Transform Floating-point to Fixed-point The value of a fixed-point type variable is determined by the value of its floating-point counter part and the point position which can be set by the users. The actual relationship between fixed-point variable and floating-point vairable can be depicted as follow: We refer Fxp as fixed-point value, Fp as floating-point value and PP as point position $Fxp = Fp * pow(2,PP)$ We can see that point position is actually the exponent of 2, we multiply it with floating-point value and retain the integer part as fixed-point value ### Implementing Fixed-point type in TVM If we want to build a NN model with fixed-point type while using nnvm or relay as below, we need to set the fixed-point type first in python code. ``` with nnvm.compiler.build_config(opt_level=0): graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype='fxp16_13', params=params) ``` In this case, the fixed-point type is presented as 'fxp16_13'. 16 is the number of bits in the fixed-point variable. 13 is the point position. #### NNVM/Relay Python First, we need to modify graph_attr.py to add fxp16 in the dtype. ``` DTYPE_TO_TCODE = { "default": -1, "float32": 0, "float64": 1, "float16": 2, "uint8": 3, "int32": 4, "int8": 5, "int64": 6, "int16": 7, "uint16": 8, "uint32": 9, "uint64": 10, "bool": 11, "fxp16": 12, } ``` Since we can set different point position in fxp type, we should add extra conditional statements to translate dtype from string to code. We can add them in function - set_dtype_inputs. ``` def set_dtype_inputs(g, dtype): if isinstance(dtype, dict): list_dtype = [ DTYPE_TO_TCODE[str(dtype.get(name, "default"))] for name in g.index.input_names] else: if dtype.find('fxp') != -1: fxp_pos=int(dtype[dtype.find('_')+1:]) dtype=dtype[:dtype.find('_')] list_dtype = [DTYPE_TO_TCODE[dtype]+fxp_pos-1] * len(g.index.input_names) else: list_dtype = [DTYPE_TO_TCODE[dtype]] * len(g.index.input_names) g._set_json_attr("dtype_inputs", list_dtype, "list_int") return g ``` #### NNVM/Relay C++ The fxp type will be passed down from python to c++ code. Therefore we also have to modify the 'compile_engine.cc' to implement fxp type. The modified functions are 'GetTypeFlag' and 'GetTVMType' ``` int GetTypeFlag(tvm::Type type) { if (type == tvm::Float(32)) return 0; if (type == tvm::Float(64)) return 1; if (type == tvm::Float(16)) return 2; if (type == tvm::UInt(8)) return 3; if (type == tvm::Int(32)) return 4; if (type == tvm::Int(8)) return 5; if (type == tvm::Int(64)) return 6; if (type == tvm::Int(16)){ if (type.handle_type!=nullptr) { std::string fxp_name = type.handle_type->inner_name.name; if(fxp_name.find("fxp")!=std::string::npos) { int fxp_pos = std::stoi(fxp_name.substr(fxp_name.find("_")+1)); return 11+fxp_pos; } } return 7; } if (type == tvm::UInt(16)) return 8; if (type == tvm::UInt(32)) return 9; if (type == tvm::UInt(64)) return 10; if (type == tvm::UInt(1)) return 11; LOG(FATAL) << "cannot convert " << type; return 0; } Type GetTVMType(int type_flag) { switch (type_flag) { case 0: return tvm::Float(32); case 1: return tvm::Float(64); case 2: return tvm::Float(16); case 3: return tvm::UInt(8); case 4: return tvm::Int(32); case 5: return tvm::Int(8); case 6: return tvm::Int(64); case 7: return tvm::Int(16); case 8: return tvm::UInt(16); case 9: return tvm::UInt(32); case 10: return tvm::UInt(64); case 11: return tvm::UInt(1); case 12: case 13: case 14: case 15: case 16: case 17: case 18: case 19: case 20: case 21: case 22: case 23: case 24: case 25: case 26: case 27: return tvm::Fxp(16, type_flag-11); default: LOG(FATAL) << "unknown type_flag=" << type_flag; return Float(32); } } ``` #### Halide IR The next step is to add the fixed-point type in halide ir. We could add the new fxp type in type.h in halide ir. It's basically a 16-bit integer type, the only different is that we add extra information of point position in "halideir_handle_cplusplus_type". ``` inline Type Fxp(int bits, int fxp_pos, int lanes = 1) { halideir_handle_cplusplus_type *fxp_info = new halideir_handle_cplusplus_type{ halideir_cplusplus_type_name(halideir_cplusplus_type_name::Simple, "fxp16_"+std::to_string(fxp_pos)), {}, {}, {}}; return Type(Type::Int,bits, lanes,fxp_info); } ``` >Note: It's obviously not how halideir_handle_cplusplus_type intends to be used. Back to TVM, now we can define the new type from halide ir in TVM. In expr.h we add fxp type in code. ``` using HalideIR::Fxp; ``` The TVM will transform halide ir type into TVMType and pass down to runtime. Since there's no space to add point position information in TVMType. we can only give each point position a unique code. We also need to modify the function that transform halide ir type to TVMType and vice versa. ``` inline Type TVMType2Type(TVMType t) { if(t.code>=kFxp16Begin&&t.code<=kFxp16End) { return Fxp(16, 128-t.code); } return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes); } inline TVMType Type2TVMType(Type t) { TVMType ret; ret.code = static_cast<uint8_t>(t.code()); ret.bits = static_cast<uint8_t>(t.bits()); ret.lanes = static_cast<uint16_t>(t.lanes()); if(t.code()==Type::Int&&t.handle_type!=nullptr) { std::string fxp_name = t.handle_type->inner_name.name; int fxp_pos = std::stoi(fxp_name.substr(fxp_name.find("_")+1)); ret.code = static_cast<uint8_t>(128-fxp_pos); } return ret; } ``` We also need to define 'kFxp16Begin' and 'kFxp16End' in runtime source code 'c_runtime_api.h' ``` typedef enum { kHandle = 3U, kNull = 4U, kTVMType = 5U, kTVMContext = 6U, kArrayHandle = 7U, kNodeHandle = 8U, kModuleHandle = 9U, kFuncHandle = 10U, kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, kExtBegin = 15U, kNNVMFirst = 16U, kNNVMLast = 20U, // The following section of code is used for non-reserved types. kExtReserveEnd = 64U, kFxp16Begin = 112U, kFxp16End = 127U, kExtEnd = 128U } TVMTypeCode; ``` #### Runtime Since our program runs on our own modified DLR. We should modify this part on both DLR and TVM. The file is 'packed_func.h'. we mainly modify the interpreter function that transform string to code number and vice versa. ``` inline std::string TVMType2String(TVMType t) { if (t.bits == 0) return ""; #ifndef _LIBCPP_SGX_NO_IOSTREAMS std::ostringstream os; os << t; return os.str(); #else if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { return "bool"; } repr += TypeCode2Str(t.code); if (t.code == kHandle) return repr; repr += std::to_string(static_cast<int>(t.bits)); if (t.lanes != 1) { repr += "x" + std::to_string(static_cast<int>(t.lanes)); } if(t.code>=kFxp16Begin&&t.code<=kFxp16End) { repr+="_"+std::to_string(static_cast<int>(128-t.code)); } return repr; #endif } ``` Since we pass the information of point position with LLVM IR. It's not necessary to pass this information to DLR. DLR can regard fxp as integer. ``` inline TVMType String2TVMType(std::string s) { TVMType t; // handle None type if (s.length() == 0) { t.bits = 0; t.lanes = 0; t.code = kHandle; return t; } t.bits = 32; t.lanes = 1; const char* scan; if (s.substr(0, 3) == "int") { t.code = kDLInt; scan = s.c_str() + 3; } else if (s.substr(0, 3) == "fxp") { t.code = kDLInt; s = s.substr(0,s.find("_")); scan = s.c_str() + 3; } else if (s.substr(0, 4) == "uint") { t.code = kDLUInt; scan = s.c_str() + 4; } else if (s.substr(0, 5) == "float") { t.code = kDLFloat; scan = s.c_str() + 5; } else if (s.substr(0, 6) == "handle") { t.code = kHandle; t.bits = 64; // handle uses 64 bit by default. scan = s.c_str() + 6; } else if (s == "bool") { t.code = kDLUInt; t.bits = 1; t.lanes = 1; return t; } else { scan = s.c_str(); LOG(FATAL) << "unknown type " << s; } char* xdelim; // emulate sscanf("%ux%u", bits, lanes) uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10)); if (bits != 0) t.bits = bits; char* endpt = xdelim; if (*xdelim == 'x') { t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10)); } CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; return t; } ```