contributed by <wx200010> ( Shao-Huan, Yu )
In this assignment, I chose Problem C as my topic. Since Problem C utilizes a technique for counting leading zeros to find the MSB position of the mantissa in the float16 to float32 conversion, I discovered that this technique can also accelerate computations in the sqrt(x) problem. By using a binary search method to find the square root, employing the CLZ function allows for quickly obtaining the MSB of the input 𝑥. By right-shifting it by half, we can derive a smaller binary search range, thereby reducing the number of search iterations.
In each assembly code display, if there are parts related to reducing RISC-V overhead or accelerating computational efficiency, I will write them before the code display.
In problem C, we are tasked with implementing the conversion of a float16 value into a float32 format.
The float32 format is represented in binary as follows:
sign(1 bit) | exponent(8 bits) | mantissa(23 bits) |
---|---|---|
0 | 00000000 | 00000000000000000000000 |
And the float16 format is represented in binary as follows:
sign(1 bit) | exponent(5 bits) | mantissa(10 bits) |
---|---|---|
0 | 00000 | 0000000000 |
The Sign bit indicates the sign of the floating-point number and is interpreted as follows:
0
: The number is positive1
: The number is negativeThese bits are used to store the exponent value in a "biased" format, meaning that a constant (called the bias) is added to the actual exponent to allow both positive and negative exponents to be represented.
The exponent is stored as an unsigned integer and the actual exponent
E is calculated as:E = Exponent bits − bias
127
15
The mantissa in floating-point formats represents the precision of the number. In both float32 and float16 formats, the mantissa bits encode the significant digits of the number.
In this assignment, there are three main functions to be implemented: fabsf
, my_clz
, and fp16_to_fp32
. Below are the descriptions of each function along with the relevant code.
The fabsf
function computes the absolute value of a floating-point number. It takes a single float as input and returns the non-negative value of that float. If the input is negative, it will remove the sign bit, effectively returning the magnitude of the number.
The original fabsf
code in Problem C is as follows.
static inline float fabsf(float x) {
uint32_t i = *(uint32_t *)&x; // Read the bits of the float into an integer
i &= 0x7FFFFFFF; // Clear the sign bit to get the absolute value
x = *(float *)&i; // Write the modified bits back into the float
return x;
}
Here is the RISC-V assembly code converted from the above C code:
fabsf:
# a0 is the input parameter x
li t0, 0x7FFFFFFF
and a0, a0, t0 # a0 = a0 & 0x7FFFFFFF
ret
The my_clz
function counts the number of leading zeros in a 32-bit unsigned integer. It efficiently determines the position of the most significant bit (MSB).
The original my_clz
code in Problem C is as follows.
static inline int my_clz(uint32_t x) {
int count = 0;
for (int i = 31; i >= 0; --i) {
if (x & (1U << i))
break;
count++;
}
return count;
}
Since the original my_clz
implementation used a loop for searching, the worst-case scenario required about 32 iterations to complete the computation. Therefore, I replaced it with a branchless version, primarily composed of a combination of binary search and a lookup table. This approach is more convenient for writing in RISC-V and offers improved performance.
static inline int my_clz(uint32_t x) {
int r = 0, c;
c = (x < 0x00010000) << 4;
r += c;
x <<= c; // off 16
c = (x < 0x01000000) << 3;
r += c;
x <<= c; // off 8
c = (x < 0x10000000) << 2;
r += c;
x <<= c; // off 4
c = (x >> (32 - 4 - 1)) & 0x1e;
return r + ((0x55af >> c) & 3);
}
During the implementation process, I noticed that when translating the conditions x < 0x00010000
, x < 0x01000000
, and x < 0x10000000
, the subsequent constants exceed 12 bits. As a result, using li
to load values like 0x00010000
into a register requires two instructions to complete the operation.
However, since these values are only one bit is set to 1
, after using li
to load 0x00010000
the first time, the subsequent values 0x01000000
and 0x10000000
can be replaced with slli
instructions, thereby reducing the overhead of the assembly code.
Here is the RISC-V assembly code converted from the above C code:
my_clz:
# a0 is the input parameter x
# t0 is r
# t1 is c
# t2 is tmp
li t0, 0 # r = 0
li t2, 0x00010000 # tmp = 0x00010000
slt t1, a0, t2
slli t1, t1, 4 # c = (x < 0x00010000) << 4;
add t0, t0, t1 # r += c
sll a0, a0, t1 # x <<= c
slli t2, t2, 8
slt t1, a0, t2
slli t1, t1, 3 # c = (x < 0x01000000) << 3;
add t0, t0, t1 # r += c
sll a0, a0, t1 # x <<= c
slli t2, t2, 4
slt t1, a0, t2
slli t1, t1, 2 # c = (x < 0x10000000) << 2;
add t0, t0, t1 # r += c
sll a0, a0, t1 # x <<= c
srli t1, a0, 27
andi t1, t1, 0x1e # c = (x >> (32 - 4 - 1)) & 0x1e
li a0, 0x55af
srl a0, a0, t1
andi a0, a0, 3
add a0, t0, a0 # return r + (0x55af >> c) & 3
ret
The fp16_to_fp32
function converts a half-precision floating-point number (float16) to single-precision floating-point number (float32).
In this original C code, fp16_to_fp32
only calls my_clz
to calculate the MSB of the mantissa and does not call fabs
, but instead incorporates the content of fabs directly within the function.
static inline uint32_t fp16_to_fp32(uint16_t h) {
/*
* Extends the 16-bit half-precision floating-point number to 32 bits
* by shifting it to the upper half of a 32-bit word:
* +---+-----+------------+-------------------+
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 31 26-30 16-25 0-15
*
* S - sign bit, E - exponent bits, M - mantissa bits, 0 - zero bits.
*/
const uint32_t w = (uint32_t) h << 16;
/*
* Isolates the sign bit from the input number, placing it in the most
* significant bit of a 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extracts the mantissa and exponent from the input number, placing
* them in bits 0-30 of the 32-bit word:
*
* +---+-----+------------+-------------------+
* | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 30 27-31 17-26 0-16
*/
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
/*
* The renorm_shift variable indicates how many bits the mantissa
* needs to be shifted to normalize the half-precision number.
* For normalized numbers, renorm_shift will be 0. For denormalized
* numbers, renorm_shift will be greater than 0. Shifting a
* denormalized number will move the mantissa into the exponent,
* normalizing it.
*/
uint32_t renorm_shift = my_clz(nonsign);
renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
/*
* If the half-precision number has an exponent of 15, adding a
* specific value will cause overflow into bit 31, which converts
* the upper 9 bits into ones. Thus:
* inf_nan_mask ==
* 0x7F800000 if the half-precision number is
* NaN or infinity (exponent of 15)
* 0x00000000 otherwise
*/
const int32_t inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) &
INT32_C(0x7F800000);
/*
* If nonsign equals 0, subtracting 1 will cause overflow, setting
* bit 31 to 1. Otherwise, bit 31 will be 0. Shifting this result
* propagates bit 31 across all bits in zero_mask. Thus:
* zero_mask ==
* 0xFFFFFFFF if the half-precision number is
* zero (+0.0h or -0.0h)
* 0x00000000 otherwise
*/
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
/*
* 1. Shifts nonsign left by renorm_shift to normalize it (for denormal
* inputs).
* 2. Shifts nonsign right by 3, adjusting the exponent to fit in the
* 8-bit exponent field and moving the mantissa into the correct
* position within the 23-bit mantissa field of the single-precision
* format.
* 3. Adds 0x70 to the exponent to account for the difference in bias
* between half-precision and single-precision.
* 4. Subtracts renorm_shift from the exponent to account for any
* renormalization that occurred.
* 5. ORs with inf_nan_mask to set the exponent to 0xFF if the input
* was NaN or infinity.
* 6. ANDs with the inverted zero_mask to set the mantissa and exponent
* to zero if the input was zero.
* 7. Combines everything with the sign bit of the input number.
*/
return sign | ((((nonsign << renorm_shift >> 3) +
((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask);
}
When converting fp16_to_fp32 to assembly language, I employed several techniques:
my_clz
is only called once in fp16_to_fp32
, so I can directly incorporate the content of my_clz
into the fp16_to_fp32
function, eliminating the need to handle function calls and save registers, thereby reducing overhead.my_clz
has already been optimized to be branchless, the assembly implementation of fp16_to_fp32
will not involve any looping process, which enhances computational efficiency.
fp16_to_fp32:
# a0 is h
# t0 is w
# t1 is sign
# t2 is nonsign
# t3 is renorm_shift
# t4 is inf_nan_mask
# t5 is zero_mask
slli t0, a0, 16 # w = (uint32_t) h << 16;
li t2, 0x80000000
and t1, t0, t2 # sign = w & UINT32_C(0x80000000);
li t2, 0x7FFFFFFF
and t2, t0, t2 # nonsign = w & UINT32_C(0x7FFFFFFF);
mv t3, t2 # renorm_shift = nonsign
# renorm_shift = my_clz(nonsign) after my_clz labels
my_clz:
# t3 is the input parameter x
# t4 is r
# t5 is c
# t6 is tmp
li t4, 0 # r = 0
li t6, 0x00010000 # tmp = 0x00010000
sltu t5, t3, t6 # c = (x < 0x00010000)
slli t5, t5, 4 # c = (x < 0x00010000) << 4;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
slli t6, t6, 8
sltu t5, t3, t6 # c = (x < 0x01000000)
slli t5, t5, 3 # c = (x < 0x01000000) << 3;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
slli t6, t6, 4
sltu t5, t3, t6 # c = (x < 0x10000000)
slli t5, t5, 2 # c = (x < 0x10000000) << 2;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
srli t5, t3, 27
andi t5, t5, 0x1e # c = (x >> (32 - 4 - 1)) & 0x1e
li t3, 0x55af
srl t3, t3, t5
andi t3, t3, 3
add t3, t4, t3 # renorm_shift = r + (0x55af >> c) & 3
my_clz_end: # renorm_shift = my_clz(nonsign)
li t4, 5
bleu t3, t4, renorm_shift_zero # if renorm_shift <= 5, then renorm_shift = 0
renorm_shift_substract5:
addi t3, t3, -5 # else renorm_shift = renorm_shift - 5
j renorm_shift_end
renorm_shift_zero:
li t3, 0 # renorm_shift = 0
renorm_shift_end:
li t5, 0x04000000
add t4, t2, t5 # inf_nan_mask = nonsign + 0x04000000
srai t4, t4, 8 # inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8)
li t5, 0x7F800000
and t4, t4, t5 # inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
addi t5, t2, -1 # zero_mask = (int32_t)(nonsign - 1)
srai t5, t5, 31 # zero_mask = (int32_t)(nonsign - 1) >> 31;
# finally, we want to return sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask)
# we will first calculate (nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)
sll t2, t2, t3
srli t2, t2, 3 # tmp1 = nonsign << renorm_shift >> 3
li t6, 0x70
sub t3, t6, t3 # tmp2 = 0x70 - renorm_shift
slli t3, t3, 23 # tmp2 = (0x70 - renorm_shift) << 23
add t2, t2, t3 # tmp2 = (nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)
# and calculate (((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) | inf_nan_mask) & ~zero_mask)
or t2, t2, t4 # tmp2 = ((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) | inf_nan_mask)
not t5, t5 # zero_mask = ~zero_mask
and t2, t2, t5 # tmp2 = ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask)
ori a0, t1, 0 # result = sign
or a0, a0, t2 # result = sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask)
ret # return result
Here is the complete C code used for testing the results, where my_clz
and fp16_to_fp32
are the previously mentioned functions with comments removed. The main function is added at the end to print the converted values:
#include <stdio.h>
#include <stdint.h>
static inline uint32_t my_clz(uint32_t x)
{
int r = 0, c;
c = (x < 0x00010000) << 4;
r += c;
x <<= c; // off 16
c = (x < 0x01000000) << 3;
r += c;
x <<= c; // off 8
c = (x < 0x10000000) << 2;
r += c;
x <<= c; // off 4
c = (x >> (32 - 4 - 1)) & 0x1e;
return r + ((0x55af >> c) & 3);
}
static inline uint32_t fp16_to_fp32(uint16_t h)
{
const uint32_t w = (uint32_t)h << 16;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = my_clz(nonsign);
renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
const int32_t inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) &
INT32_C(0x7F800000);
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
return sign | ((((nonsign << renorm_shift >> 3) +
((0x70 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
}
int main()
{
uint32_t datas[] = {
0x0710, // normalized number
0x311F, // normalized number
0x000F, // denormalized number
0x0000, // positive zero
0x8000, // negative zero
0x7C00, // positive inf
0xFC00, // negative inf
0x7CFF // NaN
};
uint32_t results[] = {
fp16_to_fp32(datas[0]),
fp16_to_fp32(datas[1]),
fp16_to_fp32(datas[2]),
fp16_to_fp32(datas[3]),
fp16_to_fp32(datas[4]),
fp16_to_fp32(datas[5]),
fp16_to_fp32(datas[6]),
fp16_to_fp32(datas[7])};
printf("\nfp16_to_fp32(0x0710) is : 0x%x ", results[0]); // normalized number
printf("\nfp16_to_fp32(0x311F) is : 0x%x ", results[1]); // normalized number
printf("\nfp16_to_fp32(0x000F) is : 0x%x ", results[2]); // denormalized number
printf("\nfp16_to_fp32(0x0000) is : 0x%x ", results[3]); // positive zero
printf("\nfp16_to_fp32(0x8000) is : 0x%x ", results[4]); // negative zero
printf("\nfp16_to_fp32(0x7C00) is : 0x%x ", results[5]); // positive inf
printf("\nfp16_to_fp32(0xFC00) is : 0x%x ", results[6]); // negative inf
printf("\nfp16_to_fp32(0x7CFF) is : 0x%x ", results[7]); // NaN
}
output messages:
fp16_to_fp32(0x0710) is : 0x38e20000
fp16_to_fp32(0x311F) is : 0x3e23e000
fp16_to_fp32(0x000F) is : 0x35700000
fp16_to_fp32(0x0000) is : 0x0
fp16_to_fp32(0x8000) is : 0x80000000
fp16_to_fp32(0x7C00) is : 0x7f800000
fp16_to_fp32(0xFC00) is : 0xff800000
fp16_to_fp32(0x7CFF) is : 0x7f9fe000
I add a main function and testing data at the beginning to output the converted values, making it easier to compare the results with those from the C code execution.
Additionally, I organized the results from the execution of the C code into an ans array in the RISC-V code, which will automatically compare the answers obtained from the C code during the process. If they are not equal, it will print "\nthe answer is wrong!!!"
.
Here’s the modified assembly code:
.data
datas: .word 0x0710, 0x311F, 0x000F, 0x0000, 0x8000, 0x7C00, 0xFC00, 0x7CFF
ans: .word 0x38e20000, 0x3e23e000, 0x35700000, 0x0, 0x80000000, 0x7f800000, 0xff800000, 0x7f9fe000
str1: .string "\nfp16_to_fp32(0x0710) is : "
str2: .string "\nfp16_to_fp32(0x311F) is : "
str3: .string "\nfp16_to_fp32(0x000F) is : "
str4: .string "\nfp16_to_fp32(0x0000) is : "
str5: .string "\nfp16_to_fp32(0x8000) is : "
str6: .string "\nfp16_to_fp32(0x7C00) is : "
str7: .string "\nfp16_to_fp32(0xFC00) is : "
str8: .string "\nfp16_to_fp32(0x7CFF) is : "
strError: .string "\nthe answer is wrong!!!"
strs: .word str1, str2, str3, str4, str5, str6, str7, str8
.text
main:
la s6, ans # Load ans reference
la s7, datas # Load datas reference
la s8, strs # Load strs references
li s9, 8 # Load the loop count
print_numbers:
lw a0, 0(s8) # Load string reference
li a7, 4 # print string
ecall
lw a0, 0(s7) # Load data
jal ra, fp16_to_fp32 # calculate fp16_to_fp32(data)
li, a7, 34 # print the result in hex format
ecall
validation:
lw t0, 0(s6) # Load ans
sub t0, t0, a0 # calculate ans - result for validation
beqz t0, check_loop # if (ans - result) == 0 then skip
la a0, strError
li a7, 4 # print error message!!!
ecall
check_loop:
addi s6, s6, 4 # shift ans index
addi s7, s7, 4 # shift datas index
addi s8, s8, 4 # shift strs index
addi s9, s9, -1 # loop count - 1
bnez s9, print_numbers
exit:
# Exit the program
li a7, 10 # System call code for exiting the program
ecall # Make the exit system call
ret
fp16_to_fp32:
# a0 is h
# t0 is w
# t1 is sign
# t2 is nonsign
# t3 is renorm_shift
# t4 is inf_nan_mask
# t5 is zero_mask
slli t0, a0, 16 # w = (uint32_t) h << 16;
li t2, 0x80000000
and t1, t0, t2 # sign = w & UINT32_C(0x80000000);
li t2, 0x7FFFFFFF
and t2, t0, t2 # nonsign = w & UINT32_C(0x7FFFFFFF);
mv t3, t2 # renorm_shift = nonsign
# renorm_shift = my_clz(nonsign) after my_clz labels
my_clz:
# t3 is the input parameter x
# t4 is r
# t5 is c
# t6 is tmp
li t4, 0 # r = 0
li t6, 0x00010000 # tmp = 0x00010000
sltu t5, t3, t6 # c = (x < 0x00010000)
slli t5, t5, 4 # c = (x < 0x00010000) << 4;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
slli t6, t6, 8
sltu t5, t3, t6 # c = (x < 0x01000000)
slli t5, t5, 3 # c = (x < 0x01000000) << 3;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
slli t6, t6, 4
sltu t5, t3, t6 # c = (x < 0x10000000)
slli t5, t5, 2 # c = (x < 0x10000000) << 2;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
srli t5, t3, 27
andi t5, t5, 0x1e # c = (x >> (32 - 4 - 1)) & 0x1e
li t3, 0x55af
srl t3, t3, t5
andi t3, t3, 3
add t3, t4, t3 # renorm_shift = r + (0x55af >> c) & 3
my_clz_end: # renorm_shift = my_clz(nonsign)
li t4, 5
bleu t3, t4, renorm_shift_zero # if renorm_shift <= 5, then renorm_shift = 0
renorm_shift_substract5:
addi t3, t3, -5 # else renorm_shift = renorm_shift - 5
j renorm_shift_end
renorm_shift_zero:
li t3, 0 # renorm_shift = 0
renorm_shift_end:
li t5, 0x04000000
add t4, t2, t5 # inf_nan_mask = nonsign + 0x04000000
srai t4, t4, 8 # inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8)
li t5, 0x7F800000
and t4, t4, t5 # inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
addi t5, t2, -1 # zero_mask = (int32_t)(nonsign - 1)
srai t5, t5, 31 # zero_mask = (int32_t)(nonsign - 1) >> 31;
sll t2, t2, t3
srli t2, t2, 3 # tmp1 = nonsign << renorm_shift >> 3
li t6, 0x70
sub t3, t6, t3 # tmp2 = 0x70 - renorm_shift
slli t3, t3, 23 # tmp2 = (0x70 - renorm_shift) << 23
add t2, t2, t3 # tmp2 = (nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)
or t2, t2, t4 # tmp2 = ((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23) | inf_nan_mask)
not t5, t5 # zero_mask = ~zero_mask
and t2, t2, t5 # tmp2 = ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask)
ori a0, t1, 0 # result = sign
or a0, a0, t2 # result = sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask)
ret # return result
ouputs
fp16_to_fp32(0x0710) is : 0x38e20000
fp16_to_fp32(0x311F) is : 0x3e23e000
fp16_to_fp32(0x000F) is : 0x35700000
fp16_to_fp32(0x0000) is : 0x0000
fp16_to_fp32(0x8000) is : 0x80000000
fp16_to_fp32(0x7C00) is : 0x7f800000
fp16_to_fp32(0xFC00) is : 0xff800000
fp16_to_fp32(0x7CFF) is : 0x7f9fe000
Program exited with code: 0
This section will compare the cycle count of my C code, after being compiled into assembly using the RISC-V GCC compiler, with the cycle count of my self-written RISC-V assembly code.
Ripes provides a RISC-V C/C++ compiler feature that allows users to paste C code and convert it into RISC-V assembly. The compiler version I used is risc-v-gcc10.1.0.exe with -O3
arguments
Here is the execution results of the RISC-V code generated from C compiler and my RISC-V code:
generated from RISC-V C compiler | my RISC-V code |
---|---|
![]() |
![]() |
I'm not sure why the performance of the RISC-V code generated by compiling the C code is so different from the RISC-V code I wrote myself. It could be because the compiler version is outdated, or perhaps it's just normal.
Given a non-negative integer x, return the square root of x rounded down to the nearest integer. The returned integer should be non-negative as well.
You must not use any built-in exponent function or operator.
For example, do not use pow(x, 0.5) in c++ or x ** 0.5 in python.
Example 1:
Input: x = 4
Output: 2
Explanation: The square root of 4 is 2, so we return 2.
Example 2:
Input: x = 8
Output: 2
Explanation: The square root of 8 is 2.82842…, and since weround it down to the nearest integer, 2 is returned.
Constraints:
0 <= x <= (2^31) - 1
The ideas and inspiration for the following C code are derived from this hackmd
Since the range of x in this problem is 0 <= x <= (2^31) - 1, we can use binary search, with the search range set between 1 and x. The cases for 0 can be handled with conditional statements at the beginning.
Therefore, the following is the first version of binary search based on the above idea:
int mySqrt(unsigned x)
{
if (x == 0)
return x;
uint32_t L, R;
L = 1;
R = x;
while (1)
{
uint64_t M = (L + R) >> 1;
if (M * M > x)
R = M;
else if ((M + 1) * (M + 1) <= x) // can do M*M + 2*M + 1 in assembly
L = M;
else
return M; // return M when M*M <= x < (M+1)*(M+1)
}
}
However, this binary search can still be optimized in the following two ways:
The following is the code after optimization based on the above points, which reduces the number of iterations in the binary search loop.
#include <stdio.h>
#include <stdint.h>
uint32_t my_clz(uint32_t x)
{
int r = 0, c;
c = (x < 0x00010000) << 4;
r += c;
x <<= c; // off 16
c = (x < 0x01000000) << 3;
r += c;
x <<= c; // off 8
c = (x < 0x10000000) << 2;
r += c;
x <<= c; // off 4
c = (x >> (32 - 4 - 1)) & 0x1e;
return r + ((0x55af >> c) & 3);
}
int mySqrt(unsigned x)
{
if (x == 0)
return x;
uint32_t temp, L, R, M;
temp = 31 - my_clz(x); // using clz method to find the MSB location - 1 of x
L = 1 << (temp >> 1); // set initial min as 2 ^ ((MSB location-1) / 2)
R = L << 1; // set initial max as 2 ^ (((MSB location-1) / 2) + 1)
while (1)
{
M = (L + R) >> 1;
if (M * M > x)
R = M;
else if ((M + 1) * (M + 1) <= x) // can do M*M + 2*M + 1 in assembly
L = M;
else
return M; // return M when M*M <= x < (M+1)*(M+1)
}
}
int main()
{
uint32_t datas[] = {0, 1, 2, 4, 8, 2147483647};
uint32_t results[] = {
mySqrt(datas[0]),
mySqrt(datas[1]),
mySqrt(datas[2]),
mySqrt(datas[3]),
mySqrt(datas[4]),
mySqrt(datas[5])};
for (int i = 0; i < 6; ++i)
printf("\nmySqrt(%u) is : %u", datas[i], results[i]);
}
Outputs:
mySqrt(0) is : 0
mySqrt(1) is : 1
mySqrt(2) is : 1
mySqrt(4) is : 2
mySqrt(8) is : 2
mySqrt(2147483647) is : 46340
Here is the RISC-V assembly code converted from the optimized C code above. The optimizations also made during the assembly conversion:
.data
datas: .word 0, 1, 2, 4, 8, 2147483647
ans: .word 0, 1, 1, 2, 2, 46340
str1: .string "\nmySqrt(0) is : "
str2: .string "\nmySqrt(1) is : "
str3: .string "\nmySqrt(2) is : "
str4: .string "\nmySqrt(4) is : "
str5: .string "\nmySqrt(8) is : "
str6: .string "\nmySqrt(2147483647) is : "
strError: .string "\nthe answer is wrong!!!"
strs: .word str1, str2, str3, str4, str5, str6
.text
main:
la s6, ans # Load ans reference
la s7, datas # Load datas reference
la s8, strs # Load strs references
li s9, 6 # Load the loop count
print_numbers:
lw a0, 0(s8) # Load string reference
li a7, 4 # print string
ecall
lw a0, 0(s7) # Load data
jal ra, mySqrt # calculate fp16_to_fp32(data)
li, a7, 36 # print the result in unsigned format
ecall
validation:
lw t0, 0(s6) # Load ans
sub t0, t0, a0 # calculate ans - result for validation
beqz t0, check_loop # if (ans - result) == 0 then skip
la a0, strError
li a7, 4 # print error message!!!
ecall
check_loop:
addi s6, s6, 4 # shift ans index
addi s7, s7, 4 # shift datas index
addi s8, s8, 4 # shift strs index
addi s9, s9, -1 # loop count - 1
bnez s9, print_numbers
exit:
# Exit the program
li a7, 10 # System call code for exiting the program
ecall # Make the exit system call
ret
mySqrt:
# a0 is x
# t0 is temp
# t1 is L
# t2 is R
# t3 is M
bnez a0, conditionSkip # if(x==0) then return x
ret
conditionSkip:
mv t3, a0 # set t3 = x to calc my_clz(x)
my_clz:
# t3 is the input parameter x
# t4 is r
# t5 is c
# t6 is tmp
li t4, 0 # r = 0
li t6, 0x00010000 # tmp = 0x00010000
sltu t5, t3, t6 # c = (x < 0x00010000)
slli t5, t5, 4 # c = (x < 0x00010000) << 4;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
slli t6, t6, 8
sltu t5, t3, t6 # c = (x < 0x01000000)
slli t5, t5, 3 # c = (x < 0x01000000) << 3;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
slli t6, t6, 4
sltu t5, t3, t6 # c = (x < 0x10000000)
slli t5, t5, 2 # c = (x < 0x10000000) << 2;
add t4, t4, t5 # r += c
sll t3, t3, t5 # x <<= c
srli t5, t3, 27
andi t5, t5, 0x1e # c = (x >> (32 - 4 - 1)) & 0x1e
li t3, 0x55af
srl t3, t3, t5
andi t3, t3, 3
add t0, t4, t3 # temp = r + (0x55af >> c) & 3
my_clz_end: # temp = my_clz(x)
li t4, 31
sub t0, t4, t0 # temp = 31 - my_clz(x)
srli t0, t0, 1 # tmp = temp >> 1
li t1, 1 # L = 1
sll t1, t1, t0 # L = 1 << (temp >> 1)
slli t2, t1, 1 # R = L << 1
binary_search_loop:
add t3, t1, t2 # M = L + R
srli t3, t3, 1 # M = (L + R) >> 1
mv t4, t3 # copy M to t4
mv t5, t3 # copy M to t5
li s0, 0 # set result = 0
multiple_loop: # calculate result = M * M
andi t6, t4, 1 # check LSB of t4
beqz t6, skip_add # if LSB of t4 == 0 then skip
add s0, s0, t5 # result = result + t5
skip_add:
srli t4, t4, 1 # t4 = t4 >> 1
slli t5, t5, 1 # t5 = t5 << 1
bnez t4, multiple_loop
multiple_end:
bgtu s0, a0, squareM_is_bigger # if(M * M > x) then jump
add s0, s0, t3 # result = M*M + M
add s0, s0, t3 # result = M*M + 2*M
addi s0, s0, 1 # result = (M+1)*(M+1)
bleu s0, a0, squareM1_is_smaller # if((M+1)*(M+1) <= x) then jump
breakLoop: # else return M
mv a0, t3
ret # return M
squareM_is_bigger:
mv t2, t3 # R = M
j binary_search_loop # continue
squareM1_is_smaller:
mv t1, t3 # L = M
j binary_search_loop # continue
Outputs
mySqrt(0) is : 0
mySqrt(1) is : 1
mySqrt(2) is : 1
mySqrt(4) is : 2
mySqrt(8) is : 2
mySqrt(2147483647) is : 46340
Program exited with code: 0
The compiler version I used is risc-v-gcc10.1.0.exe with -O3
arguments
Here is the execution results of the RISC-V code generated from C compiler and my RISC-V code:
generated from RISC-V C compiler | my RISC-V code |
---|---|
![]() |
![]() |
The 5-stage pipelined processor in RISC-V divides the instruction execution process into multiple stages, Here are the 5 stages of the RISC-V 5-stage pipeline processor:
The pipeline allows multiple instructions to be processed simultaneously at different stages, thereby increasing the overall throughput of the processor.
Taking the following RISC-V code as an example, this program will perform the CLZ operation on the value in t3
:
.text
main:
li t3, 0x050
my_clz:
# t3 is the input parameter x
# t4 is r
# t5 is c
# t6 is tmp
li t4, 0 # r = 0
li t6, 0x00010000 # tmp = 0x00010000
sltu t5, t3, t6 # c = (x < 0x00010000)
slli t5, t5, 4 # c = (x < 0x00010000) << 4;
...
Below is the executable code:
00000000 <main>:
0: 05000e13 addi x28 x0 80
00000004 <my_clz>:
4: 00000e93 addi x29 x0 0
8: 00010fb7 lui x31 0x10
c: 01fe3f33 sltu x30 x28 x31
10: 004f1f13 slli x30 x30 4
...
Next, we will examine the execution process of these three instructions
addi x28, x0, 80
) based on the address in the PC.x28
, x0
, 80
)x28
), source register(x0
) and immediate value 80
x0
) for the next-stageaddi x29, x0, 0
) also enters the instruction fetch stage."addi x28, x0, 80
), adding the value from x0 to the immediate value 80.0x00000050
) is stored in a temporary register for the next stage.addi x29, x0, 0
) also enters the instruction decode stage."lui x31, 0x10
) also enters the instruction decode stage."addi
instruction does not require memory access (it only modifies a register), this stage is effectively a no-operation (NOP) for this instruction.0x00000050
back to the destination register x28
.addi x28, x0, 80
has been successfully executed.從√2的存在談開平方根的快速運算 (binary search method)
數值系統 (count leading zero)
Branchless count-leading-zeros on 32-bit RISC-V without Zbb extension
Quiz1 of Computer Architecture (2024 Fall)