Try   HackMD

Assignment1: RISC-V Assembly and Instruction Pipeline

contributed by <wx200010> ( Shao-Huan, Yu )

Abstract

In this assignment, I chose Problem C as my topic. Since Problem C utilizes a technique for counting leading zeros to find the MSB position of the mantissa in the float16 to float32 conversion, I discovered that this technique can also accelerate computations in the sqrt(x) problem. By using a binary search method to find the square root, employing the CLZ function allows for quickly obtaining the MSB of the input 𝑥. By right-shifting it by half, we can derive a smaller binary search range, thereby reducing the number of search iterations.

In each assembly code display, if there are parts related to reducing RISC-V overhead or accelerating computational efficiency, I will write them before the code display.

Problem C to RISC-V

Introduction to Float Formats

In problem C, we are tasked with implementing the conversion of a float16 value into a float32 format.

The float32 format is represented in binary as follows:

sign(1 bit) exponent(8 bits) mantissa(23 bits)
0 00000000 00000000000000000000000

And the float16 format is represented in binary as follows:

sign(1 bit) exponent(5 bits) mantissa(10 bits)
0 00000 0000000000

Sign bit

The Sign bit indicates the sign of the floating-point number and is interpreted as follows:

  • 0 : The number is positive
  • 1 : The number is negative

Exponent bits

These bits are used to store the exponent value in a "biased" format, meaning that a constant (called the bias) is added to the actual exponent to allow both positive and negative exponents to be represented.

The exponent is stored as an unsigned integer and the actual exponent
E is calculated as:E = Exponent bits − bias

  • The bias for float32 is 127
  • The bias for float16 is 15

Mantissa bits

The mantissa in floating-point formats represents the precision of the number. In both float32 and float16 formats, the mantissa bits encode the significant digits of the number.

TODO

In this assignment, there are three main functions to be implemented: fabsf, my_clz, and fp16_to_fp32. Below are the descriptions of each function along with the relevant code.

fabsf

The fabsf function computes the absolute value of a floating-point number. It takes a single float as input and returns the non-negative value of that float. If the input is negative, it will remove the sign bit, effectively returning the magnitude of the number.

C code in Problem C

The original fabsf code in Problem C is as follows.

static inline float fabsf(float x) { uint32_t i = *(uint32_t *)&x; // Read the bits of the float into an integer i &= 0x7FFFFFFF; // Clear the sign bit to get the absolute value x = *(float *)&i; // Write the modified bits back into the float return x; }

Assembly code

Here is the RISC-V assembly code converted from the above C code:

fabsf: # a0 is the input parameter x li t0, 0x7FFFFFFF and a0, a0, t0 # a0 = a0 & 0x7FFFFFFF ret

my_clz (count leading zero)

The my_clz function counts the number of leading zeros in a 32-bit unsigned integer. It efficiently determines the position of the most significant bit (MSB).

C code in problem C

The original my_clz code in Problem C is as follows.

static inline int my_clz(uint32_t x) { int count = 0; for (int i = 31; i >= 0; --i) { if (x & (1U << i)) break; count++; } return count; }

Optimized to branchless C code

Since the original my_clz implementation used a loop for searching, the worst-case scenario required about 32 iterations to complete the computation. Therefore, I replaced it with a branchless version, primarily composed of a combination of binary search and a lookup table. This approach is more convenient for writing in RISC-V and offers improved performance.

static inline int my_clz(uint32_t x) { int r = 0, c; c = (x < 0x00010000) << 4; r += c; x <<= c; // off 16 c = (x < 0x01000000) << 3; r += c; x <<= c; // off 8 c = (x < 0x10000000) << 2; r += c; x <<= c; // off 4 c = (x >> (32 - 4 - 1)) & 0x1e; return r + ((0x55af >> c) & 3); }

Assembly code for branchless my_clz

During the implementation process, I noticed that when translating the conditions x < 0x00010000, x < 0x01000000, and x < 0x10000000, the subsequent constants exceed 12 bits. As a result, using li to load values like 0x00010000 into a register requires two instructions to complete the operation.

However, since these values are only one bit is set to 1, after using li to load 0x00010000 the first time, the subsequent values 0x01000000 and 0x10000000 can be replaced with slli instructions, thereby reducing the overhead of the assembly code.
Here is the RISC-V assembly code converted from the above C code:

my_clz: # a0 is the input parameter x # t0 is r # t1 is c # t2 is tmp li t0, 0 # r = 0 li t2, 0x00010000 # tmp = 0x00010000 slt t1, a0, t2 slli t1, t1, 4 # c = (x < 0x00010000) << 4; add t0, t0, t1 # r += c sll a0, a0, t1 # x <<= c slli t2, t2, 8 slt t1, a0, t2 slli t1, t1, 3 # c = (x < 0x01000000) << 3; add t0, t0, t1 # r += c sll a0, a0, t1 # x <<= c slli t2, t2, 4 slt t1, a0, t2 slli t1, t1, 2 # c = (x < 0x10000000) << 2; add t0, t0, t1 # r += c sll a0, a0, t1 # x <<= c srli t1, a0, 27 andi t1, t1, 0x1e # c = (x >> (32 - 4 - 1)) & 0x1e li a0, 0x55af srl a0, a0, t1 andi a0, a0, 3 add a0, t0, a0 # return r + (0x55af >> c) & 3 ret

fp16_to_fp32

The fp16_to_fp32 function converts a half-precision floating-point number (float16) to single-precision floating-point number (float32).

C code in problem C

In this original C code, fp16_to_fp32 only calls my_clz to calculate the MSB of the mantissa and does not call fabs, but instead incorporates the content of fabs directly within the function.

static inline uint32_t fp16_to_fp32(uint16_t h) { /* * Extends the 16-bit half-precision floating-point number to 32 bits * by shifting it to the upper half of a 32-bit word: * +---+-----+------------+-------------------+ * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| * +---+-----+------------+-------------------+ * Bits 31 26-30 16-25 0-15 * * S - sign bit, E - exponent bits, M - mantissa bits, 0 - zero bits. */ const uint32_t w = (uint32_t) h << 16; /* * Isolates the sign bit from the input number, placing it in the most * significant bit of a 32-bit word: * * +---+----------------------------------+ * | S |0000000 00000000 00000000 00000000| * +---+----------------------------------+ * Bits 31 0-31 */ const uint32_t sign = w & UINT32_C(0x80000000); /* * Extracts the mantissa and exponent from the input number, placing * them in bits 0-30 of the 32-bit word: * * +---+-----+------------+-------------------+ * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| * +---+-----+------------+-------------------+ * Bits 30 27-31 17-26 0-16 */ const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); /* * The renorm_shift variable indicates how many bits the mantissa * needs to be shifted to normalize the half-precision number. * For normalized numbers, renorm_shift will be 0. For denormalized * numbers, renorm_shift will be greater than 0. Shifting a * denormalized number will move the mantissa into the exponent, * normalizing it. */ uint32_t renorm_shift = my_clz(nonsign); renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; /* * If the half-precision number has an exponent of 15, adding a * specific value will cause overflow into bit 31, which converts * the upper 9 bits into ones. Thus: * inf_nan_mask == * 0x7F800000 if the half-precision number is * NaN or infinity (exponent of 15) * 0x00000000 otherwise */ const int32_t inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); /* * If nonsign equals 0, subtracting 1 will cause overflow, setting * bit 31 to 1. Otherwise, bit 31 will be 0. Shifting this result * propagates bit 31 across all bits in zero_mask. Thus: * zero_mask == * 0xFFFFFFFF if the half-precision number is * zero (+0.0h or -0.0h) * 0x00000000 otherwise */ const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; /* * 1. Shifts nonsign left by renorm_shift to normalize it (for denormal * inputs). * 2. Shifts nonsign right by 3, adjusting the exponent to fit in the * 8-bit exponent field and moving the mantissa into the correct * position within the 23-bit mantissa field of the single-precision * format. * 3. Adds 0x70 to the exponent to account for the difference in bias * between half-precision and single-precision. * 4. Subtracts renorm_shift from the exponent to account for any * renormalization that occurred. * 5. ORs with inf_nan_mask to set the exponent to 0xFF if the input * was NaN or infinity. * 6. ANDs with the inverted zero_mask to set the mantissa and exponent * to zero if the input was zero. * 7. Combines everything with the sign bit of the input number. */ return sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask); }

Assembly code for fp16_to_fp32

When converting fp16_to_fp32 to assembly language, I employed several techniques:

  1. I noticed that my_clz is only called once in fp16_to_fp32, so I can directly incorporate the content of my_clz into the fp16_to_fp32 function, eliminating the need to handle function calls and save registers, thereby reducing overhead.
  2. Since my_clz has already been optimized to be branchless, the assembly implementation of fp16_to_fp32 will not involve any looping process, which enhances computational efficiency.
fp16_to_fp32: # a0 is h # t0 is w # t1 is sign # t2 is nonsign # t3 is renorm_shift # t4 is inf_nan_mask # t5 is zero_mask slli t0, a0, 16 # w = (uint32_t) h << 16; li t2, 0x80000000 and t1, t0, t2 # sign = w & UINT32_C(0x80000000); li t2, 0x7FFFFFFF and t2, t0, t2 # nonsign = w & UINT32_C(0x7FFFFFFF); mv t3, t2 # renorm_shift = nonsign # renorm_shift = my_clz(nonsign) after my_clz labels my_clz: # t3 is the input parameter x # t4 is r # t5 is c # t6 is tmp li t4, 0 # r = 0 li t6, 0x00010000 # tmp = 0x00010000 sltu t5, t3, t6 # c = (x < 0x00010000) slli t5, t5, 4 # c = (x < 0x00010000) << 4; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c slli t6, t6, 8 sltu t5, t3, t6 # c = (x < 0x01000000) slli t5, t5, 3 # c = (x < 0x01000000) << 3; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c slli t6, t6, 4 sltu t5, t3, t6 # c = (x < 0x10000000) slli t5, t5, 2 # c = (x < 0x10000000) << 2; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c srli t5, t3, 27 andi t5, t5, 0x1e # c = (x >> (32 - 4 - 1)) & 0x1e li t3, 0x55af srl t3, t3, t5 andi t3, t3, 3 add t3, t4, t3 # renorm_shift = r + (0x55af >> c) & 3 my_clz_end: # renorm_shift = my_clz(nonsign) li t4, 5 bleu t3, t4, renorm_shift_zero # if renorm_shift <= 5, then renorm_shift = 0 renorm_shift_substract5: addi t3, t3, -5 # else renorm_shift = renorm_shift - 5 j renorm_shift_end renorm_shift_zero: li t3, 0 # renorm_shift = 0 renorm_shift_end: li t5, 0x04000000 add t4, t2, t5 # inf_nan_mask = nonsign + 0x04000000 srai t4, t4, 8 # inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) li t5, 0x7F800000 and t4, t4, t5 # inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); addi t5, t2, -1 # zero_mask = (int32_t)(nonsign - 1) srai t5, t5, 31 # zero_mask = (int32_t)(nonsign - 1) >> 31; # finally, we want to return sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask) # we will first calculate (nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) sll t2, t2, t3 srli t2, t2, 3 # tmp1 = nonsign << renorm_shift >> 3 li t6, 0x70 sub t3, t6, t3 # tmp2 = 0x70 - renorm_shift slli t3, t3, 23 # tmp2 = (0x70 - renorm_shift) << 23 add t2, t2, t3 # tmp2 = (nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) # and calculate (((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) | inf_nan_mask) & ~zero_mask) or t2, t2, t4 # tmp2 = ((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) | inf_nan_mask) not t5, t5 # zero_mask = ~zero_mask and t2, t2, t5 # tmp2 = ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask) ori a0, t1, 0 # result = sign or a0, a0, t2 # result = sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask) ret # return result

Complete C code

Here is the complete C code used for testing the results, where my_clz and fp16_to_fp32 are the previously mentioned functions with comments removed. The main function is added at the end to print the converted values:

#include <stdio.h> #include <stdint.h> static inline uint32_t my_clz(uint32_t x) { int r = 0, c; c = (x < 0x00010000) << 4; r += c; x <<= c; // off 16 c = (x < 0x01000000) << 3; r += c; x <<= c; // off 8 c = (x < 0x10000000) << 2; r += c; x <<= c; // off 4 c = (x >> (32 - 4 - 1)) & 0x1e; return r + ((0x55af >> c) & 3); } static inline uint32_t fp16_to_fp32(uint16_t h) { const uint32_t w = (uint32_t)h << 16; const uint32_t sign = w & UINT32_C(0x80000000); const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); uint32_t renorm_shift = my_clz(nonsign); renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; const int32_t inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; return sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask); } int main() { uint32_t datas[] = { 0x0710, // normalized number 0x311F, // normalized number 0x000F, // denormalized number 0x0000, // positive zero 0x8000, // negative zero 0x7C00, // positive inf 0xFC00, // negative inf 0x7CFF // NaN }; uint32_t results[] = { fp16_to_fp32(datas[0]), fp16_to_fp32(datas[1]), fp16_to_fp32(datas[2]), fp16_to_fp32(datas[3]), fp16_to_fp32(datas[4]), fp16_to_fp32(datas[5]), fp16_to_fp32(datas[6]), fp16_to_fp32(datas[7])}; printf("\nfp16_to_fp32(0x0710) is : 0x%x ", results[0]); // normalized number printf("\nfp16_to_fp32(0x311F) is : 0x%x ", results[1]); // normalized number printf("\nfp16_to_fp32(0x000F) is : 0x%x ", results[2]); // denormalized number printf("\nfp16_to_fp32(0x0000) is : 0x%x ", results[3]); // positive zero printf("\nfp16_to_fp32(0x8000) is : 0x%x ", results[4]); // negative zero printf("\nfp16_to_fp32(0x7C00) is : 0x%x ", results[5]); // positive inf printf("\nfp16_to_fp32(0xFC00) is : 0x%x ", results[6]); // negative inf printf("\nfp16_to_fp32(0x7CFF) is : 0x%x ", results[7]); // NaN }

output messages:
fp16_to_fp32(0x0710) is : 0x38e20000
fp16_to_fp32(0x311F) is : 0x3e23e000
fp16_to_fp32(0x000F) is : 0x35700000
fp16_to_fp32(0x0000) is : 0x0
fp16_to_fp32(0x8000) is : 0x80000000
fp16_to_fp32(0x7C00) is : 0x7f800000
fp16_to_fp32(0xFC00) is : 0xff800000
fp16_to_fp32(0x7CFF) is : 0x7f9fe000

Complete assembly code

I add a main function and testing data at the beginning to output the converted values, making it easier to compare the results with those from the C code execution.

Additionally, I organized the results from the execution of the C code into an ans array in the RISC-V code, which will automatically compare the answers obtained from the C code during the process. If they are not equal, it will print "\nthe answer is wrong!!!".

Here’s the modified assembly code:

.data datas: .word 0x0710, 0x311F, 0x000F, 0x0000, 0x8000, 0x7C00, 0xFC00, 0x7CFF ans: .word 0x38e20000, 0x3e23e000, 0x35700000, 0x0, 0x80000000, 0x7f800000, 0xff800000, 0x7f9fe000 str1: .string "\nfp16_to_fp32(0x0710) is : " str2: .string "\nfp16_to_fp32(0x311F) is : " str3: .string "\nfp16_to_fp32(0x000F) is : " str4: .string "\nfp16_to_fp32(0x0000) is : " str5: .string "\nfp16_to_fp32(0x8000) is : " str6: .string "\nfp16_to_fp32(0x7C00) is : " str7: .string "\nfp16_to_fp32(0xFC00) is : " str8: .string "\nfp16_to_fp32(0x7CFF) is : " strError: .string "\nthe answer is wrong!!!" strs: .word str1, str2, str3, str4, str5, str6, str7, str8 .text main: la s6, ans # Load ans reference la s7, datas # Load datas reference la s8, strs # Load strs references li s9, 8 # Load the loop count print_numbers: lw a0, 0(s8) # Load string reference li a7, 4 # print string ecall lw a0, 0(s7) # Load data jal ra, fp16_to_fp32 # calculate fp16_to_fp32(data) li, a7, 34 # print the result in hex format ecall validation: lw t0, 0(s6) # Load ans sub t0, t0, a0 # calculate ans - result for validation beqz t0, check_loop # if (ans - result) == 0 then skip la a0, strError li a7, 4 # print error message!!! ecall check_loop: addi s6, s6, 4 # shift ans index addi s7, s7, 4 # shift datas index addi s8, s8, 4 # shift strs index addi s9, s9, -1 # loop count - 1 bnez s9, print_numbers exit: # Exit the program li a7, 10 # System call code for exiting the program ecall # Make the exit system call ret fp16_to_fp32: # a0 is h # t0 is w # t1 is sign # t2 is nonsign # t3 is renorm_shift # t4 is inf_nan_mask # t5 is zero_mask slli t0, a0, 16 # w = (uint32_t) h << 16; li t2, 0x80000000 and t1, t0, t2 # sign = w & UINT32_C(0x80000000); li t2, 0x7FFFFFFF and t2, t0, t2 # nonsign = w & UINT32_C(0x7FFFFFFF); mv t3, t2 # renorm_shift = nonsign # renorm_shift = my_clz(nonsign) after my_clz labels my_clz: # t3 is the input parameter x # t4 is r # t5 is c # t6 is tmp li t4, 0 # r = 0 li t6, 0x00010000 # tmp = 0x00010000 sltu t5, t3, t6 # c = (x < 0x00010000) slli t5, t5, 4 # c = (x < 0x00010000) << 4; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c slli t6, t6, 8 sltu t5, t3, t6 # c = (x < 0x01000000) slli t5, t5, 3 # c = (x < 0x01000000) << 3; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c slli t6, t6, 4 sltu t5, t3, t6 # c = (x < 0x10000000) slli t5, t5, 2 # c = (x < 0x10000000) << 2; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c srli t5, t3, 27 andi t5, t5, 0x1e # c = (x >> (32 - 4 - 1)) & 0x1e li t3, 0x55af srl t3, t3, t5 andi t3, t3, 3 add t3, t4, t3 # renorm_shift = r + (0x55af >> c) & 3 my_clz_end: # renorm_shift = my_clz(nonsign) li t4, 5 bleu t3, t4, renorm_shift_zero # if renorm_shift <= 5, then renorm_shift = 0 renorm_shift_substract5: addi t3, t3, -5 # else renorm_shift = renorm_shift - 5 j renorm_shift_end renorm_shift_zero: li t3, 0 # renorm_shift = 0 renorm_shift_end: li t5, 0x04000000 add t4, t2, t5 # inf_nan_mask = nonsign + 0x04000000 srai t4, t4, 8 # inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) li t5, 0x7F800000 and t4, t4, t5 # inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); addi t5, t2, -1 # zero_mask = (int32_t)(nonsign - 1) srai t5, t5, 31 # zero_mask = (int32_t)(nonsign - 1) >> 31; sll t2, t2, t3 srli t2, t2, 3 # tmp1 = nonsign << renorm_shift >> 3 li t6, 0x70 sub t3, t6, t3 # tmp2 = 0x70 - renorm_shift slli t3, t3, 23 # tmp2 = (0x70 - renorm_shift) << 23 add t2, t2, t3 # tmp2 = (nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) or t2, t2, t4 # tmp2 = ((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) | inf_nan_mask) not t5, t5 # zero_mask = ~zero_mask and t2, t2, t5 # tmp2 = ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask) ori a0, t1, 0 # result = sign or a0, a0, t2 # result = sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask) ret # return result

ouputs
fp16_to_fp32(0x0710) is : 0x38e20000
fp16_to_fp32(0x311F) is : 0x3e23e000
fp16_to_fp32(0x000F) is : 0x35700000
fp16_to_fp32(0x0000) is : 0x0000
fp16_to_fp32(0x8000) is : 0x80000000
fp16_to_fp32(0x7C00) is : 0x7f800000
fp16_to_fp32(0xFC00) is : 0xff800000
fp16_to_fp32(0x7CFF) is : 0x7f9fe000
Program exited with code: 0

Compare the cycle counts in 5-Stage RISC-V Processor

This section will compare the cycle count of my C code, after being compiled into assembly using the RISC-V GCC compiler, with the cycle count of my self-written RISC-V assembly code.

Ripes provides a RISC-V C/C++ compiler feature that allows users to paste C code and convert it into RISC-V assembly. The compiler version I used is risc-v-gcc10.1.0.exe with -O3 arguments

Here is the execution results of the RISC-V code generated from C compiler and my RISC-V code:

generated from RISC-V C compiler my RISC-V code
image image

I'm not sure why the performance of the RISC-V code generated by compiling the C code is so different from the RISC-V code I wrote myself. It could be because the compiler version is outdated, or perhaps it's just normal.

LeetCode Problem: Sqrt(x)

Leetcode 69. Sqrt(x)

Description

Given a non-negative integer x, return the square root of x rounded down to the nearest integer. The returned integer should be non-negative as well.

You must not use any built-in exponent function or operator.

For example, do not use pow(x, 0.5) in c++ or x ** 0.5 in python.

Example 1:

Input: x = 4
Output: 2
Explanation: The square root of 4 is 2, so we return 2.

Example 2:

Input: x = 8
Output: 2
Explanation: The square root of 8 is 2.82842, and since weround it down to the nearest integer, 2 is returned.

Constraints:

0 <= x <= (2^31) - 1

The ideas and inspiration for the following C code are derived from this hackmd

Since the range of x in this problem is 0 <= x <= (2^31) - 1, we can use binary search, with the search range set between 1 and x. The cases for 0 can be handled with conditional statements at the beginning.

Therefore, the following is the first version of binary search based on the above idea:

int mySqrt(unsigned x) { if (x == 0) return x; uint32_t L, R; L = 1; R = x; while (1) { uint64_t M = (L + R) >> 1; if (M * M > x) R = M; else if ((M + 1) * (M + 1) <= x) // can do M*M + 2*M + 1 in assembly L = M; else return M; // return M when M*M <= x < (M+1)*(M+1) } }

However, this binary search can still be optimized in the following two ways:

  1. The search range of 1 to x is still too large. We can use the CLZ (count leading zeros) instruction to quickly find the MSB (most significant bit) of x, and use the MSB to generate a value that is close to but less than sqrt(x), thus reducing the search range.
  2. Based on the first point, when calculating M * M or (M+1)*(M+1), it will not exceed (2^32) - 1, so we can replace the uint64_t multiplication with uint32_t.

Optimizing the C solution using the CLZ method:

The following is the code after optimization based on the above points, which reduces the number of iterations in the binary search loop.

#include <stdio.h> #include <stdint.h> uint32_t my_clz(uint32_t x) { int r = 0, c; c = (x < 0x00010000) << 4; r += c; x <<= c; // off 16 c = (x < 0x01000000) << 3; r += c; x <<= c; // off 8 c = (x < 0x10000000) << 2; r += c; x <<= c; // off 4 c = (x >> (32 - 4 - 1)) & 0x1e; return r + ((0x55af >> c) & 3); } int mySqrt(unsigned x) { if (x == 0) return x; uint32_t temp, L, R, M; temp = 31 - my_clz(x); // using clz method to find the MSB location - 1 of x L = 1 << (temp >> 1); // set initial min as 2 ^ ((MSB location-1) / 2) R = L << 1; // set initial max as 2 ^ (((MSB location-1) / 2) + 1) while (1) { M = (L + R) >> 1; if (M * M > x) R = M; else if ((M + 1) * (M + 1) <= x) // can do M*M + 2*M + 1 in assembly L = M; else return M; // return M when M*M <= x < (M+1)*(M+1) } } int main() { uint32_t datas[] = {0, 1, 2, 4, 8, 2147483647}; uint32_t results[] = { mySqrt(datas[0]), mySqrt(datas[1]), mySqrt(datas[2]), mySqrt(datas[3]), mySqrt(datas[4]), mySqrt(datas[5])}; for (int i = 0; i < 6; ++i) printf("\nmySqrt(%u) is : %u", datas[i], results[i]); }

Outputs:
mySqrt(0) is : 0
mySqrt(1) is : 1
mySqrt(2) is : 1
mySqrt(4) is : 2
mySqrt(8) is : 2
mySqrt(2147483647) is : 46340

Complete Assembly code

Here is the RISC-V assembly code converted from the optimized C code above. The optimizations also made during the assembly conversion:

  1. (𝑀+1)×(𝑀+1) can be expanded into 𝑀×𝑀 + 2×𝑀 + 1, and 𝑀×𝑀 has already been calculated earlier. Therefore, we only need to add 2×𝑀 and then add 1 to the result, avoiding the need to use the multiplier again.
  2. Since my_clz is only called once in the mySqrt function, we can directly inline my_clz into mySqrt to eliminate the additional overhead of a function call.
.data datas: .word 0, 1, 2, 4, 8, 2147483647 ans: .word 0, 1, 1, 2, 2, 46340 str1: .string "\nmySqrt(0) is : " str2: .string "\nmySqrt(1) is : " str3: .string "\nmySqrt(2) is : " str4: .string "\nmySqrt(4) is : " str5: .string "\nmySqrt(8) is : " str6: .string "\nmySqrt(2147483647) is : " strError: .string "\nthe answer is wrong!!!" strs: .word str1, str2, str3, str4, str5, str6 .text main: la s6, ans # Load ans reference la s7, datas # Load datas reference la s8, strs # Load strs references li s9, 6 # Load the loop count print_numbers: lw a0, 0(s8) # Load string reference li a7, 4 # print string ecall lw a0, 0(s7) # Load data jal ra, mySqrt # calculate fp16_to_fp32(data) li, a7, 36 # print the result in unsigned format ecall validation: lw t0, 0(s6) # Load ans sub t0, t0, a0 # calculate ans - result for validation beqz t0, check_loop # if (ans - result) == 0 then skip la a0, strError li a7, 4 # print error message!!! ecall check_loop: addi s6, s6, 4 # shift ans index addi s7, s7, 4 # shift datas index addi s8, s8, 4 # shift strs index addi s9, s9, -1 # loop count - 1 bnez s9, print_numbers exit: # Exit the program li a7, 10 # System call code for exiting the program ecall # Make the exit system call ret mySqrt: # a0 is x # t0 is temp # t1 is L # t2 is R # t3 is M bnez a0, conditionSkip # if(x==0) then return x ret conditionSkip: mv t3, a0 # set t3 = x to calc my_clz(x) my_clz: # t3 is the input parameter x # t4 is r # t5 is c # t6 is tmp li t4, 0 # r = 0 li t6, 0x00010000 # tmp = 0x00010000 sltu t5, t3, t6 # c = (x < 0x00010000) slli t5, t5, 4 # c = (x < 0x00010000) << 4; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c slli t6, t6, 8 sltu t5, t3, t6 # c = (x < 0x01000000) slli t5, t5, 3 # c = (x < 0x01000000) << 3; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c slli t6, t6, 4 sltu t5, t3, t6 # c = (x < 0x10000000) slli t5, t5, 2 # c = (x < 0x10000000) << 2; add t4, t4, t5 # r += c sll t3, t3, t5 # x <<= c srli t5, t3, 27 andi t5, t5, 0x1e # c = (x >> (32 - 4 - 1)) & 0x1e li t3, 0x55af srl t3, t3, t5 andi t3, t3, 3 add t0, t4, t3 # temp = r + (0x55af >> c) & 3 my_clz_end: # temp = my_clz(x) li t4, 31 sub t0, t4, t0 # temp = 31 - my_clz(x) srli t0, t0, 1 # tmp = temp >> 1 li t1, 1 # L = 1 sll t1, t1, t0 # L = 1 << (temp >> 1) slli t2, t1, 1 # R = L << 1 binary_search_loop: add t3, t1, t2 # M = L + R srli t3, t3, 1 # M = (L + R) >> 1 mv t4, t3 # copy M to t4 mv t5, t3 # copy M to t5 li s0, 0 # set result = 0 multiple_loop: # calculate result = M * M andi t6, t4, 1 # check LSB of t4 beqz t6, skip_add # if LSB of t4 == 0 then skip add s0, s0, t5 # result = result + t5 skip_add: srli t4, t4, 1 # t4 = t4 >> 1 slli t5, t5, 1 # t5 = t5 << 1 bnez t4, multiple_loop multiple_end: bgtu s0, a0, squareM_is_bigger # if(M * M > x) then jump add s0, s0, t3 # result = M*M + M add s0, s0, t3 # result = M*M + 2*M addi s0, s0, 1 # result = (M+1)*(M+1) bleu s0, a0, squareM1_is_smaller # if((M+1)*(M+1) <= x) then jump breakLoop: # else return M mv a0, t3 ret # return M squareM_is_bigger: mv t2, t3 # R = M j binary_search_loop # continue squareM1_is_smaller: mv t1, t3 # L = M j binary_search_loop # continue

Outputs
mySqrt(0) is : 0
mySqrt(1) is : 1
mySqrt(2) is : 1
mySqrt(4) is : 2
mySqrt(8) is : 2
mySqrt(2147483647) is : 46340
Program exited with code: 0

Compare the cycle counts in 5-Stage RISC-V Processor

The compiler version I used is risc-v-gcc10.1.0.exe with -O3 arguments

Here is the execution results of the RISC-V code generated from C compiler and my RISC-V code:

generated from RISC-V C compiler my RISC-V code
image image

The 5-stage pipelined processor in RISC-V

image

Description

The 5-stage pipelined processor in RISC-V divides the instruction execution process into multiple stages, Here are the 5 stages of the RISC-V 5-stage pipeline processor:

  1. Instruction Fetch(IF) : fetches the next instruction to be executed from the instruction memory
  2. Instruction Decode (ID) : decodes the instruction, and the registers required for the operation are identified and fetched.
  3. Execute (EX) : performs the appropriate operations based on the instruction type, such as arithmetic operations, logical operations, or address calculations
  4. Memory Access (MEM) : For instructions that require memory access (such as load and store instructions), this stage performs the read or write operations to memory.
  5. Write Back (WB) : writes the execution results back to the register file

The pipeline allows multiple instructions to be processed simultaneously at different stages, thereby increasing the overall throughput of the processor.

Example

Taking the following RISC-V code as an example, this program will perform the CLZ operation on the value in t3 :

.text main: li t3, 0x050 my_clz: # t3 is the input parameter x # t4 is r # t5 is c # t6 is tmp li t4, 0 # r = 0 li t6, 0x00010000 # tmp = 0x00010000 sltu t5, t3, t6 # c = (x < 0x00010000) slli t5, t5, 4 # c = (x < 0x00010000) << 4; ...

Below is the executable code:

00000000 <main>:
    0:        05000e13        addi x28 x0 80

00000004 <my_clz>:
    4:        00000e93        addi x29 x0 0
    8:        00010fb7        lui x31 0x10
    c:        01fe3f33        sltu x30 x28 x31
    10:       004f1f13        slli x30 x30 4
    ...

Next, we will examine the execution process of these three instructions

1. Instruction Fetch(IF)

image

  • The processor fetches the next instruction to be executed (addi x28, x0, 80) based on the address in the PC.
  • The PC will increment by 4 to move to the address of the next instruction.

2. Instruction Decode (ID)

image

  • The processor decodes the instruction (addi x28, x0, 80)
  • The processor identifies the destination register(x28), source register(x0) and immediate value 80
  • The processor reads source register(x0) for the next-stage
  • At the same time, the second instruction(addi x29, x0, 0) also enters the instruction fetch stage."

3. Execute (EX)

image

  • The processor execute the instruction(addi x28, x0, 80), adding the value from x0 to the immediate value 80.
  • This result(0x00000050) is stored in a temporary register for the next stage.
  • At the same time, the second instruction(addi x29, x0, 0) also enters the instruction decode stage."
  • At the same time, the third instruction(lui x31, 0x10) also enters the instruction decode stage."

4. Memory Access (MEM)

image

  • Since the addi instruction does not require memory access (it only modifies a register), this stage is effectively a no-operation (NOP) for this instruction.

5. Write Back (WB)

image

  • The processor writes the result of the result0x00000050 back to the destination register x28.
  • Therefore, the instruction addi x28, x0, 80 has been successfully executed.

Reference

從√2的存在談開平方根的快速運算 (binary search method)
數值系統 (count leading zero)
Branchless count-leading-zeros on 32-bit RISC-V without Zbb extension
Quiz1 of Computer Architecture (2024 Fall)