Try   HackMD

Assignment1: RISC-V Assembly and Instruction Pipeline

contributed by < chi0819 >

Sigmoid Function by BF16 without Floating Point Support

In the field of deep learning, the sigmoid function is widely used as the activation function for each layer of neural networks. Open AI project triton is a language and compiler for writing highly efficient custom Deep-Learning primitives. To write fast code at higher productivity than CUDA. triton utilize BF16 to speed up computation and more memory efficiency. For example, we can write custom GPU kernels for activation functions like the sigmoid function to accelerate computations in deep learning models

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

Sigmoid function maps any real-valued number to a value between 0 and 1, which makes it useful in machine learning for activating neurons in neural networks and in logistic regression for modeling probabilities

σ(x)=11+ex

For the exponential function

exp(x) which is the same as
ex
, the Taylor series expansion around
x=0
(also known as the Maclaurin series) is particularly straightforward and widely used.
exp(x)n=0xnn!=1+x+x22!+x33!+x44!+

For

exp(x)
exp(x)n=0(x)nn!=1x+x22!x33!+x44!

Although employing the Maclaurin series to calculate the natural exponential function requires maintaining input values within the range of [-1, 1] to ensure accuracy, precise computation of the exponential values can be achieved by normalizing the input data and controlling the magnitudes of weights and biases. Furthermore, because the inputs are confined to [-1, 1], numerical convergence occurs rapidly. Unlike methods based on ln 2, there is no need to restrict the input domain, which consequently conserves memory and reduces computational time.

Function Reference

With FP32 <-> BF16 conversion in 2024 Quiz 1 Problem B
I try to modify 2023 Quiz 1 Problem C and last year student's Homework1 function.
Create BF16 functions to calculate

σ(x) without floating point support.

  • fmul32 -> bfmul16
  • imul32 -> imul16
  • fdiv32 -> bfdiv16
  • idiv24 -> idiv7
  • fadd32 -> bfadd16

BF16 Layout

The structure of the bfloat16 floating-point format is as follows.

        ┌ sign 
        │
        │   ┌ exponent
        │   │
        │   │      ┌ mantissa 
        │   │      │
        │┌──┴───┐┌─┴───┐
      0b0000000000000000 bfloat16

Simple Definition of bf16_t Data Type

typedef struct {
    uint32_t bits;
} bf16_t;

Convert FP32 to BF16

static inline bf16_t fp32_to_bf16(float s)
{
    bf16_t h;
    union {
        float f;
        uint32_t i;
    } u = {.f = s};
    if ((u.i & 0x7fffffff) > 0x7f800000) { /* NaN */
        h.bits = (u.i >> 16) | 64;         /* force to quiet */
        return h;                                                                                                                                             
    }
    h.bits = (u.i + (0x7fff + ((u.i >> 0x10) & 1))) >> 0x10;
    return h;
}

Convert BF16 to FP32

static inline float bf16_to_fp32(bf16_t h)
{
    union {
        float f;
        uint32_t i;
    } u = {.i = (uint32_t)h.bits << 16};
    return u.f;
}

Check LSB getbit

uint16_t getbit(uint16_t value, int n)
{
    return (value >> n) & 1;
}

C Code

I modified the C code from last year student's homework
Change the data type to bf16_t, learn floating point format
Most of time spend to understand bitwise operation

bfmul16

bf16_t bfmul16(bf16_t a, bf16_t b)
{
    uint16_t ia = a.bits;
    uint16_t ib = b.bits;

    /* mantissa 7 bits */
    uint8_t ma = (ia & 0x7F) | 0x80;
    uint8_t mb = (ib & 0x7F) | 0x80;

    /* exponent 8 bits */
    uint8_t exp_a = (ia >> 7) & 0xFF;
    uint8_t exp_b = (ib >> 7) & 0xFF;

    /* sign */
    uint16_t sign = (ia ^ ib) & 0x8000;

    /* mantissa multiplication */
    /* extend uint8_t ma and mb to uint16_t */
    /* only need uint16_t to store the multiplication of ma and mb */
    /* the multiplication result won't exceed 16 bits */
    uint16_t m = imul16((int16_t)ma, (int16_t)mb);
    // uint16_t m = (uint16_t)ma * (uint16_t)mb;

    /* normalization */
    uint8_t mshift = (m >> 15) & 1; /* check if bit 15 is set */
    m >>= mshift; /* shift mantissa right by 1 if overflow */
    
    /* adjust exponent: (exp_a + exp_b - bias) + mshift */
    int16_t exp_r = (int16_t)exp_a + (int16_t)exp_b - 127 + (int16_t)mshift;

    /* only take care of underflow */
    /* aim to process small number */
    if (exp_r <= 0) exp_r = 0;

    /* combine sign, exponent, and mantissa */
    return (bf16_t) {.bits = sign | (((uint16_t)exp_r << 7) & 0x7F80) | ((m >> 7) & 0x7F)};
}

imul16

uint16_t imul16(int16_t a, int16_t b)
{
    uint16_t r = 0;

    for (int i = 0; i < 8; i++) {
        if (getbit(b, i))
            r += a << i;
    }
    return r;
}

Use loop unrolling to reduce branch condition
TODO

bfdiv16

bf16_t bfdiv16(bf16_t a, bf16_t b)
{
    uint16_t ia = a.bits;
    uint16_t ib = b.bits;

    /* divisor = 0 then return infty */
    if ((ib & 0x7FFF) == 0) return (bf16_t) {.bits = ((ia ^ ib) & 0x8000) | 0x7F80};
    /* dividend = 0 then return 0 */
    if ((ia & 0x7FFF) == 0) return (bf16_t) {.bits = (ia ^ ib) & 0x8000};

    /* mantissa */
    int16_t ma = (ia & 0x7F) | 0x80;
    int16_t mb = (ib & 0x7F) | 0x80;

    /* sign and mantissa */
    int16_t sea = ia & 0xFF80;
    int16_t seb = ib & 0xFF80;

    /* result of mantissa */
    int32_t mantissa = idiv7(ma, mb);
    int32_t mshift = !getbit(mantissa, 15);
    mantissa <<= mshift;

    return (bf16_t){.bits = ((sea - seb + 0x3F80) - (0x8000 & -mshift)) | (mantissa & 0x7E00) >> 9};
}

idiv7

int16_t idiv7(int16_t a, int16_t b){
    uint16_t r = 0;
    for (int i = 0; i < 16; i++){
        r <<= 1;
        if (a - b < 0){
            a <<= 1;
            continue;
        }

        r |= 1;
        a -= b;
        a <<= 1;
    }

    return r;
}

bfadd16

bf16_t bfadd16(bf16_t a, bf16_t b) {
    uint16_t ia = a.bits;
    uint16_t ib = b.bits;

    /* compare thr absolute value */
    /* ensure a is bigget than b */
    uint16_t cmp_a = ia & 0x7fff;
    uint16_t cmp_b = ib & 0x7fff;
    if (cmp_a < cmp_b)
        iswap(ia, ib);
    
    /* exponent */
    uint16_t ea = (ia >> 7) & 0xff;
    uint16_t eb = (ib >> 7) & 0xff;

    /* 7 bits is the size of bf16_t mantissa */
    int16_t mantissa_align_shift = (ea - eb > 7) ? 7 : ea - eb;

    /* mantissa */
    uint16_t ma = (ia & 0x7F) | 0x80;
    uint16_t mb = (ib & 0x7F) | 0x80;

    mb >>= mantissa_align_shift;
    if((ia ^ ib) >> 15) ma -= mb;
    else ma += mb;
    int16_t clz = my_clz(ma);
    int16_t shift = 0;
    if(clz <= 8) {
        shift = 8 - clz;
        ma >>= shift;
        ea += shift;
    } else {
        shift = clz - 8;
        ma <<= shift;
        ea -= shift;
    }
    
    return (bf16_t){.bits = (ia & 0x8000) | ((ea << 7) & 0x7f80) | (ma & 0x7f)};
}

bfsub16

bf16_t bfsub16(bf16_t a, bf16_t b) {
    uint16_t ia = a.bits;
    uint16_t ib = b.bits;

    /* compare the absolute values */
    /* ensure a has greater or equal magnitude than b */
    uint16_t cmp_a = ia & 0x7FFF;
    uint16_t cmp_b = ib & 0x7FFF;
    int swapped = 0;
    if (cmp_a < cmp_b) {
        iswap(ia, ib);
        swapped = 1;
    }

    /* exponents */
    uint16_t ea = (ia >> 7) & 0xFF;
    uint16_t eb = (ib >> 7) & 0xFF;

    /* 7 bits is the size of bf16_t mantissa */
    int16_t mantissa_align_shift = (ea - eb > 7) ? 7 : ea - eb;

    /* mantissas */
    int16_t ma = (ia & 0x7F) | 0x80;
    int16_t mb = (ib & 0x7F) | 0x80;

    mb >>= mantissa_align_shift;

    /* adjust mantissas based on signs */
	/* if sign is same then ma subtract mb directly */
    if (!((ia ^ ib) >> 15)) ma -= mb;
    else ma += mb;

    /* handle negative mantissa result */
    int negative_result = 0;
    if (ma < 0) {
        ma = -ma;
        negative_result = 1;
    }

    /* normalize the result */
    int16_t clz = my_clz((uint16_t)ma);
    int16_t shift = 0;
    if (clz <= 8) {
        shift = 8 - clz;
        ma >>= shift;
        ea += shift;
    } else {
        shift = clz - 8;
        ma <<= shift;
        ea -= shift;
    }

    /* determine the sign of the result */
    uint16_t sign_a = ia & 0x8000; /* if swapped ia = original ib */
    uint16_t sign_b = ib & 0x8000; /* if swapped ib = original ia */
    uint16_t sign;
    if (negative_result == 0) sign = (swapped ? sign_b : sign_a);
    else sign = (swapped ? sign_b : sign_a) ^ 0x8000; /* Flip the sign */
		
    if(swapped && !(sign_a ^ sign_b)) sign ^= 0x8000; /* need to be optimize */

    /* assemble the final result */
    return (bf16_t){.bits = sign | ((ea << 7) & 0x7f80) | (ma & 0x7F)};
}

When I solve the bfsub16 bug about sign
I suddenly found that I don't need to implement subtraction when I ready to sleep
If I want to make

ab operation, I can simply change
b
to
b

Then make addition
a+(b)
to perform subtraction
I paid lots of time to implement bfsub16 but I think it's a good experience

Think More
I need some time to figure out how to combine this three line together
It's a good chance to exercise my brain

if (negative_result == 0) sign = (swapped ? sign_b : sign_a);
else sign = (swapped ? sign_b : sign_a) ^ 0x8000; /* Flip the sign */
		
if(swapped && !(sign_a ^ sign_b)) sign ^= 0x8000; /* need to be optimize */

Outcome

Emm The outcome is really weird

input x = 0.500000
my bf16 sigmoid = 0.308594
STL sigmoid approch = 0.622459

Have you checked the error rate accordingly?

In the main.c random sample float in range

[1,1]
Sample size is 10
The number of terms in the Taylor expansion of the exponential function is 5

MSE : 0.698543

Wait me to fix the bfdiv16 problem
I fix bfmul16, bfadd16 and bfsub16 successfully
If I fix bfdiv16 successfully, I will come back to fix Sigmoid

Check Expansion Term Output

The bfmul16, bfadd16 and bfsub16 output is correct
It seems that the problem is on bfdiv16

term 1 : 0.500000 = 0.500000 / 1.000000
term 2 : 0.125000 = 0.250000 / 2.000000
term 3 : -0.036377 = 0.125000 / 6.000000
term 4 : -0.004547 = 0.062500 / 24.000000
term 5 : -0.000504 = 0.031250 / 120.000000
term 6 : -0.000037 = 0.015625 / 720.000000
term 7 : -0.000002 = 0.007812 / 5040.000000
term 8 : -0.000000 = 0.003906 / 40320.000000
term 9 : -0.000000 = 0.001953 / 362880.000000

Result

I have converted the functional parts into assembly code and executed them on Ripes, learning the actual process of how assembly code runs on the processor. After becoming familiar with bitwise operations, I will come back to fix the functionality of bfdiv16. I think besides bitwise operations, I also need to better understand the specifications and handling of floating-point numbers

Assembly Code

fp32_to_bf16

Use this simple function to learn the instructions flow in each stage

.data
argument: .word   0x3E800000 # float 0.25

.text

_start:
    j main
    
fp32_to_bf16:
    la a0, argument      # load argument address
    lw a1, 0(a0)         # load data (float 0.25) from argument
    li a2, 0x7F800000    # use to check NaN case
    li a3, 0x7FFFFFF     # bit mask
    and a4, a1, a3       # u.i & 0x7FFFFFFF
    blt a2, a4, non_case # Handle NaN case
    li a2, 0x7FFF
    mv a3, a1
    srli a3, a3, 16      # u.i >> 16
    andi a3, a3, 1       # (u.i >> 16) & 1
    add a3, a3, a2       # 0x7FFF + (u.i >> 16) & 1
    add a3, a3, a1       # u.i + (0x7FFF + (u.i >> 16) & 1)
    srli a0, a3, 16      # (u.i + (0x7FFF + (u.i >> 16) & 1)) >> 16
    ret
non_case:
    srli a1, a1, 16      # (u.i >> 16)
    li a4, 0x40
    or a0, a1, a4        # (u.i >> 16) | 0x40
    ret
main:
    jal ra, fp32_to_bf16
    
end:
    nop

correspond disassembled code

00000000 <_start>:
    0:        0580006f        jal x0 88 <main>

00000004 <fp32_to_bf16>:
    4:        10000517        auipc x10 0x10000
    8:        ffc50513        addi x10 x10 -4
    c:        00052583        lw x11 0 x10
    10:        7f800637        lui x12 0x7f800
    14:        080006b7        lui x13 0x8000
    18:        fff68693        addi x13 x13 -1
    1c:        00d5f733        and x14 x11 x13
    20:        02e64463        blt x12 x14 40 <non_case>
    24:        00008637        lui x12 0x8
    28:        fff60613        addi x12 x12 -1
    2c:        00058693        addi x13 x11 0
    30:        0106d693        srli x13 x13 16
    34:        0016f693        andi x13 x13 1
    38:        00c686b3        add x13 x13 x12
    3c:        00b686b3        add x13 x13 x11
    40:        0106d513        srli x10 x13 16
    44:        00008067        jalr x0 x1 0

00000048 <non_case>:
    48:        0105d593        srli x11 x11 16
    4c:        04000713        addi x14 x0 64
    50:        00e5e533        or x10 x11 x14
    54:        00008067        jalr x0 x1 0

00000058 <main>:
    58:        fadff0ef        jal x1 -84 <fp32_to_bf16>

0000005c <end>:
    5c:        00000013        addi x0 x0 0

bfadd16

.data
data1:    .word    0x3f9f # BF16 1.245
data2:    .word    0x4050 # BF16 3.255
ans:      .word    0x40900000 # Expect ans 4.5
newline:  .string  "\n"
str1:     .string  "data1 : "
str2:     .string  "data2 : "
str3:     .string  "bfadd16 output : "
str4:     .string  "Expect ans : "

.text

bfadd16:
    la a0, data1
    la a1, data2

    # Load input
    lh t0, 0(a0)        # ia
    lh t1, 0(a1)        # ib

    # create bit mask 0x7fff
    li t2, 0x7fff
    and t3, t0, t2          # cmp_a = ia & 0x7fff
    and t4, t1, t2          # cmp_b = ib & 0x7fff
    
    bge t3, t4, no_swap     # if cmp_a >= cmp_b goto noswap
swap:
    mv t2, t0
    mv t0, t1
    mv t1, t2
no_swap:
    srli t3, t0, 7        # ia >> 7
    srli t4, t1, 7        # ib >> 7
    andi t3, t3, 0xff     # ea = (ia >> 7) & 0xff
    andi t4, t4, 0xff     # eb = (ib >> 7) & 0xff
    sub t2, t3, t4        # t2 = ea - eb
    li t5, 7
    bgt t2, t5, mshift_max
    j mshift_done
mshift_max:
    li t2, 7               # mshift = 7
mshift_done:
    andi t5, t0, 0x7f
    ori t5, t5, 0x80        # ma = (ia & 0x7f) | 0x80
    andi t6, t1, 0x7f
    ori t6, t6, 0x80        # mb = (ib & 0x7f) | 0x80
    srl t6, t6, t2          # mb >>= mshift
    xor a1, t0, t1          # ia ^ ib
    srli a1, a1, 15         # (ia ^ ib) >> 15
    beqz a1, ma_plus_mb
ma_minus_mb:
    sub t5, t5, t6
    j ma_mb_done
ma_plus_mb:
    add t5, t5, t6
ma_mb_done:
    jal ra, clz
    li a2, 9
    blt a0, a2, shift_in_range # clz <= 8
    addi t4, a0, -8    # shift = clz - 8
    sll t5, t5, t4     # ma <<= shift
    sub t3, t3, t4     # ea -= shift
    j return
shift_in_range:
    li t4, 8
    sub t4, t4, a0    # shift = 8 - clz
    srl t5, t5, t4    # ma >>= shift
    add t3, t3, t4    # ea += shift
return:
    li a0, 0x8000    # create bit mask
    and a0, a0, t0   # sign = ia & 0x8000
    slli t3, t3, 7   # exponent = ea << 7
    li t2, 0x7F80
    and t3, t3, t2   # exponent = (ea << 7) & 0x7F80
    or a0, a0, t3
    li t2, 0x7F
    and t5, t5, t2   # mantissa = ma & 0x7F
    or a0, a0, t5
    slli t0, a0, 16  # extend bf16 to fp32
    j print

clz:
    li a0, 0            # count = 0
    li a1, 0x8000       # create bit mask 0x8000
while_loop:
    and a2, t5, a1
    bnez a2, clz_done   # if((ma & mask) == 0)
    addi a0, a0, 1
    srli a1, a1, 1    # mask >>= 1
    bnez a1, while_loop
clz_done:
    ret

print:
    la a0, str1  # data1 : 
    li a7, 4     # ecall 4 print string
    ecall
    la a0, data1
    lw a0, 0(a0)       # load data1
    slli a0, a0, 16    # Extend bf16 to fp32
    li a7, 2           # ecall 2 print float
    ecall
    la a0, newline
    li a7, 4
    ecall
    la a0, str2  # data2 :
    li a7, 4
    ecall
    la a0, data2
    lw a0, 0(a0)       # load data2
    slli a0, a0, 16    # extend bf16 to fp32
    li a7, 2           # ecall 2 print float
    ecall
    la a0, newline
    li a7, 4
    ecall
    la a0, str3        # bfadd16 output :
    li a7, 4
    ecall
    mv a0, t0          # bfadd16 result
    li a7, 2
    ecall
    la a0, newline
    li a7, 4
    ecall
    la a0, str4        # Expect ans : 
    li a7, 4
    ecall
    la a0, ans
    lw a0, 0(a0)       # load expect fp32 ans
    li a7, 2
    ecall

output

data1 : 1.24219
data2 : 3.25
bfadd16 output : 4.46875
Expect ans : 4.5

Instruction Pipeline Detail

Tracking the instruction addi x10, x10, -4 in fp32_to_bf16 assembly code

Fetch

In instruction fetch stage, the address ( 0x00000008 ) of the instruction send to Instr. Memory to fetch instruction, with the 32 bits data ( 0xffc50513 ) we know the instruction in this stage is addi x10, x10, -4, and the PC + 4 going to next instruction at ( 0x0000000c )
Screenshot 2024-10-10 at 10.15.33 PM
Screenshot 2024-10-10 at 10.16.22 PM

Decode

Screenshot 2024-10-11 at 9.57.36 AM
Screenshot 2024-10-11 at 8.57.09 AM

screen shot from Lecture video, it is talking about the RV32I instruction layout

Screenshot 2024-10-11 at 9.40.44 AM

The sign extension operation in Imm. to extend 12 bits immediate in instruction to int32_t value

Decode extract the opecode and register number from instruction, we can found that the opcode output is 0x13 for ADDI ( 0b0010011 ), with opcode and instruction we can get immediate from Imm., with sign extension get the 32 bits signed integer -4 in bit pattern is 0xfffffffc, in the Decode stage also decode opcode in Control block to know what should ALU do in Execution stage.

The output for R1 inx is 0x0a ( 10 ) is for source register x10 in instruction, also send the 0x0a signal to next stage, because the instruction take data from x10, after the calculation we store the result back into x10 register

Why 0x1c ( 28 ) for R2 inx ? need some time to understand the design concept

load & store Hazard

I think the signal to R2 idx also used to hazard detection and check next instruction should be forward or not

I write a simple assembly program to understand signal connect to R2 idx, after instruction lw a1, 0(a0) is the addi a1, a1, 1, if there don't has any hazard detection, addi a1, a1, 1 in Execution stage will cause error because the data hasn't store into register a1, in this example the signal connect to R2 idx is 0x0b

 .data
argument: .word   0x1

.text

_start:
    la a0, argument
    lw a1, 0(a0)
    addi a1, a1, 1
    addi a1, a1, 1

The yellow line connect 0x0b to Hazard Unit also store the data to ID/EX register, when lw x11, 0(x10) at Execution stage, 0x0b will be send to Forwarding Unit from ID/EX register
Screenshot 2024-10-11 at 12.05.33 PM

The signal connect to R2 idx also connect to Hazard Unit and store the signal to ID/EX register

When lw x11, 0(x10) forward to Execution stage, find the ID/EX register's clear signal is enable, notice the Hazard Unit lower input id_reg1_idx ( 0b01011, x11 ) is from next instruction at Decode stage, upper input is from ID/EX register indicate the register ( 0b01011, x11 ) now is used by lw x11, 0(x10), so Hazard Unit find these two instruction has conflict, so the ID/EX will be clear and make a stall, wait lw x11, 0(x10) forward to Data Memory to read out the data from x10, also notice the IF/ID register enable signal is false, so that addi x11, x11, 1 remain stay at Decoding stage, the next instruction won't advanced from Fetch stage
Screenshot 2024-10-11 at 1.33.07 PM

I think there has a technique to prevent one more stall called data forwarding
Screenshot 2024-10-11 at 1.53.26 PM-2
Notice the yellow line forward the data 0x00000001 back to Execution stage for addi x11, x11, 1 , also notice the Forwarding Unit signal ( WbStage ) to multiplexer, select the forwarding data from lw x11, 0(x10) as input data
Screenshot 2024-10-11 at 1.57.53 PM

More about Data Forwarding

We can execute two addi on same register continuously, because we can use data forwarding from Memory stage, notice the yellow line is the data passed from EX/MEM register which store the previous instruction output, with Forwarding Unit signal MemStage send to Execution stage multiplexer, the instruction at Execution stage can use correct data in x11
Screenshot 2024-10-11 at 2.34.43 PM
Screenshot 2024-10-11 at 2.40.59 PM

Data forwarding after Execution stage before Memory stage

Execution

In Execution stage, the ALU behavior is control by the 5 bits control signal ctrl, before the data send into ALU, there has two multiplexer with enable signal to select which data should be calculated by ALU, in this example, the input is 0x10000004 and 0xfffffffc, the ALU behavior is Add immediate, the ALU output is the calculation result ( 0x10000004 + -4 )
Screenshot 2024-10-11 at 10.20.55 AM
Screenshot 2024-10-11 at 10.28.11 AM

After the calculation result write back to the register destination x10, we can check the result is in x10 ( 0x10000000 = 0x10000004 + -4 )

In this example, when instruction addi x10, x10, -4 in Execution stage, the IMM control signal is enable, so the ALU input Op2 select the immediate data 0xfffffffc ( -4 )
Screenshot 2024-10-11 at 10.31.55 AM
Screenshot 2024-10-11 at 2.40.20 PM

Memory

From previous simple assembly program, the data value is 0x1, the input of Data Memory is Addr. ( 0x10000000 ) and Data in ( 0x00000000 ) as byte offset, the data is store in address 0x10000000, actually, in disassembled code, before lw x11, 0(x10) we need load the address of data to x10 use auipc x10 0x10000, becaue lw is I-Type instruction which only has 12 bits to represent immediate
Screenshot 2024-10-11 at 2.57.48 PM

little endian : lower bit store in lower address

Data Memory get LW OP signal, Addr. and offset, so that Data Memory read out the data 0x00000001 from data segment by memory address
Screenshot 2024-10-12 at 11.54.35 AM

Write Back

In Write Back stage also has multiplexer to select the data from Execution, we found that the number of register 0x0a ( x10 ), recall that Decode stage extract the register destination and the pipeline forward the data to the Write Back stage, notice that there has a green line is connect back to Register, the green line signal extracted when the instruction in the Decode stage, the signal is for write register enable
Screenshot 2024-10-11 at 10.42.54 AM
The data 0x10000000 is connect to Wr data, and 0x0a is connect to Wr idx, Wr En is connect to the write register enable signal, after this stage we can observe that the x10 register has data 0x10000000
Screenshot 2024-10-11 at 10.51.03 AM

Reference

Further Discussion

Why Sigmoid ?

How many times activation function called in neural network ?

Assume we has a fully connected neural network used to classify a person is male or female by picture, the network input size is

256, hidden layer is
[1024,512,256]
, output size is
2
, the total number to call activation function is
1024+512+256=1792

So we actually pay lots of time on activation operation, if we can reduce the overhead of activation function we can accelerate the model.
Screenshot 2024-10-14 at 12.56.40 PM

Sigmoid Properties

Sigmoid function output range is

[1,1], with small range of output we can normalize outlier data in
[1,1]
, so that the outlier won't cause serious impact for model's accuarcy.
And Sigmoid function is differentiable also a real function for all real input values, so that is good activation function for model to gradient descent, compare with another famous activation function called ReLU, ReLU isn't a bounded activation function, if the input is positive outlier, the output is same value as input, also the output of ReLU isn't probabilistic interpretable
The advantages of Sigmoid

  • Probability Interpretation: The sigmoid function's output range aligns well with probabilistic interpretations, which can enhance models that need to predict likelihoods.
  • Smooth Gradient: In some models, the smoothness of the sigmoid function helps in optimization, especially when dealing with probabilistic loss functions like cross-entropy.

BF16 Arithmetic Function Accuracy Evaluation

In this part, I don't evaluate the accuracy of BF16 Sigmoid
Because the bfdiv16 has some bugs I need to fix it
I will come back to complete BF16 Sigmoid when I successfully fix bfdiv16
I only test the accuracy of other BF16 functions used in BF16 Sigmoid
If the other function which used in BF16 Sigmoid has high accuracy
I think the BF16 Sigmoid also has high accuracy
Enhance basic BF16 arithmetic functions equal to enhance BF16 Sigmoid

Each array contain 100 sample

bfmul16 with Input Range
[1,1]

void testBF16MulMSE(float *arr1, float *arr2) {
	float BF16Mul[SAMPLE_SIZE], FP32Mul[SAMPLE_SIZE];
	for(int i = 0;i != SAMPLE_SIZE;i++) {
		BF16Mul[i] = bf16_to_fp32(bfmul16(fp32_to_bf16(arr1[i]), fp32_to_bf16(arr2[i])));
		FP32Mul[i] = arr1[i] * arr2[i];
	}
	float mse = MSE(BF16Mul, FP32Mul);
	printf("bfmul16 MSE : %f\n", mse);
}

The MSE in some condition will output inf and very large number
I think my bfmul16 has some problems to process overflow
Now I'm try to solve the bug

I found the bug, that is my MSE function :), arithmetic part is correct

bfmul16 MSE : 0.000002

bfadd16 with Input Range
[1,1]

Random sample two array each contain 100 float in range [-1, 1]

void testBF16AddMSE(float *arr1, float *arr2) {
	float BF16Add[SAMPLE_SIZE], FP32Add[SAMPLE_SIZE];
	for(int i = 0;i != SAMPLE_SIZE;i++) {
		BF16Add[i] = bf16_to_fp32(bfadd16(fp32_to_bf16(arr1[i]), fp32_to_bf16(arr2[i])));
		FP32Add[i] = arr1[i] + arr2[i];
	}
	float mse = MSE(BF16Add, FP32Add);
	printf("bfadd16 MSE : %f\n", mse);
}
bfadd16 MSE : 0.000005

bfsub16 with Input Range
[1,1]

void testBF16SubMSE(float *arr1, float *arr2) {
	float BF16Sub[SAMPLE_SIZE], FP32Sub[SAMPLE_SIZE];
	for(int i = 0;i != SAMPLE_SIZE;i++) {
		BF16Sub[i] = bf16_to_fp32(bfsub16(fp32_to_bf16(arr1[i]), fp32_to_bf16(arr2[i])));
		FP32Sub[i] = arr1[i] - arr2[i];
	}
	float mse = MSE(BF16Sub, FP32Sub);
	printf("bfsub16 MSE : %f\n", mse);
}
bfsub16 MSE : 0.000005

bfmul16 with Input Range
[5,5]

bfmul16 MSE : 0.000651

bfadd16 with Input Range
[5,5]

bfadd16 MSE : 0.000133

bfsub16 with Input Range
[5,5]

bfsub16 MSE : 0.000114

Perform BF16 Arithmetic Function without Floating Point Support

If a processor only support RV32I means that it doesn't have circuit to perform multiplication and division instructions, we can only use bitwise operation to reach same result. There has three steps we must do before make multiplication and division on two bf16_t value, there has some overhead because we must process sign, exponent and mantissa by different bit mask and branch condition

  • Extract Sign by bits & 0x8000 -> 0b1_00000000_0000000
  • Extract Exponent by bits & 0x7F80 -> 0b0_11111111_0000000
  • Extract Mantissa by (bits & 0x7F) & 0x80 -> 0x7F = 0b0_00000000_1111111 ; 0x80 = 0b0_000000001_0000000 use to mantissa rounding up

My target is to reuse bit mask and reduce branch condition
I will use the assembly code of bfadd16 for testing
Because each BF16 arithmetic function use bit mask and branch condition
So the methods to optimize assembly code is similar

bfadd16 original

.data
data1:    .word    0x3f9f # BF16 1.245
data2:    .word    0x4050 # BF16 3.255
ans:      .word    0x40900000 # Expect ans 4.5
.text

bfadd16:
    la a0, data1
    la a1, data2

    # Load input
    lh t0, 0(a0)        # ia
    lh t1, 0(a1)        # ib

    # create bit mask 0x7fff
    li t2, 0x7fff
    and t3, t0, t2          # cmp_a = ia & 0x7fff
    and t4, t1, t2          # cmp_b = ib & 0x7fff
    
    bge t3, t4, no_swap     # if cmp_a >= cmp_b goto noswap
swap:
    mv t2, t0
    mv t0, t1
    mv t1, t2
no_swap:
    srli t3, t0, 7        # ia >> 7
    srli t4, t1, 7        # ib >> 7
    andi t3, t3, 0xff     # ea = (ia >> 7) & 0xff
    andi t4, t4, 0xff     # eb = (ib >> 7) & 0xff
    sub t2, t3, t4        # t2 = ea - eb
    li t5, 7
    bgt t2, t5, mshift_max
    j mshift_done
mshift_max:
    li t2, 7               # mshift = 7
mshift_done:
    andi t5, t0, 0x7f
    ori t5, t5, 0x80        # ma = (ia & 0x7f) | 0x80
    andi t6, t1, 0x7f
    ori t6, t6, 0x80        # mb = (ib & 0x7f) | 0x80
    srl t6, t6, t2          # mb >>= mshift
    xor a1, t0, t1          # ia ^ ib
    srli a1, a1, 15         # (ia ^ ib) >> 15
    beqz a1, ma_plus_mb
ma_minus_mb:
    sub t5, t5, t6
    j ma_mb_done
ma_plus_mb:
    add t5, t5, t6
ma_mb_done:
    jal ra, clz
    li a2, 9
    blt a0, a2, shift_in_range # clz <= 8
    addi t4, a0, -8    # shift = clz - 8
    sll t5, t5, t4     # ma <<= shift
    sub t3, t3, t4     # ea -= shift
    j return
shift_in_range:
    li t4, 8
    sub t4, t4, a0    # shift = 8 - clz
    srl t5, t5, t4    # ma >>= shift
    add t3, t3, t4    # ea += shift
return:
    li a0, 0x8000    # create bit mask
    and a0, a0, t0   # sign = ia & 0x8000
    slli t3, t3, 7   # exponent = ea << 7
    li t2, 0x7F80
    and t3, t3, t2   # exponent = (ea << 7) & 0x7F80
    or a0, a0, t3
    li t2, 0x7F
    and t5, t5, t2   # mantissa = ma & 0x7F
    or a0, a0, t5
    slli t0, a0, 16  # extend bf16 to fp32
    j done

clz:
    li a0, 0            # count = 0
    li a1, 0x8000       # create bit mask 0x8000
while_loop:
    and a2, t5, a1
    bnez a2, clz_done   # if((ma & mask) == 0)
    addi a0, a0, 1
    srli a1, a1, 1    # mask >>= 1
    bnez a1, while_loop
clz_done:
    ret

done:
    nop

Screenshot 2024-10-15 at 1.32.17 AM

Issue 1 : Create bit mask used li actually need two instructions

For example : Bit mask for exponent li t2, 0x7F80
You can check in disassembled code li t2, 0x7F80 actually need 2 instructions

c0:        000083b7        lui x7 0x8
c4:        f8038393        addi x7 x7 -128

Recall the I-Type Instruction format, if immediate size is little or equal to 12 bits, li instruction can be simply convert addi instruction, but when we use li t2, 0x7F80 to create bit mask, in disassembled code we actually use two instruction lui and addi, because 0x7F80 is 16 bits width.
Screenshot%202024-10-08%20at%2012.49.26%E2%80%AFPM

Although we can simply use li to create a 32 bits value, if we need bit mask which size is exceed 12 bits then we actually use 2 instruction per bit mask, I think we can reuse init bit mask to reduce li instruction to create a new bit mask again, for example if we want to extract sign, we need create first bit mask by li t2, 0x8000, after if we need extract exponent, we can simply use addi t3, t2, -128 to get bit mask 0x7F80 for exponent rather than use li instruction again, so that we can reduce one instruction per bit mask

Issue 2 : jal used to check mantissa shift will waste two cycles

Screenshot 2024-10-15 at 1.33.09 AM

Now I don't figure out a method to optimize branch condition in bfadd16, Maybe there has a chance to reduce brach condition in count leading zero, in my count leading zero the time complexity is

O(n), I found a method can reduce time complexity to
O(log2n)

bfadd16 with bit mask reuse

.data
data1:    .word    0x3f9f # BF16 1.245
data2:    .word    0x4050 # BF16 3.255
ans:      .word    0x40900000 # Expect ans 4.5
.text

bfadd16:
    la a0, data1
    la a1, data2

    # Load input
    lh t0, 0(a0)        # ia
    lh t1, 0(a1)        # ib

    # create bit mask 0x7fff as init bit mask
    li a5, 0x7fff
    and t3, t0, a5          # cmp_a = ia & 0x7fff
    and t4, t1, a5          # cmp_b = ib & 0x7fff
    
    bge t3, t4, no_swap     # if cmp_a >= cmp_b goto noswap
swap:
    mv t2, t0
    mv t0, t1
    mv t1, t2
no_swap:
    srli t3, t0, 7        # ia >> 7
    srli t4, t1, 7        # ib >> 7
    andi t3, t3, 0xff     # ea = (ia >> 7) & 0xff
    andi t4, t4, 0xff     # eb = (ib >> 7) & 0xff
    sub t2, t3, t4        # t2 = ea - eb
    li t5, 7
    bgt t2, t5, mshift_max
    j mshift_done
mshift_max:
    li t2, 7               # mshift = 7
mshift_done:
    andi t5, t0, 0x7f
    ori t5, t5, 0x80        # ma = (ia & 0x7f) | 0x80
    andi t6, t1, 0x7f
    ori t6, t6, 0x80        # mb = (ib & 0x7f) | 0x80
    srl t6, t6, t2          # mb >>= mshift
    xor a1, t0, t1          # ia ^ ib
    srli a1, a1, 15         # (ia ^ ib) >> 15
    beqz a1, ma_plus_mb
ma_minus_mb:
    sub t5, t5, t6
    j ma_mb_done
ma_plus_mb:
    add t5, t5, t6
ma_mb_done:
    jal ra, clz
    li a2, 9
    blt a0, a2, shift_in_range # clz <= 8
    addi t4, a0, -8    # shift = clz - 8
    sll t5, t5, t4     # ma <<= shift
    sub t3, t3, t4     # ea -= shift
    j return
shift_in_range:
    li t4, 8
    sub t4, t4, a0     # shift = 8 - clz
    srl t5, t5, t4     # ma >>= shift
    add t3, t3, t4     # ea += shift
return:
-   li a0, 0x8000      # create bit mask for sign
+   addi a0, a5, 0x1   # Reuse bit mask
    and a0, a0, t0     # sign = ia & 0x8000
    slli t3, t3, 7     # exponent = ea << 7
-   li t2, 0x7F80      # create bit mask for exponent
+   addi t2, a5, -127  # Reuse bit mask
    and t3, t3, t2     # exponent = (ea << 7) & 0x7F80
    or a0, a0, t3
    andi t5, t5, 0x7F  # mantissa = ma & 0x7F
    or a0, a0, t5
    slli t0, a0, 16    # extend bf16 to fp32
    j done

clz:
    li a0, 0            # count = 0
-   li a1, 0x8000       # create bit mask for count leading zero
+   addi a1, a5, 1      # reuse bit mask
while_loop:
    and a2, t5, a1
    bnez a2, clz_done   # if((ma & mask) == 0)
    addi a0, a0, 1
    srli a1, a1, 1    # mask >>= 1
    bnez a1, while_loop
clz_done:
    ret

done:
    nop

Screenshot 2024-10-15 at 8.40.50 AM

I have seen there has another method to optimize count leading zero,