# Assignment1: RISC-V Assembly and Instruction Pipeline
contributed by><[winterchen](https://github.com/kstoko02)>
## quiz1 - problem B
This program implements a custom 8-bit floating-point format (uf8) for encoding and decoding, and verifies the correctness of the encoding and decoding process. It uses the `clz()` function to accelerate the estimation of the uf8 exponent.
### C code
```C=1
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint8_t uf8;
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
return (mantissa << exponent) + offset;
}
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
/* Estimate exponent - the formula is empirical */
exponent = msb - 4;
if (exponent > 15)
exponent = 15;
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* Adjust if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Find exact exponent */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent;
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) {
uint8_t fl = i;
int32_t value = uf8_decode(fl);
uint8_t fl2 = uf8_encode(value);
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,
value, fl2);
passed = false;
}
if (value <= previous_value) {
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value);
passed = false;
}
previous_value = value;
}
return passed;
}
int main(void)
{
if (test()) {
printf("All tests passed.\n");
return 0;
}
return 1;
}
```
### Assembly code
```assembly=1
.data
str1: .string "failed"
str2: .string "passed"
.text
.globl main
main:
li s0, 0 # fl = 0
li t1, 255 # max fl
li s4, -1 # previous_value
test_loop:
mv a0, s0
jal ra, uf8_decode
mv s5, a0 # value
jal ra, uf8_encode_fast
mv t4, a0 # fl2
bne t4, s0, fail
bge s4, s5, fail
mv s4, s5
addi s0, s0, 1
ble s0, t1, test_loop # if(t0 <= t1) test_loop
la a0, str2
li a7, 4
ecall
j end
fail:
mv a0, s0
li a7,1
ecall
j end
end:
li a7,10 #exit
ecall
# clz start
clz:
li t4, 32 # n = 32
li t1, 16 # c = 16
mv t2, a0 # t2 = x
clz_loop:
srl t3, t2, t1 # y = x >> c
beqz t3, skip # if y == 0, skip
sub t4, t4, t1 # n -= c
mv t2, t3 # x = y
skip:
srli t1, t1, 1 # c >>= 1
bnez t1, clz_loop # while (c != 0)
sub a0, t4, t2 # return n - x
jr ra
#clz end
uf8_decode:
andi t4, a0, 0x0F
srli t1, a0, 4 # t1 = exponent
li t2, 0x7FFF
li t3, 15
sub t3, t3, t1 # t3 = 15 - exponent
srl t2, t2, t3
slli t2, t2, 4
sll t4, t4, t1
add a0, t4, t2
jr ra
uf8_encode_fast:
li t1, 16
blt a0, t1, small_val # value < 16
mv s3, a0
addi sp, sp, -4
sw ra, 0(sp)
jal ra, clz
mv t1, a0 # lz = clz(value)
li t2, 31
sub t3, t2, t1 # msb = 31 - clz(value)
li t4, 0
li t5, 0
li t6, 5
blt t3, t6, compute_overflow
addi t4, t3, -4 # exponent = msb - 4
li t6, 16
bge t4, t6, set_15_exp
compress:
li t6, 0
loopA:
bge t6, t4, loopB
slli t5, t5, 1
addi t5, t5, 16
addi t6, t6, 1
j loopA
loopB:
blez t4, compute_overflow
bge t1, t5, compute_overflow
addi t5, t5, -16
srli t5, t5, 1
addi t4, t4, -1
j loopB
compute_overflow:
li t6, 16
bge t4, t6, end_compute
slli s1, t5, 1
addi s1, s1, 16
blt t1, s1, end_compute
mv t5, s1
addi t4, t4, 1
j compute_overflow
end_compute:
sub s2, s3, t5
srl s2, s2, t4
slli t4, t4, 4
or a0, t4, s2
lw ra, 0(sp)
addi sp, sp, 4
jr ra
set_15_exp:
li t4, 15
j compress
small_val:
mv a0, a0
jr ra
```
## quiz1 - problem C
This program includes functions for determining whether a bfloat16 is zero, infinite, or NaN, converting float32 to bfloat16 and bfloat16 to float32, and performing bfloat16 addition, subtraction, multiplication, division, and square root operations.
### C code
```C=1
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
typedef struct {
uint16_t bits;
} bf16_t;
#define BF16_SIGN_MASK 0x8000U
#define BF16_EXP_MASK 0x7F80U
#define BF16_MANT_MASK 0x007FU
#define BF16_EXP_BIAS 127
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0})
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000})
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
return (bf16_t) {.bits = f32bits >> 16};
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a;
}
if (exp_b == 0xFF)
return b;
if (!exp_a && !mant_a)
return b;
if (!exp_b && !mant_b)
return a;
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) {
result_exp = exp_a;
if (exp_diff > 8)
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) {
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else {
result_exp = exp_a;
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b;
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
} else {
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO();
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO();
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK;
return bf16_add(a, b);
}
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (!exp_b && !mant_b)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15};
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15};
result_mant >>= (1 - result_exp);
result_exp = 0;
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_b == 0xFF) {
if (mant_b)
return b;
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = result_sign << 15};
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_a == 0xFF) {
if (mant_a)
return a;
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15};
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++;
if (quotient & 0x8000)
quotient >>= 8;
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F)};
}
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Handle special cases */
if (exp == 0xFF) {
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow */
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; /* +Inf */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
```
### Assembly code
```Assembly=1
.data
BF16_SIGN_MASK: .half 0x8000
BF16_EXP_MASK: .half 0x7F80
BF16_MANT_MASK: .half 0x007F
BF16_EXP_BIAS: .word 127
.text
.globl main.data
BF16_SIGN_MASK: .half 0x8000
BF16_EXP_MASK: .half 0x7F80
BF16_MANT_MASK: .half 0x007F
BF16_EXP_BIAS: .word 127
.text
.globl main
BF16_NAN:
li t1, 0x7FC0
mv a0, t1
jr ra
BF16_ZERO:
li t1, 0x0000
mv a0, t1
jr ra
bf16_isnan:
la t0, BF16_EXP_MASK
lh t1, 0(t0)
la t0, BF16_MANT_MASK
lh t2, 0(t0)
and t3, a0, t1
bne t3, t1, not_nan
and t4, a0, t2
beqz t4, not_nan
li a0, 1
jr ra
not_nan:
li a0, 0
jr ra
bf16_isinf:
la t0, BF16_EXP_MASK
lh t1, 0(t0)
la t0, BF16_MANT_MASK
lh t2, 0(t0)
and t3, a0, t1
bne t3, t1, not_inf
and t4, a0, t2
bnez t4, not_inf
li a0, 1
jr ra
not_inf:
li a0, 0
jr ra
bf16_iszero:
li t1, 0x7FFF
and t2, a0, t1
seqz a0, t2
jr ra
f32_to_bf16:
li t1, 0xFF
srli t2, a0, 23
and t3, t1, t2
beq t1, t3, re_bf16
li t1, 0x7FFF
srli t2, a0, 16
andi t3, t2, 1
add t4, t3, t1
add a0, a0, t4
srli a0, a0, 16
jr ra
re_bf16:
li t1, 0xFFFF
srli t2, a0, 16
and a0, t1, t2
jr ra
bf16_to_f32:
slli t1, a0, 16
mv a0, t1
jr ra
set_value:
srli t1, a0, 15
andi s0, t1, 1 #sign_a
srli t1, a1, 15
andi s1, t1, 1 #sign_b
li t1, 0x7F
and s4, a0, t1 #mant_a
and s5, a1, t1 #mant_b
li t1, 0xFF
srai t2, a0, 7
and s2, t2, t1 #exp_a
srai t2, a1, 7
and s3, t2, t1 #exp_b
jr ra
bf16_add:
addi sp, sp, -4
sw ra, 0(sp)
jal ra, set_value
lw ra, 0(sp)
addi sp, sp, 4
li t1, 0xFF
li t2, 1
bne s2, t1, skip_expa
beq s4, t2, re_a
bne s3, t1, re_a
beq s5, t2, re_b
beq s0, s1, re_b
j BF16_NAN
skip_expa:
beq s3, t1, re_b
bnez s2, skipA
bnez s4, skipA
j re_b
skipA:
bnez s3, skipB
bnez s5, skipB
j re_a
skipB:
li t2, 0x80
beqz s2, skipC
or s4, s4, t2
skipC:
beqz s3, skipD
or s5, s5, t2
skipD:
sub s6, s2, s3 #exp_diff
blez s6, diff_zero
mv s8, s2
li t3, 8
bgt s6, t3, re_a
srl s5, s5, s6
j skip_diff
diff_zero:
bnez s6, diff_neg
mv s8, s2
j skip_diff
diff_neg:
mv s8, s3
li t3, -8
blt s8, t3, re_b
sub t4, x0, s6
srl s4, s4, t4
skip_diff:
# sign_a == sign_b
bne s0, s1, sign_neq
mv s7, s0
add s9, s4, s5
li t3, 0x100
and t4, s9, t3
beqz t4, skip_sign
srli s9, s9, 1
addi s8, s8, 1
blt s8, t1, skip_sign
j re_bits
sign_neq:
blt s4, s5, mant_lt
mv s7, s0
sub s9, s4, s5
j re_mant
mant_lt:
mv s7, s1
sub s9, s5, s4
re_mant:
beqz s9, BF16_ZERO
loopA:
and t3, s9, t2
bnez t3, skip_sign
slli s9, s9,1
addi s8, s8, -1
blez s8, BF16_ZERO
j loopA
skip_sign:
slli t3, s7, 15
and t4, s8, t1
slli t4, t4, 7
li t5, 0x7F
and t6, s9, t5
or a0, t3, t4
or a0, a0, t6
jr ra
re_a:
jr ra
re_b:
mv a0, a1
jr ra
bf16_sub:
la t0, BF16_SIGN_MASK
lhu t1, 0(t0)
xor a1, a1, t1
j bf16_add
bf16_mul:
addi sp, sp, -4
sw ra, 0(sp)
jal ra, set_value
lw ra, 0(sp)
xor s7, s0, s1 #result_sign
li t1, 0xFF
bne s2, t1, skip_a
bnez s4, re_a
bnez s3, re_bits
bnez s5, re_bits
j BF16_NAN
skip_a:
bne s3, t1, skip_b
bnez s5, re_b
bnez s2, re_bits
bnez s4, re_bits
j BF16_NAN
re_bits:
li t3, 0x7F80
slli t4, s7, 15
or a0, t4, t3
jr ra
re_bits2:
slli a0, s7, 15
jr ra
skip_b:
bnez s2, skip_zero
beqz s4, re_bits2
bnez s3, skip_zero
beqz s5, re_bits2
skip_zero:
li s10, 0 #exp_adjust
li t2, 0x80
bnez s2, skipE
li s2, 1
loopB:
and t3, s4, t2
bnez t3, skipF
slli s4, s4, 1
addi s10, s10, -1
j loopB
skipE:
or s4, s4, t2
skipF:
bnez s3, skipG
li s3, 1
loopC:
and t3, s5, t2
bnez t3, skip_exp
slli s5, s5, 1
addi s10, s10, -1
j loopC
skipG:
or s5, s5, t2
skip_exp:
mv a0, s4
mv a1, s5
sw ra, 0(sp)
jal ra, prod
mv s9, a0 #result_mant
lw ra, 0(sp)
addi sp, sp, 4
la t0, BF16_EXP_BIAS
lw t3, 0(t0)
add s8, s2, s3
add s8, s8, s10
sub s8, s8, t3 #result_exp
li t3, 0x8000
and t4, s9, t3
li t5, 0x7F
beqz t4, skipH
addi s8, s8, 1
srli s9, s9, 8
and s9, s9, t5
j skip_mant
skipH:
srli s9, s9, 7
and s9, s9, t5
skip_mant:
bge s8, t1, re_bits
blt x0, s8, skipJ
li t3, -6
bge s8, t3, skipI
j re_bits2
skipI:
li t3, 1
sub t4, t3, s8
srl s9, s9, t4
mv s8, x0
skipJ:
slli t3, s7, 15
and t4, s8, t1
slli t4, t4, 7
and t6, s9, t5
or a0, t3, t4
or a0, a0, t6
jr ra
bf16_div:
addi sp, sp, -4
sw ra, 0(sp)
jal ra, set_value
lw ra, 0(sp)
addi sp, sp, 4
xor s7, s0, s1 #result_sign
li t1, 0xFF
li t2, 0x80
li t5, 0x7F
bne s3, t1, skipK
bnez s5, re_b
bne s2, t1, re_bits2
bnez s4, re_bits2
j BF16_NAN
skipK:
bnez s3, skipL
bnez s5, skipL
bnez s2, re_bits
bnez s4, re_bits
j BF16_NAN
skipL:
bne s2, t1, skipM
bnez s4, re_a
j re_bits
skipM:
bnez s2, skipN
bnez s4, skipN
j re_bits2
skipN:
beqz s2, skipO
or s4, s4, t2
skipO:
beqz s3, skipP
or s5, s5, t2
skipP:
slli s11, s4, 15 #dividend
mv s6, s5 #divisor
li s9, 0 #quotient
li t3, -1
li t4, 16
li t2, 15
loopD:
addi t3, t3, 1
bge t3, t4, skipQ
slli s9, s9, 1
sub t6, t2, t3
sll t6, s6, t6
blt s11, t6, loopD
sub s11, s11, t6
ori s9, s9, 1
j loopD
skipQ:
la t0, BF16_EXP_BIAS
lw t2, 0(t0)
sub s8, s2, s3
add s8, s8, t2 #result_exp
bnez s2, skipR
addi s8, s8, -1
skipR:
bnez s3, skipS
addi s8, s8, 1
skipS:
li t2, 0x8000
li t4, 1
and t3, s9, t2
beqz t3, loopE
srli s9, s9, 8
j skip_quo
loopE:
bnez t3, skipT
ble s8, t4, skipT
slli s9, s9, 1
addi s8, s8, -1
j loopE
skipT:
srli s9, s9, 8
skip_quo:
and s9, s9, t5
bge s8, t1, re_bits
ble s8, x0, re_bits2
slli t3, s7, 15
and t4, s8, t1
slli t4, t4, 7
and t6, s9, t5
or a0, t3, t4
or a0, a0, t6
jr ra
bf16_sqrt:
li t1, 0xFF
li t2, 0x80
li t5, 0x7F
srli t3, a0, 15
andi s0, t3, 1 #sign
srli t3, a0, 7
and s1, t3, t1 #exp
and s2, a0, t5 #mant
bne s1, t1, skip_exp2
bnez s2, re_a
bnez s0, BF16_NAN
j re_a
skip_exp2:
bnez s1, skip_zero2
bnez s2, skip_zero2
j BF16_ZERO
skip_zero2:
bnez s0, BF16_NAN
beqz s1, BF16_ZERO
la t0, BF16_EXP_BIAS
lw t3, 0(t0)
sub s3, s1, t3 #e
or s5, t2, s2 #m
andi t4, s3, 1
beqz t4, skipU
slli s5, s5, 1
addi t4, s3, -1
srai t4, t4, 1
add s4, t4, t3 #new_exp
j skip_e
skipU:
srai t4, s3, 1
add s4, t4, t3 #new_exp
skip_e:
li s6, 90 #low
li s7, 256 #high
li s8, 128 #result
addi sp, sp, -4
sw ra, 0(sp)
loopF:
blt s7, s6, skip_loopF
add t3, s6, s7
srli s10, t3, 1 #mid
mv a0, s10
mv a1, s10
jal ra, prod
srli s11, a0, 7 #sq
blt s5, s11, skipV
mv s8, s10
addi s6, s10, 1
j loopF
skipV:
addi s7, s10, -1
j loopF
skip_loopF:
lw ra, 0(sp)
addi sp, sp, 4
li t2, 256
blt s8, t2, skip_256
srli s8, s8, 1
addi s4, s4, 1
j skip_result
skip_256:
bge s8, t6, skip_result
li t3, 1
loopG:
bge s8, t6, skip_result
bge t3, s4, skip_result
slli s8, s8, 1
addi s4, s4, -1
j loopG
skip_result:
and s9, s8, t5 #new_mant
blt s4, t1, skip_inf
li t3, 0x7F80
mv a0, t3
jr ra
skip_inf:
bge x0, s4, BF16_ZERO
and t3, s4, t1
slli t3, t3, 7
or a0, t3, s9
jr ra
prod:
li t3, 0
loop_mul:
beqz a1, skip_mul
andi t4, a1, 1
beqz t4, next_mul
add t3, t3, a0
next_mul:
slli a0, a0, 1
srli a1, a1, 1
bnez a1, loop_mul
skip_mul:
mv a0, t3
jr ra
```
## Maximum Product of Word Lengths ([LeetCode318](https://leetcode.com/problems/maximum-product-of-word-lengths/description/))
>Given a string array words, return the maximum value of length(word[i]) * length(word[j]) where the two words do not share common letters. If no such two words exist, return 0.
>* 2 <= words.length <= 1000
>* 1 <= words[i].length <= 1000
>* words[i] consists only of lowercase English letters.
>Example 1:
>Input: words = ["abcw","baz","foo","bar","xtfn","abcdef"]
Output: 16
Explanation: The two words can be "abcw", "xtfn".
>Example 2:
>Input: words = ["a","ab","abc","d","cd","bcd","abcd"]
Output: 4
Explanation: The two words can be "ab", "cd".
>Example 3:
>Input: words = ["a","aa","aaa","aaaa"]
Output: 0
Explanation: No such pair of words.
## Solution
### Idea for problem solving
Uses bitmasks to efficiently determine whether two words share any common letters.
1. Each letter corresponds to one bit in an integer
(a → bit 0, b → bit 1, …, z → bit 25).
Example:
"abc" → binary 000...000111 → 0x7
"w" → bit 22 → 0x400000
2. Build a bitmask for every word
For each letter in the word, set the corresponding bit to 1.
This creates a compact 26-bit representation of all letters in that word.
3. Compare each pair of words
```C=1
if ((mask[i] & mask[j]) == 0)
```
* If the bitwise AND is 0 → the two words share no common letters.
* If no common letters, calculate the product of their lengths.
Checking for common letters is O(1) (a single AND operation).
The overall complexity is O(n²), but much faster than comparing letters one by one.
### C code (Without clz)
```C=1
int maxProduct(char *words[], int n) {
int mask[1000] = {0};
int ans = 0;
// build bitmask
for (int i = 0; i < n; i++) {
int m = 0;
for (int j = 0; words[i][j] != '\0'; j++) {
m |= 1 << (words[i][j] - 'a');
}
mask[i] = m;
}
// compare pairs word
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if ((mask[i] & mask[j]) == 0) { // no common letters
int len = strlen(words[i]) * strlen(words[j]);
if (len > ans) ans = len;
}
}
}
return ans;
}
```
### Assembly code(Without clz)
```Assembly=1
maxProduct:
la s2, mask
li s3, 0
bitmask_loop:
beq s3, a1, compare_pairs # if i == n -> compare
li t6, 0 # m = 0
slli t0, s3, 2
add t1, a0, t0
lw t2, 0(t1) # t2 = words[i]
li s4, 0 # s4 = j = 0
char_loop:
add t3, t2, s4
lbu t4, 0(t3)
beqz t4, store_mask # == '\0'
addi t4, t4, -97 # t4 = t4 - 'a'
li t5, 1
sll t5, t5, t4 # 1 << (t4)
or t6, t6, t5 # m |= ...
addi s4, s4, 1 # j++
j char_loop
store_mask:
add t1, s2, t0
sw t6, 0(t1)
addi s3, s3, 1
j bitmask_loop
compare_pairs:
li s4, 0 # i = 0
la s3, ans # s3 = ans
addi sp, sp, -8
sw ra, 0(sp)
outer_loop:
bge s4, a1, done
addi t1, s4, 1 # j = i + 1
inner_loop:
bge t1, a1, next_i
slli t2, s4, 2
add t4, s2, t2
lw t5, 0(t4) # mask[i]
slli t2, t1, 2
add t4, s2, t2
lw t6, 0(t4) # mask[j]
and t2, t5, t6
bnez t2, skip_pair # mask[i] & mask[j] == 0
mv s0, a0
slli t2, s4, 2
add t4, s0, t2
lw a0, 0(t4) # word[i]
jal ra, strlen
mv t5, a0 # len1
slli t2, t1, 2
add t4, s0, t2
lw a0, 0(t4) # word[j]
jal ra, strlen
mv t6, a0 # len2
mv a0, s0
mul t2, t5, t6
lw t3, 0(s3)
bge t3, t2, skip_pair
sw t2, 0(s3)
skip_pair:
addi t1, t1, 1
j inner_loop
next_i:
addi s4, s4, 1
j outer_loop
done:
lw a0, 0(s3) # a0 = ans
li a7, 1
ecall
la a0, newline
li a7, 4
ecall
lw ra, 0(sp)
addi sp, sp, 4
jr ra
strlen:
mv t0, a0
li t3, 0
strlen_loop:
lbu t2, 0(t0)
beqz t2, strlen_end
addi t3, t3, 1
addi t0, t0, 1
j strlen_loop
strlen_end:
mv a0, t3
jr ra
```
The above program is a simple brute force solution. I designed a second version that added sorting and pruning to improve the program efficiency.
### C code(With Prune)
```C=1
// calculate highest bit
int highest_bit(uint32_t x) {
if (x == 0) return -1;
return 31 - __builtin_clz(x);
}
int maxProduct(char *words[], int n) {
int mask[MAX_WORDS] = {0};
int ans = 0;
// build bitmask
for (int i = 0; i < n; i++) {
for (int j = 0; words[i][j]; j++)
mask[i] |= 1 << (words[i][j] - 'a');
}
// Record the highest bit and index
int highest[MAX_WORDS];
for (int i = 0; i < n; i++)
highest[i] = highest_bit(mask[i]);
// sort by highest bit(largest to smallest)
int indices[MAX_WORDS];
for (int i = 0; i < n; i++)
indices[i] = i;
for (int i = 0; i < n - 1; i++) {
for (int j = i + 1; j < n; j++) {
if (highest[indices[i]] < highest[indices[j]]) {
int tmp = indices[i];
indices[i] = indices[j];
indices[j] = tmp;
}
}
}
// compare words
for (int i = 0; i < n; i++) {
int idx1 = indices[i];
for (int j = i + 1; j < n; j++) {
int idx2 = indices[j];
// Prune:If product of length not greater than ans
int prod = (int)strlen(words[idx1]) * (int)strlen(words[idx2]);
if (prod <= ans) break;
if ((mask[idx1] & mask[idx2]) == 0)
if (prod > ans) ans = prod;
}
}
return ans;
}
```
### Assembly code(With Prune)
```Assembly=1
maxProduct:
addi sp, sp, -8
sw ra, 0(sp)
sw a0, 4(sp)
la s2, mask
la s5, highest
la s6, indices
la s7, len_arr
li s3, 0
bitmask_loop:
beq s3, a1, sort # if i == n -> compare
li t6, 0 # m = 0
lw a0, 4(sp)
slli t0, s3, 2
add t1, a0, t0
lw t2, 0(t1) # t2 = words[i]
li s4, 0 # s4 = j = 0
char_loop:
add t3, t2, s4
lbu t4, 0(t3)
beqz t4, store_mask # == '\0'
addi t4, t4, -97 # t4 = t4 - 'a'
li t5, 1
sll t5, t5, t4 # 1 << (t4)
or t6, t6, t5 # m |= ...
addi s4, s4, 1 # j++
j char_loop
store_mask:
add t1, s7, t0
sw s4, 0(t1) # store len_arr[i]
add t1, s2, t0
sw t6, 0(t1) # store mask[i]
mv a0, t6
jal ra, highest_bit
add t1, s5, t0
sw a0, 0(t1) # store highest[i]
addi s3, s3, 1
j bitmask_loop
sort:
li s3, 0 # i = 0
loop_i:
beq s3, a1, loop_end
slli t2, s3, 2
add t4, s6, t2
sw s3, 0(t4)
addi s3, s3, 1
j loop_i
loop_end:
li s3, 0 # i = 0
addi t0, a1, -1
out_loop:
beq s3, t0, compare_pairs
addi t1, s3, 1
in_loop:
beq t1, a1, next_loop
slli t2, s3, 2
add t4, s6, t2
lw t5, 0(t4) #indices[i]
mv s8, t4
mv s10, t5
slli t2, t5, 2
add t4, s5, t2
lw t5, 0(t4) #highest[indices[i]]
slli t2, t1, 2
add t4, s6, t2
lw t6, 0(t4) #indices[j]
mv s9, t4
mv s11, t6
slli t2, t6, 2
add t4, s5, t2
lw t6, 0(t4) #highest[indices[j]]
bge t5, t6, skip_swap
sw s11, 0(s8)
sw s10, 0(s9)
skip_swap:
addi t1, t1, 1
j in_loop
next_loop:
addi s3, s3, 1
j out_loop
compare_pairs:
li s4, 0 # i = 0
li s3, 0 # s3 = ans
outer_loop:
bge s4, a1, done
slli t2, s4, 2
add t4, s6, t2
lw s10, 0(t4) # idx1
addi t1, s4, 1 # j = i + 1
inner_loop:
bge t1, a1, next_i
slli t2, t1, 2
add t4, s6, t2
lw s11, 0(t4) # idx2
slli t2, s10, 2
add t4, s7, t2
lw t5, 0(t4) # len_arr[idx1]
slli t2, s11, 2
add t4, s7, t2
lw t6, 0(t4) # len_arr[idx2]
mul s0, t5, t6 #prod
bge s3, s0, skip_pair
slli t2, s10, 2
add t4, s2, t2
lw t5, 0(t4) # mask[idx1]
slli t2, s11, 2
add t4, s2, t2
lw t6, 0(t4) # mask[idx2]
and t2, t5, t6
bnez t2, skip_pair # mask[i] & mask[j] == 0
bge s3, s0, skip_pair
mv s3, s0
skip_pair:
addi t1, t1, 1
j inner_loop
next_i:
addi s4, s4, 1
j outer_loop
done:
mv a0, s3 # a0 = ans
li a7, 1
ecall
la a0, newline
li a7, 4
ecall
lw ra, 0(sp)
addi sp, sp, 8
jr ra
highest_bit:
bnez a0, clz
li a0, -1
jr ra
clz:
li t4, 32 # n = 32
li t5, 16 # c = 16
mv t2, a0 # t2 = x
clz_loop:
srl t3, t2, t5 # y = x >> c
beqz t3, skip # if y == 0, skip
sub t4, t4, t5 # n -= c
mv t2, t3 # x = y
skip:
srli t5, t5, 1 # c >>= 1
bnez t5, clz_loop # while (c != 0)
sub a0, t4, t2 # n - x
li t6, 31
sub a0, t6, a0 # return 31 - clz(x)
jr ra
```
#### Idea
Each word’s bitmask represents which letters it contains.
**highest_bit(mask)** gives the position of the highest letter in that word (e.g., z → 25).
We use this value to sort the words.
1. Sort the words by their highest_bit in descending order,
so words with “larger” letters come first.
2. When comparing pairs of words, If **prod <= ans** (the current best result),
you can break early.
* Pruning does not change the worst-case time complexity (it remains O(n²)), but it can significantly reduce the average number of comparisons to around O(nlogn) for real-world data.
Using clz when calculating the highest bit can speed up the calculation. In addition, since the string length is included in the mask calculation, it does not need to be calculated separately. The strlen function in the without prune version can be removed.
Without Prune:

With Prune:

* From the two figures above, it is easy to see that the number of cycles of the optimized program is relatively small.
### Hazard
After optimizing the program, some hazards still exist in the assembly code. The following are the hazards I found and how to resolve them.


```Assembly=19
char_loop:
add t3, t2, s4
lbu t4, 0(t3)
beqz t4, store_mask # == '\0'
addi t4, t4, -97 # t4 = t4 - 'a'
li t5, 1
sll t5, t5, t4 # 1 << (t4)
or t6, t6, t5 # m |= ...
addi s4, s4, 1 # j++
j char_loop
```
* It can be seen that the t4 register has data hazards on lines 21 and 22. This causes a stall cycle to be required between the lbu and beqz instructions.
The solution is to rearrange the code and insert the li instruction in line 25 between the lbu and beqz instructions, because the li instruction only uses the t5 register and will not affect the t4 register.
* The following is the rearranged code
```Assembly=19
char_loop:
add t3, t2, s4
lbu t4, 0(t3)
li t5, 1
beqz t4, store_mask # == '\0'
addi t4, t4, -97 # t4 = t4 - 'a'
sll t5, t5, t4 # 1 << (t4)
or t6, t6, t5 # m |= ...
addi s4, s4, 1 # j++
j char_loop
```

* Because there are instructions, there is no need to waste a cycle.

* The overall number of cycles has also been reduced by more than 100.(Original: 5888)