Try   HackMD

Assignment1: RISC-V Assembly and Instruction Pipeline

contributed by <JordyMalone>

Quiz 1 - Problem C

Problem analysis

In this problem, we utilized three functions: fabsf, my_clz and fp16_to_fp32.

  1. The fabsf function, discussed in Problem A, calculates the absolute value of a floating-point number. Rather than using traditional arithmetic operations, it clears the sign bit through a bitwise operation to return the result.
  2. The my_clz function counts the number of leading zero bits in the binary representation of an unsigned integer, starting from the most significant bit. It then returns an integer that represents how many zero bits precede the first 1 in the binary representation of the number.
  3. The fp16_to_fp32 function converts a 16-bit floating-point number in IEEE half-precision format to a 32-bit floating-point number in single-precision format. The my_clz function is also used in fp16_to_fp32.

fabsf

static inline float fabsf(float x) {
    uint32_t i = *(uint32_t *)&x;  // Read the bits of the float into an integer
    i &= 0x7FFFFFFF;               // Clear the sign bit to get the absolute value
    x = *(float *)&i;              // Write the modified bits back into the float
    return x;
}

assembly code

fabsf:
     # Assume that the input float number x is in a0
     li t0, 0x7FFFFFFF		# Clear the sign bit
     and a0, a0, t0			
     jr ra

my_clz

The C code from the __bulitin_clz is mentioned in Problem C as follow:

static inline int my_clz(uint32_t x) {
    int count = 0;
    for (int i = 31; i >= 0; --i) {
        if (x & (1U << i))
            break;
        count++;
    }
    return count;
}

assembly code

.data
argument1: .word 0x01111111        // zero: 7
argument2: .word 0x00008000        // zero: 16
argument3: .word 0x00000001        // zero: 31
newline: .string "\n"
str1: .string "The leading zero of "
str2: .string " is "

.text
main:
    # Argument1 condition
    lw   a0, argument1
    jal  ra, my_clz

    # Prepare to print the result
    mv   a1, a0
    lw   a0, argument1

    # Call the function to print the result
    jal  ra, print

    # Argument2 condition
    lw   a0, argument2
    jal  ra, my_clz

    # Prepare to print the result
    mv   a1, a0
    lw   a0, argument2

    # Call the function to print the result
    jal  ra, print

    # Argument3 condition
    lw   a0, argument3
    jal  ra, my_clz

    # Prepare to print the result
    mv   a1, a0
    lw   a0, argument3

    # Call the function to print the result
    jal  ra, print

    # Exit the program
    li   a7, 10
    ecall

my_clz:
    addi t0, zero, 0
    addi t1, zero, 31        # Set i = 31

loop:
    blt  t1, zero, exit      # i < 0 then exit
    addi t2, zero, 1         # t2 assign 1U
    sll  t2, t2, t1          # 1U << i
    and  t3, a0, t2          # x & (1U << i)
    bne  t3, zero, exit      # check if condition
    addi t0, t0, 1           # count + 1
    addi t1, t1, -1          # --i
    jal  x0, loop

exit:
    mv   a0, t0              # save count result into a0
    jr   ra

print:
    mv   t0, a0    	      # save argument to t0
    mv   t1, a1    	      # save leading zero result to t1

    la   a0, str1            # print string 1
    li   a7, 4
    ecall

    mv   a0, t0              # print arguments
    li   a7, 1
    ecall

    la   a0, str2            # print string 2
    li   a7, 4
    ecall

    mv   a0, t1              # print result
    li   a7, 1
    ecall
    
    la   a0, newline         # print \n
    li   a7, 4
    ecall

    jr   ra                  # jump back to main

Optimization

Using branchless version

#include <stdint.h>

int my_clz(uint32_t x) {
    int count = 0;

    if (x <= 0x0000FFFF) {    // check the first 16 bits
        count += 16; x <<= 16;
    } 
    if (x <= 0x00FFFFFF) {    // check the first 8 bits
        count += 8;  x <<= 8; 
    } 
    if (x <= 0x0FFFFFFF) {    // check the first 4 bits
        count += 4;  x <<= 4; 
    } 
    if (x <= 0x3FFFFFFF) {    // check the first 2 bits
        count += 2;  x <<= 2; 
    } 
    if (x <= 0x7FFFFFFF) {    // check the first 1 bits
        count += 1;
    } 

    return count;
}
.data
argument1: .word 0x01111111
argument2: .word 0x00008000
argument3: .word 0x00000001
newline: .string "\n"
str1: .string "The leading zero of "
str2: .string " is "

.text
main:
    # Argument1 condition
    lw   a0, argument1
    jal  ra, my_clz

    # Prepare to print the result
    mv   a1, a0
    lw   a0, argument1

    # Call the function to print the result
    jal  ra, print

    # Argument2 condition
    lw   a0, argument2
    jal  ra, my_clz

    # Prepare to print the result
    mv   a1, a0
    lw   a0, argument2

    # Call the function to print the result
    jal  ra, print

    # Argument3 condition
    lw   a0, argument3
    jal  ra, my_clz

    # Prepare to print the result
    mv   a1, a0
    lw   a0, argument3

    # Call the function to print the result
    jal  ra, print

    # Exit the program
    li   a7, 10
    ecall

my_clz:
    addi t0, zero, 0      # Initialize count to zero
    mv   t1, a0           # copy argument into t1

    li t2, 0x0000FFFF     # t2 = 0x0000FFFF (16-bit mask)
    li t3, 0x00FFFFFF     # t3 = 0x00FFFFFF (24-bit mask)
    li t4, 0x0FFFFFFF     # t4 = 0x0FFFFFFF (28-bit mask)
    li t5, 0x3FFFFFFF     # t5 = 0x3FFFFFFF (30-bit mask)
    li t6, 0x7FFFFFFF     # t6 = 0x7FFFFFFF (31-bit mask)

    # Check condition 1: if (x <= 0x0000FFFF) { count += 16; x <<= 16; }
    bleu t1, t2, L1
    j L2

L1:
    addi t0, t0, 16       # count += 16
    slli t1, t1, 16       # x <<= 16

L2:
    # Check condition 2: if (x <= 0x00FFFFFF) { count += 8; x <<= 8; }
    bleu t1, t3, L3
    j L4

L3:
    addi t0, t0, 8        # count += 8
    slli t1, t1, 8        # x <<= 8

L4:
    # Check condition 3: if (x <= 0x0FFFFFFF) { count += 4; x <<= 4; }
    bleu t1, t4, L5
    j L6

L5:
    addi t0, t0, 4        # count += 4
    slli t1, t1, 4        # x <<= 4

L6:
    # Check condition 4: if (x <= 0x3FFFFFFF) { count += 2; x <<= 2; }
    bleu t1, t5, L7
    j L8

L7:
    addi t0, t0, 2        # count += 2
    slli t1, t1, 2        # x <<= 2

L8:
    # Check condition 5: if (x <= 0x7FFFFFFF) { count += 1; }
    bleu t1, t6, L9
    j end

L9:
    addi t0, t0, 1        # count += 1

end:
    mv a0, t0
    jr ra

print:
    mv   t0, a0    	      # save argument to t0
    mv   t1, a1    	      # save leading zero result to t1

    la   a0, str1            # print string 1
    li   a7, 4
    ecall

    mv   a0, t0              # print arguments
    li   a7, 1
    ecall

    la   a0, str2            # print string 2
    li   a7, 4
    ecall

    mv   a0, t1              # print result
    li   a7, 1
    ecall
    
    la   a0, newline         # print \n
    li   a7, 4
    ecall

    jr   ra                  # jump back to main

Execution state

Original version
image

Branchless version
image

Converting the loop version into a branchless version can effectively reduce both the cycle count and the number of instructions.

fp16_to_fp32

static inline uint32_t fp16_to_fp32(uint16_t h) {
    const uint32_t w = (uint32_t) h << 16;
    const uint32_t sign = w & UINT32_C(0x80000000);
    const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
    uint32_t renorm_shift = my_clz(nonsign);
    renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
    const int32_t inf_nan_mask = ((int32_t)(nonsign + 0x04000000) >> 8) &
                                 INT32_C(0x7F800000);
    const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
    return sign | ((((nonsign << renorm_shift >> 3) +
            ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask);
}

assembly code

.data
argument1: .word 0x3C00
argument2: .word 0xC000
argument3: .word 0x7BFF
str1: .string "\nThe FP32 value of FP16 number "
str2: .string " is "

.text
main:
    lw   a0, argument1
    jal  ra, fp16_to_fp32
    mv   a1, a0
    lw   a0, argument1
    jal  ra, print
    lw   a0, argument2
    jal  ra, fp16_to_fp32
    mv   a1, a0
    lw   a0, argument2
    jal  ra, print
    lw   a0, argument3
    jal  ra, fp16_to_fp32
    mv   a1, a0
    lw   a0, argument3
    jal  ra, print
    li   a7, 10
    ecall

fp16_to_fp32:
    addi sp, sp, -8
    sw   ra, 0(sp)
    add  t0, zero, a0
    slli t0, t0, 16		# Extend 16 bits to 32 bits (w)
    li   t1, 0x80000000	
    and  t1, t0, t1		# Extract sign bit (sign)
    li   t2, 0x7FFFFFFF
    and  t2, t0, t2		# Extract mantissa and exponent (nonsign)
    add  a0, t2, zero		# put nonsign in a1 and pass to my_clz
    jal  ra, my_clz
    lw   ra, 0(sp)
    addi sp, sp, -8
    bgt  t3, zero, continue    # if t3 > 0 == renorm_shift > 5 
    addi t3, zero, 0		# else renorm_shift = 0

continue:
    # inf_nan_mask
    li   t4, 0x04000000
    add  t4, t2, t4		# nonsign + 0x04000000
    srli t4, t4, 8
    li   t5, 0x7F800000
    and  t4, t4, t5		# (nonsign + 0x04000000) >> 8 & 0x7F800000

    # zero_mask
    addi t5, t2, -1		# nonsign - 1
    srli t5, t5, 31

    sll  t2, t2, t3		# nonsign << renorm_shift
    srli t2, t2, 3		# (nonsign << renorm_shift) >> 3
    li   t6, 0x70
    sub  t6, t6, t3		# 0x70 - renorm_shift
    slli t6, t6, 23		# (0x70 - renorm_shift) << 23
    add  t2, t2, t6		# ((nonsign << renorm_shift) >> 3) + ((0x70 - renorm_shift) << 23)
    or   t2, t2, t4		# t2 | inf_nan_mask
    li   t6, 0xFFFFFFFF
    xor  t5, t5, t6		# ~zero_mask
    and  t2, t2, t5		# t2 & ~zero_mask
    or   t1, t1, t2		# sign | t1

exit:
    mv   a0, t1
    jr   ra

my_clz:
    addi t3, zero, 0
    addi t4, zero, 31         # Set i = 31

clz_loop:
    blt  t4, zero, clz_exit   # i < 0 then exit
    addi t5, zero, 1          # t5 assign 1U
    sll  t5, t5, t4           # t5 = (1U << i)
    and  t6, a0, t5           # t6 = x & t5
    bnez t6, clz_exit         # check if condition
    addi t3, t3, 1            # count + 1
    addi t4, t4, -1           # --i
    jal  x0, clz_loop

clz_exit:
    addi t3, t3, -5           # renorm_shift - 5
    mv   a0, t3               # put count result into a0
    ret
    
print:
    mv   t0, a0
    mv   t1, a1
    la   a0, str1
    li   a7, 4
    ecall
    mv   a0, t0
    li   a7, 1
    ecall
    la   a0, str2
    li   a7, 4
    ecall
    mv   a0, t1
    li   a7, 1
    ecall
    jr   ra                   # jump back to main

Provide more tests for validations.

Undo

Use Case

LeetCode: 3011. Find if Array Can Be Sorted

Description:
You are given a 0-indexed array of positive integers nums.
In one operation, you can swap any two adjacent elements if they have the same number of set bits. You are allowed to do this operation any number of times (including zero).
Return true if you can sort the array, else return false.

Implementation

In Problem C, we selected the my_clz function. Initially, it counts leading zeros, but we modified it to count the set bits instead.

Additionally, we defined a function called swap, which is triggered by a specific condition. The main function, canSortArray, employs the bubble sort concept to achieve its implementation.

C code

#include <stdbool.h>

int countSetbits(uint32_t number) {
    int count = 0;        // Initialize counter
    while(number) {       // Continue looping while number is not zero
        count += (number & 1);    // Check if the rightmost bit is '1'
                                // and if so , increment the counter
        number >>= 1;    // Shift the number right by one bit 
    }                    // to check the next bit
    return count;
}

void swap(int *a, int *b) {
    int temp = *a;
    *a = *b;
    *b = temp;
}

bool canSortArray(int* nums, int numsSize) {
    bool flag;
    for (int i = 0; i < numsSize; i++) {
        flag = false;
        for (int j = 1; j < numsSize; j++) {
            if (countSetbits(nums[j - 1]) == countSetbits(nums[j])) {
                if (nums[j - 1] > nums[j]) {
                    swap(&nums[j - 1], &nums[j]);
                    flag = true;
                }
            }
        }
        if (flag == 0) {
            break;
        }
    }
    
    for (int i = 1; i < numsSize; i++) {
        if (nums[i-1] > nums[i]) {
            return false;
        }
    }
    return true;

}

Assembly code

.data
nums1: .word 8, 4, 2, 30, 15   # Declare array
nums_size1: .word 5           # Size of the array
str1: .string "True\n"        # Output if sorted correctly
str2: .string "False\n"       # Output if sorting fails

.text
main:
    la a0, nums1           # Load the starting address of the array into a0
    lw a1, nums_size1      # Load the size of the array into a1
    jal ra, canSortArray   # Jump to canSortArray function

    la a0, nums1           # Load the starting address of the array into a0
    jal ra, print          # Jump to print function

    li a7, 10              # End the program
    ecall

canSortArray:
    addi sp, sp, -4        # Allocate space on the stack
    sw a0, 0(sp)           # Save a0 (array address) onto the stack
    addi t1, zero, 1       # Initialize i = 1
    addi t0, zero, 0       # Initialize flag = 0 (indicates not sorted)

outer_loop:
    lw a0, 0(sp)           # Reset a0 to the starting address
    addi t2, zero, 1       # Initialize j = 1
    addi t0, zero, 0       # Reset flag

inner_loop:
    bge t2, a1, check_flag # If j >= numsSize, jump to check_flag
    lw t3, 0(a0)           # Load nums[j-1]
    lw t4, 4(a0)           # Load nums[j]
    bge t4, t3, no_swap    # If nums[j] >= nums[j-1], skip swap

    # Calculate the number of set bits in nums[j-1]
    add t5, zero, zero     # Initialize bit counter
countSetbits1:
    andi t6, t3, 1         # Check the least significant bit
    add t5, t5, t6         # Increment bit counter
    srli t3, t3, 1         # Right shift by one
    bnez t3, countSetbits1 # If t3 is not 0, continue counting
    add t3, t5, zero       # Store the count result in t3

    # Calculate the number of set bits in nums[j]
    add t5, zero, zero     # Initialize bit counter
countSetbits2:
    andi t6, t4, 1         # Check the least significant bit
    add t5, t5, t6         # Increment bit counter
    srli t4, t4, 1         # Right shift by one
    bnez t4, countSetbits2 # If t4 is not 0, continue counting
    add t4, t5, zero       # Store the count result in t4

    # If the number of set bits are not equal, continue to the next element
    bne t3, t4, no_swap

    # Perform the swap
    lw t3, 0(a0)           # Reload nums[j-1]
    lw t4, 4(a0)           # Reload nums[j]
    sw t4, 0(a0)           # Store nums[j] in nums[j-1]
    sw t3, 4(a0)           # Store nums[j-1] in nums[j]
    addi t0, t0, 1         # Set flag = 1 to indicate a swap occurred

no_swap:
    addi a0, a0, 4         # Move to the next element
    addi t2, t2, 1         # j++
    jal x0, inner_loop     # Jump back to inner_loop for the next check

check_flag:
    bnez t0, outer_loop    # If flag is true, continue outer_loop

    # If sorting is complete, check if all elements are in order
    lw a0, 0(sp)           # Reset a0 to the starting address
    addi t1, zero, 1       # Initialize i = 1

isSorted:
    bge t1, a1, end_true   # If i >= numsSize, jump to end_true
    lw t3, 0(a0)           # Load nums[i-1]
    lw t4, 4(a0)           # Load nums[i]
    blt t4, t3, end_false  # If nums[i] < nums[i-1], jump to end_false

    addi a0, a0, 4         # Move index by 4
    addi t1, t1, 1         # i++
    jal x0, isSorted       # Jump back to isSorted

end_true:
    addi a1, zero, 1       # Set return value to true
    jr ra

end_false:
    addi a1, zero, 0       # Set return value to false
    jr ra

print:
    beqz a1, print_false   # If result is false, jump to print_false
    la a0, str1            # If true, load str1
    li a7, 4               # System call: print string
    ecall
    jr ra

print_false:
    la a0, str2            # If false, load str2
    li a7, 4               # System call: print string
    ecall
    jr ra

Execution state

image