contributed by < chi0819
>
In the field of deep learning, the sigmoid function is widely used as the activation function for each layer of neural networks. Open AI project triton is a language and compiler for writing highly efficient custom Deep-Learning primitives. To write fast code at higher productivity than CUDA. triton utilize BF16 to speed up computation and more memory efficiency. For example, we can write custom GPU kernels for activation functions like the sigmoid function to accelerate computations in deep learning models
Sigmoid function maps any real-valued number to a value between 0 and 1, which makes it useful in machine learning for activating neurons in neural networks and in logistic regression for modeling probabilities
For the exponential function
For
Although employing the Maclaurin series to calculate the natural exponential function requires maintaining input values within the range of [-1, 1] to ensure accuracy, precise computation of the exponential values can be achieved by normalizing the input data and controlling the magnitudes of weights and biases. Furthermore, because the inputs are confined to [-1, 1], numerical convergence occurs rapidly. Unlike methods based on ln 2, there is no need to restrict the input domain, which consequently conserves memory and reduces computational time.
With FP32 <-> BF16 conversion in 2024 Quiz 1 Problem B
I try to modify 2023 Quiz 1 Problem C and last year student's Homework1 function.
Create BF16 functions to calculate
fmul32
-> bfmul16
imul32
-> imul16
fdiv32
-> bfdiv16
idiv24
-> idiv7
fadd32
-> bfadd16
The structure of the bfloat16 floating-point format is as follows.
┌ sign
│
│ ┌ exponent
│ │
│ │ ┌ mantissa
│ │ │
│┌──┴───┐┌─┴───┐
0b0000000000000000 bfloat16
bf16_t
Data Typetypedef struct {
uint32_t bits;
} bf16_t;
FP32
to BF16
static inline bf16_t fp32_to_bf16(float s)
{
bf16_t h;
union {
float f;
uint32_t i;
} u = {.f = s};
if ((u.i & 0x7fffffff) > 0x7f800000) { /* NaN */
h.bits = (u.i >> 16) | 64; /* force to quiet */
return h;
}
h.bits = (u.i + (0x7fff + ((u.i >> 0x10) & 1))) >> 0x10;
return h;
}
BF16
to FP32
static inline float bf16_to_fp32(bf16_t h)
{
union {
float f;
uint32_t i;
} u = {.i = (uint32_t)h.bits << 16};
return u.f;
}
getbit
uint16_t getbit(uint16_t value, int n)
{
return (value >> n) & 1;
}
I modified the C code from last year student's homework
Change the data type to bf16_t
, learn floating point format
Most of time spend to understand bitwise operation
bfmul16
bf16_t bfmul16(bf16_t a, bf16_t b)
{
uint16_t ia = a.bits;
uint16_t ib = b.bits;
/* mantissa 7 bits */
uint8_t ma = (ia & 0x7F) | 0x80;
uint8_t mb = (ib & 0x7F) | 0x80;
/* exponent 8 bits */
uint8_t exp_a = (ia >> 7) & 0xFF;
uint8_t exp_b = (ib >> 7) & 0xFF;
/* sign */
uint16_t sign = (ia ^ ib) & 0x8000;
/* mantissa multiplication */
/* extend uint8_t ma and mb to uint16_t */
/* only need uint16_t to store the multiplication of ma and mb */
/* the multiplication result won't exceed 16 bits */
uint16_t m = imul16((int16_t)ma, (int16_t)mb);
// uint16_t m = (uint16_t)ma * (uint16_t)mb;
/* normalization */
uint8_t mshift = (m >> 15) & 1; /* check if bit 15 is set */
m >>= mshift; /* shift mantissa right by 1 if overflow */
/* adjust exponent: (exp_a + exp_b - bias) + mshift */
int16_t exp_r = (int16_t)exp_a + (int16_t)exp_b - 127 + (int16_t)mshift;
/* only take care of underflow */
/* aim to process small number */
if (exp_r <= 0) exp_r = 0;
/* combine sign, exponent, and mantissa */
return (bf16_t) {.bits = sign | (((uint16_t)exp_r << 7) & 0x7F80) | ((m >> 7) & 0x7F)};
}
imul16
uint16_t imul16(int16_t a, int16_t b)
{
uint16_t r = 0;
for (int i = 0; i < 8; i++) {
if (getbit(b, i))
r += a << i;
}
return r;
}
Use loop unrolling to reduce branch condition
TODO
bfdiv16
bf16_t bfdiv16(bf16_t a, bf16_t b)
{
uint16_t ia = a.bits;
uint16_t ib = b.bits;
/* divisor = 0 then return infty */
if ((ib & 0x7FFF) == 0) return (bf16_t) {.bits = ((ia ^ ib) & 0x8000) | 0x7F80};
/* dividend = 0 then return 0 */
if ((ia & 0x7FFF) == 0) return (bf16_t) {.bits = (ia ^ ib) & 0x8000};
/* mantissa */
int16_t ma = (ia & 0x7F) | 0x80;
int16_t mb = (ib & 0x7F) | 0x80;
/* sign and mantissa */
int16_t sea = ia & 0xFF80;
int16_t seb = ib & 0xFF80;
/* result of mantissa */
int32_t mantissa = idiv7(ma, mb);
int32_t mshift = !getbit(mantissa, 15);
mantissa <<= mshift;
return (bf16_t){.bits = ((sea - seb + 0x3F80) - (0x8000 & -mshift)) | (mantissa & 0x7E00) >> 9};
}
idiv7
int16_t idiv7(int16_t a, int16_t b){
uint16_t r = 0;
for (int i = 0; i < 16; i++){
r <<= 1;
if (a - b < 0){
a <<= 1;
continue;
}
r |= 1;
a -= b;
a <<= 1;
}
return r;
}
bfadd16
bf16_t bfadd16(bf16_t a, bf16_t b) {
uint16_t ia = a.bits;
uint16_t ib = b.bits;
/* compare thr absolute value */
/* ensure a is bigget than b */
uint16_t cmp_a = ia & 0x7fff;
uint16_t cmp_b = ib & 0x7fff;
if (cmp_a < cmp_b)
iswap(ia, ib);
/* exponent */
uint16_t ea = (ia >> 7) & 0xff;
uint16_t eb = (ib >> 7) & 0xff;
/* 7 bits is the size of bf16_t mantissa */
int16_t mantissa_align_shift = (ea - eb > 7) ? 7 : ea - eb;
/* mantissa */
uint16_t ma = (ia & 0x7F) | 0x80;
uint16_t mb = (ib & 0x7F) | 0x80;
mb >>= mantissa_align_shift;
if((ia ^ ib) >> 15) ma -= mb;
else ma += mb;
int16_t clz = my_clz(ma);
int16_t shift = 0;
if(clz <= 8) {
shift = 8 - clz;
ma >>= shift;
ea += shift;
} else {
shift = clz - 8;
ma <<= shift;
ea -= shift;
}
return (bf16_t){.bits = (ia & 0x8000) | ((ea << 7) & 0x7f80) | (ma & 0x7f)};
}
bfsub16
bf16_t bfsub16(bf16_t a, bf16_t b) {
uint16_t ia = a.bits;
uint16_t ib = b.bits;
/* compare the absolute values */
/* ensure a has greater or equal magnitude than b */
uint16_t cmp_a = ia & 0x7FFF;
uint16_t cmp_b = ib & 0x7FFF;
int swapped = 0;
if (cmp_a < cmp_b) {
iswap(ia, ib);
swapped = 1;
}
/* exponents */
uint16_t ea = (ia >> 7) & 0xFF;
uint16_t eb = (ib >> 7) & 0xFF;
/* 7 bits is the size of bf16_t mantissa */
int16_t mantissa_align_shift = (ea - eb > 7) ? 7 : ea - eb;
/* mantissas */
int16_t ma = (ia & 0x7F) | 0x80;
int16_t mb = (ib & 0x7F) | 0x80;
mb >>= mantissa_align_shift;
/* adjust mantissas based on signs */
/* if sign is same then ma subtract mb directly */
if (!((ia ^ ib) >> 15)) ma -= mb;
else ma += mb;
/* handle negative mantissa result */
int negative_result = 0;
if (ma < 0) {
ma = -ma;
negative_result = 1;
}
/* normalize the result */
int16_t clz = my_clz((uint16_t)ma);
int16_t shift = 0;
if (clz <= 8) {
shift = 8 - clz;
ma >>= shift;
ea += shift;
} else {
shift = clz - 8;
ma <<= shift;
ea -= shift;
}
/* determine the sign of the result */
uint16_t sign_a = ia & 0x8000; /* if swapped ia = original ib */
uint16_t sign_b = ib & 0x8000; /* if swapped ib = original ia */
uint16_t sign;
if (negative_result == 0) sign = (swapped ? sign_b : sign_a);
else sign = (swapped ? sign_b : sign_a) ^ 0x8000; /* Flip the sign */
if(swapped && !(sign_a ^ sign_b)) sign ^= 0x8000; /* need to be optimize */
/* assemble the final result */
return (bf16_t){.bits = sign | ((ea << 7) & 0x7f80) | (ma & 0x7F)};
}
When I solve the bfsub16
bug about sign
I suddenly found that I don't need to implement subtraction when I ready to sleep
If I want to make
Then make addition
I paid lots of time to implement bfsub16
but I think it's a good experience
Think More …
I need some time to figure out how to combine this three line together
It's a good chance to exercise my brain
if (negative_result == 0) sign = (swapped ? sign_b : sign_a);
else sign = (swapped ? sign_b : sign_a) ^ 0x8000; /* Flip the sign */
if(swapped && !(sign_a ^ sign_b)) sign ^= 0x8000; /* need to be optimize */
Emm… The outcome is really weird
input x = 0.500000
my bf16 sigmoid = 0.308594
STL sigmoid approch = 0.622459
Have you checked the error rate accordingly?
In the main.c
random sample float in range
Sample size is 10
The number of terms in the Taylor expansion of the exponential function is 5
MSE : 0.698543
Wait me to fix the bfdiv16
problem
I fix bfmul16
, bfadd16
and bfsub16
successfully
If I fix bfdiv16
successfully, I will come back to fix Sigmoid
The bfmul16
, bfadd16
and bfsub16
output is correct
It seems that the problem is on bfdiv16
term 1 : 0.500000 = 0.500000 / 1.000000
term 2 : 0.125000 = 0.250000 / 2.000000
term 3 : -0.036377 = 0.125000 / 6.000000
term 4 : -0.004547 = 0.062500 / 24.000000
term 5 : -0.000504 = 0.031250 / 120.000000
term 6 : -0.000037 = 0.015625 / 720.000000
term 7 : -0.000002 = 0.007812 / 5040.000000
term 8 : -0.000000 = 0.003906 / 40320.000000
term 9 : -0.000000 = 0.001953 / 362880.000000
I have converted the functional parts into assembly code and executed them on Ripes, learning the actual process of how assembly code runs on the processor. After becoming familiar with bitwise operations, I will come back to fix the functionality of bfdiv16
. I think besides bitwise operations, I also need to better understand the specifications and handling of floating-point numbers
fp32_to_bf16
Use this simple function to learn the instructions flow in each stage
.data
argument: .word 0x3E800000 # float 0.25
.text
_start:
j main
fp32_to_bf16:
la a0, argument # load argument address
lw a1, 0(a0) # load data (float 0.25) from argument
li a2, 0x7F800000 # use to check NaN case
li a3, 0x7FFFFFF # bit mask
and a4, a1, a3 # u.i & 0x7FFFFFFF
blt a2, a4, non_case # Handle NaN case
li a2, 0x7FFF
mv a3, a1
srli a3, a3, 16 # u.i >> 16
andi a3, a3, 1 # (u.i >> 16) & 1
add a3, a3, a2 # 0x7FFF + (u.i >> 16) & 1
add a3, a3, a1 # u.i + (0x7FFF + (u.i >> 16) & 1)
srli a0, a3, 16 # (u.i + (0x7FFF + (u.i >> 16) & 1)) >> 16
ret
non_case:
srli a1, a1, 16 # (u.i >> 16)
li a4, 0x40
or a0, a1, a4 # (u.i >> 16) | 0x40
ret
main:
jal ra, fp32_to_bf16
end:
nop
correspond disassembled code
00000000 <_start>:
0: 0580006f jal x0 88 <main>
00000004 <fp32_to_bf16>:
4: 10000517 auipc x10 0x10000
8: ffc50513 addi x10 x10 -4
c: 00052583 lw x11 0 x10
10: 7f800637 lui x12 0x7f800
14: 080006b7 lui x13 0x8000
18: fff68693 addi x13 x13 -1
1c: 00d5f733 and x14 x11 x13
20: 02e64463 blt x12 x14 40 <non_case>
24: 00008637 lui x12 0x8
28: fff60613 addi x12 x12 -1
2c: 00058693 addi x13 x11 0
30: 0106d693 srli x13 x13 16
34: 0016f693 andi x13 x13 1
38: 00c686b3 add x13 x13 x12
3c: 00b686b3 add x13 x13 x11
40: 0106d513 srli x10 x13 16
44: 00008067 jalr x0 x1 0
00000048 <non_case>:
48: 0105d593 srli x11 x11 16
4c: 04000713 addi x14 x0 64
50: 00e5e533 or x10 x11 x14
54: 00008067 jalr x0 x1 0
00000058 <main>:
58: fadff0ef jal x1 -84 <fp32_to_bf16>
0000005c <end>:
5c: 00000013 addi x0 x0 0
bfadd16
.data
data1: .word 0x3f9f # BF16 1.245
data2: .word 0x4050 # BF16 3.255
ans: .word 0x40900000 # Expect ans 4.5
newline: .string "\n"
str1: .string "data1 : "
str2: .string "data2 : "
str3: .string "bfadd16 output : "
str4: .string "Expect ans : "
.text
bfadd16:
la a0, data1
la a1, data2
# Load input
lh t0, 0(a0) # ia
lh t1, 0(a1) # ib
# create bit mask 0x7fff
li t2, 0x7fff
and t3, t0, t2 # cmp_a = ia & 0x7fff
and t4, t1, t2 # cmp_b = ib & 0x7fff
bge t3, t4, no_swap # if cmp_a >= cmp_b goto noswap
swap:
mv t2, t0
mv t0, t1
mv t1, t2
no_swap:
srli t3, t0, 7 # ia >> 7
srli t4, t1, 7 # ib >> 7
andi t3, t3, 0xff # ea = (ia >> 7) & 0xff
andi t4, t4, 0xff # eb = (ib >> 7) & 0xff
sub t2, t3, t4 # t2 = ea - eb
li t5, 7
bgt t2, t5, mshift_max
j mshift_done
mshift_max:
li t2, 7 # mshift = 7
mshift_done:
andi t5, t0, 0x7f
ori t5, t5, 0x80 # ma = (ia & 0x7f) | 0x80
andi t6, t1, 0x7f
ori t6, t6, 0x80 # mb = (ib & 0x7f) | 0x80
srl t6, t6, t2 # mb >>= mshift
xor a1, t0, t1 # ia ^ ib
srli a1, a1, 15 # (ia ^ ib) >> 15
beqz a1, ma_plus_mb
ma_minus_mb:
sub t5, t5, t6
j ma_mb_done
ma_plus_mb:
add t5, t5, t6
ma_mb_done:
jal ra, clz
li a2, 9
blt a0, a2, shift_in_range # clz <= 8
addi t4, a0, -8 # shift = clz - 8
sll t5, t5, t4 # ma <<= shift
sub t3, t3, t4 # ea -= shift
j return
shift_in_range:
li t4, 8
sub t4, t4, a0 # shift = 8 - clz
srl t5, t5, t4 # ma >>= shift
add t3, t3, t4 # ea += shift
return:
li a0, 0x8000 # create bit mask
and a0, a0, t0 # sign = ia & 0x8000
slli t3, t3, 7 # exponent = ea << 7
li t2, 0x7F80
and t3, t3, t2 # exponent = (ea << 7) & 0x7F80
or a0, a0, t3
li t2, 0x7F
and t5, t5, t2 # mantissa = ma & 0x7F
or a0, a0, t5
slli t0, a0, 16 # extend bf16 to fp32
j print
clz:
li a0, 0 # count = 0
li a1, 0x8000 # create bit mask 0x8000
while_loop:
and a2, t5, a1
bnez a2, clz_done # if((ma & mask) == 0)
addi a0, a0, 1
srli a1, a1, 1 # mask >>= 1
bnez a1, while_loop
clz_done:
ret
print:
la a0, str1 # data1 :
li a7, 4 # ecall 4 print string
ecall
la a0, data1
lw a0, 0(a0) # load data1
slli a0, a0, 16 # Extend bf16 to fp32
li a7, 2 # ecall 2 print float
ecall
la a0, newline
li a7, 4
ecall
la a0, str2 # data2 :
li a7, 4
ecall
la a0, data2
lw a0, 0(a0) # load data2
slli a0, a0, 16 # extend bf16 to fp32
li a7, 2 # ecall 2 print float
ecall
la a0, newline
li a7, 4
ecall
la a0, str3 # bfadd16 output :
li a7, 4
ecall
mv a0, t0 # bfadd16 result
li a7, 2
ecall
la a0, newline
li a7, 4
ecall
la a0, str4 # Expect ans :
li a7, 4
ecall
la a0, ans
lw a0, 0(a0) # load expect fp32 ans
li a7, 2
ecall
data1 : 1.24219
data2 : 3.25
bfadd16 output : 4.46875
Expect ans : 4.5
Tracking the instruction addi x10, x10, -4
in fp32_to_bf16
assembly code
In instruction fetch stage, the address ( 0x00000008
) of the instruction send to Instr. Memory
to fetch instruction, with the 32 bits data ( 0xffc50513
) we know the instruction in this stage is addi x10, x10, -4
, and the PC
+ 4 going to next instruction at ( 0x0000000c
)
screen shot from Lecture video, it is talking about the RV32I instruction layout
The sign extension operation in
Imm.
to extend 12 bits immediate in instruction toint32_t
value
Decode
extract the opecode
and register number from instruction, we can found that the opcode
output is 0x13
for ADDI
( 0b0010011
), with opcode
and instruction we can get immediate
from Imm.
, with sign extension get the 32 bits signed integer -4
in bit pattern is 0xfffffffc
, in the Decode
stage also decode opcode
in Control
block to know what should ALU
do in Execution
stage.
The output for R1 inx
is 0x0a
( 10
) is for source register x10
in instruction, also send the 0x0a
signal to next stage, because the instruction take data from x10
, after the calculation we store the result back into x10
register
Why 0x1c
( 28
) for R2 inx
? need some time to understand the design concept
I think the signal to R2 idx
also used to hazard detection and check next instruction should be forward or not
I write a simple assembly program to understand signal connect to R2 idx
, after instruction lw a1, 0(a0)
is the addi a1, a1, 1
, if there don't has any hazard detection, addi a1, a1, 1
in Execution
stage will cause error because the data hasn't store into register a1
, in this example the signal connect to R2 idx
is 0x0b
.data
argument: .word 0x1
.text
_start:
la a0, argument
lw a1, 0(a0)
addi a1, a1, 1
addi a1, a1, 1
The yellow line connect 0x0b
to Hazard Unit
also store the data to ID/EX
register, when lw x11, 0(x10)
at Execution
stage, 0x0b
will be send to Forwarding Unit
from ID/EX
register
The signal connect to
R2 idx
also connect toHazard Unit
and store the signal toID/EX
register
When lw x11, 0(x10)
forward to Execution
stage, find the ID/EX
register's clear signal is enable, notice the Hazard Unit
lower input id_reg1_idx
( 0b01011
, x11
) is from next instruction at Decode
stage, upper input is from ID/EX
register indicate the register ( 0b01011
, x11
) now is used by lw x11, 0(x10)
, so Hazard Unit
find these two instruction has conflict, so the ID/EX
will be clear and make a stall, wait lw x11, 0(x10)
forward to Data Memory
to read out the data from x10
, also notice the IF/ID
register enable signal is false, so that addi x11, x11, 1
remain stay at Decoding
stage, the next instruction won't advanced from Fetch
stage
I think there has a technique to prevent one more stall called data forwarding
Notice the yellow line forward the data 0x00000001
back to Execution
stage for addi x11, x11, 1
, also notice the Forwarding Unit
signal ( WbStage
) to multiplexer, select the forwarding data from lw x11, 0(x10)
as input data
We can execute two addi
on same register continuously, because we can use data forwarding from Memory
stage, notice the yellow line is the data passed from EX/MEM
register which store the previous instruction output, with Forwarding Unit
signal MemStage
send to Execution
stage multiplexer, the instruction at Execution
stage can use correct data in x11
Data forwarding after
Execution
stage beforeMemory
stage
In Execution
stage, the ALU
behavior is control by the 5 bits control signal ctrl
, before the data send into ALU
, there has two multiplexer with enable signal to select which data should be calculated by ALU
, in this example, the input is 0x10000004
and 0xfffffffc
, the ALU
behavior is Add immediate, the ALU
output is the calculation result ( 0x10000004
+ -4
)
After the calculation result write back to the register destination
x10
, we can check the result is inx10
(0x10000000
=0x10000004
+-4
)
In this example, when instruction addi x10, x10, -4
in Execution
stage, the IMM
control signal is enable, so the ALU
input Op2
select the immediate data 0xfffffffc
( -4
)
From previous simple assembly program, the data value is 0x1
, the input of Data Memory
is Addr.
( 0x10000000
) and Data in
( 0x00000000
) as byte offset, the data is store in address 0x10000000
, actually, in disassembled code, before lw x11, 0(x10)
we need load the address of data to x10
use auipc x10 0x10000
, becaue lw
is I-Type instruction which only has 12 bits to represent immediate
little endian : lower bit store in lower address
Data Memory
get LW
OP signal, Addr.
and offset, so that Data Memory
read out the data 0x00000001
from data segment by memory address
In Write Back
stage also has multiplexer to select the data from Execution
, we found that the number of register 0x0a
( x10
), recall that Decode
stage extract the register destination and the pipeline forward the data to the Write Back
stage, notice that there has a green line is connect back to Register
, the green line signal extracted when the instruction in the Decode
stage, the signal is for write register enable
The data 0x10000000
is connect to Wr data
, and 0x0a
is connect to Wr idx
, Wr En
is connect to the write register enable signal, after this stage we can observe that the x10
register has data 0x10000000
Assume we has a fully connected neural network used to classify a person is male or female by picture, the network input size is
So we actually pay lots of time on activation operation, if we can reduce the overhead of activation function we can accelerate the model.
Sigmoid function output range is
And Sigmoid function is differentiable also a real function for all real input values, so that is good activation function for model to gradient descent, compare with another famous activation function called ReLU, ReLU isn't a bounded activation function, if the input is positive outlier, the output is same value as input, also the output of ReLU isn't probabilistic interpretable
The advantages of Sigmoid
In this part, I don't evaluate the accuracy of BF16 Sigmoid
Because the bfdiv16
has some bugs I need to fix it
I will come back to complete BF16 Sigmoid when I successfully fix bfdiv16
I only test the accuracy of other BF16 functions used in BF16 Sigmoid
If the other function which used in BF16 Sigmoid has high accuracy
I think the BF16 Sigmoid also has high accuracy
Enhance basic BF16 arithmetic functions equal to enhance BF16 Sigmoid
Each array contain 100 sample
bfmul16
with Input Range void testBF16MulMSE(float *arr1, float *arr2) {
float BF16Mul[SAMPLE_SIZE], FP32Mul[SAMPLE_SIZE];
for(int i = 0;i != SAMPLE_SIZE;i++) {
BF16Mul[i] = bf16_to_fp32(bfmul16(fp32_to_bf16(arr1[i]), fp32_to_bf16(arr2[i])));
FP32Mul[i] = arr1[i] * arr2[i];
}
float mse = MSE(BF16Mul, FP32Mul);
printf("bfmul16 MSE : %f\n", mse);
}
The MSE in some condition will output inf
and very large number
I think my bfmul16
has some problems to process overflow
Now I'm try to solve the bug
I found the bug, that is my MSE function :), arithmetic part is correct
bfmul16 MSE : 0.000002
bfadd16
with Input Range Random sample two array each contain 100 float in range [-1, 1]
void testBF16AddMSE(float *arr1, float *arr2) {
float BF16Add[SAMPLE_SIZE], FP32Add[SAMPLE_SIZE];
for(int i = 0;i != SAMPLE_SIZE;i++) {
BF16Add[i] = bf16_to_fp32(bfadd16(fp32_to_bf16(arr1[i]), fp32_to_bf16(arr2[i])));
FP32Add[i] = arr1[i] + arr2[i];
}
float mse = MSE(BF16Add, FP32Add);
printf("bfadd16 MSE : %f\n", mse);
}
bfadd16 MSE : 0.000005
bfsub16
with Input Range void testBF16SubMSE(float *arr1, float *arr2) {
float BF16Sub[SAMPLE_SIZE], FP32Sub[SAMPLE_SIZE];
for(int i = 0;i != SAMPLE_SIZE;i++) {
BF16Sub[i] = bf16_to_fp32(bfsub16(fp32_to_bf16(arr1[i]), fp32_to_bf16(arr2[i])));
FP32Sub[i] = arr1[i] - arr2[i];
}
float mse = MSE(BF16Sub, FP32Sub);
printf("bfsub16 MSE : %f\n", mse);
}
bfsub16 MSE : 0.000005
bfmul16
with Input Range bfmul16 MSE : 0.000651
bfadd16
with Input Range bfadd16 MSE : 0.000133
bfsub16
with Input Range bfsub16 MSE : 0.000114
If a processor only support RV32I means that it doesn't have circuit to perform multiplication and division instructions, we can only use bitwise operation to reach same result. There has three steps we must do before make multiplication and division on two bf16_t
value, there has some overhead because we must process sign, exponent and mantissa by different bit mask and branch condition
bits & 0x8000
-> 0b1_00000000_0000000
bits & 0x7F80
-> 0b0_11111111_0000000
(bits & 0x7F) & 0x80
-> 0x7F = 0b0_00000000_1111111
; 0x80 = 0b0_000000001_0000000
use to mantissa rounding upMy target is to reuse bit mask and reduce branch condition
I will use the assembly code of bfadd16
for testing
Because each BF16 arithmetic function use bit mask and branch condition
So the methods to optimize assembly code is similar
bfadd16
original.data
data1: .word 0x3f9f # BF16 1.245
data2: .word 0x4050 # BF16 3.255
ans: .word 0x40900000 # Expect ans 4.5
.text
bfadd16:
la a0, data1
la a1, data2
# Load input
lh t0, 0(a0) # ia
lh t1, 0(a1) # ib
# create bit mask 0x7fff
li t2, 0x7fff
and t3, t0, t2 # cmp_a = ia & 0x7fff
and t4, t1, t2 # cmp_b = ib & 0x7fff
bge t3, t4, no_swap # if cmp_a >= cmp_b goto noswap
swap:
mv t2, t0
mv t0, t1
mv t1, t2
no_swap:
srli t3, t0, 7 # ia >> 7
srli t4, t1, 7 # ib >> 7
andi t3, t3, 0xff # ea = (ia >> 7) & 0xff
andi t4, t4, 0xff # eb = (ib >> 7) & 0xff
sub t2, t3, t4 # t2 = ea - eb
li t5, 7
bgt t2, t5, mshift_max
j mshift_done
mshift_max:
li t2, 7 # mshift = 7
mshift_done:
andi t5, t0, 0x7f
ori t5, t5, 0x80 # ma = (ia & 0x7f) | 0x80
andi t6, t1, 0x7f
ori t6, t6, 0x80 # mb = (ib & 0x7f) | 0x80
srl t6, t6, t2 # mb >>= mshift
xor a1, t0, t1 # ia ^ ib
srli a1, a1, 15 # (ia ^ ib) >> 15
beqz a1, ma_plus_mb
ma_minus_mb:
sub t5, t5, t6
j ma_mb_done
ma_plus_mb:
add t5, t5, t6
ma_mb_done:
jal ra, clz
li a2, 9
blt a0, a2, shift_in_range # clz <= 8
addi t4, a0, -8 # shift = clz - 8
sll t5, t5, t4 # ma <<= shift
sub t3, t3, t4 # ea -= shift
j return
shift_in_range:
li t4, 8
sub t4, t4, a0 # shift = 8 - clz
srl t5, t5, t4 # ma >>= shift
add t3, t3, t4 # ea += shift
return:
li a0, 0x8000 # create bit mask
and a0, a0, t0 # sign = ia & 0x8000
slli t3, t3, 7 # exponent = ea << 7
li t2, 0x7F80
and t3, t3, t2 # exponent = (ea << 7) & 0x7F80
or a0, a0, t3
li t2, 0x7F
and t5, t5, t2 # mantissa = ma & 0x7F
or a0, a0, t5
slli t0, a0, 16 # extend bf16 to fp32
j done
clz:
li a0, 0 # count = 0
li a1, 0x8000 # create bit mask 0x8000
while_loop:
and a2, t5, a1
bnez a2, clz_done # if((ma & mask) == 0)
addi a0, a0, 1
srli a1, a1, 1 # mask >>= 1
bnez a1, while_loop
clz_done:
ret
done:
nop
li
actually need two instructionsFor example : Bit mask for exponent li t2, 0x7F80
You can check in disassembled code li t2, 0x7F80
actually need 2 instructions
c0: 000083b7 lui x7 0x8
c4: f8038393 addi x7 x7 -128
Recall the I-Type Instruction format, if immediate size is little or equal to 12 bits, li
instruction can be simply convert addi
instruction, but when we use li t2, 0x7F80
to create bit mask, in disassembled code we actually use two instruction lui
and addi
, because 0x7F80
is 16 bits width.
Although we can simply use li
to create a 32 bits value, if we need bit mask which size is exceed 12 bits then we actually use 2 instruction per bit mask, I think we can reuse init bit mask to reduce li
instruction to create a new bit mask again, for example if we want to extract sign, we need create first bit mask by li t2, 0x8000
, after if we need extract exponent, we can simply use addi t3, t2, -128
to get bit mask 0x7F80
for exponent rather than use li
instruction again, so that we can reduce one instruction per bit mask
jal
used to check mantissa shift will waste two cyclesNow I don't figure out a method to optimize branch condition in bfadd16
, Maybe there has a chance to reduce brach condition in count leading zero, in my count leading zero the time complexity is
bfadd16
with bit mask reuse.data
data1: .word 0x3f9f # BF16 1.245
data2: .word 0x4050 # BF16 3.255
ans: .word 0x40900000 # Expect ans 4.5
.text
bfadd16:
la a0, data1
la a1, data2
# Load input
lh t0, 0(a0) # ia
lh t1, 0(a1) # ib
# create bit mask 0x7fff as init bit mask
li a5, 0x7fff
and t3, t0, a5 # cmp_a = ia & 0x7fff
and t4, t1, a5 # cmp_b = ib & 0x7fff
bge t3, t4, no_swap # if cmp_a >= cmp_b goto noswap
swap:
mv t2, t0
mv t0, t1
mv t1, t2
no_swap:
srli t3, t0, 7 # ia >> 7
srli t4, t1, 7 # ib >> 7
andi t3, t3, 0xff # ea = (ia >> 7) & 0xff
andi t4, t4, 0xff # eb = (ib >> 7) & 0xff
sub t2, t3, t4 # t2 = ea - eb
li t5, 7
bgt t2, t5, mshift_max
j mshift_done
mshift_max:
li t2, 7 # mshift = 7
mshift_done:
andi t5, t0, 0x7f
ori t5, t5, 0x80 # ma = (ia & 0x7f) | 0x80
andi t6, t1, 0x7f
ori t6, t6, 0x80 # mb = (ib & 0x7f) | 0x80
srl t6, t6, t2 # mb >>= mshift
xor a1, t0, t1 # ia ^ ib
srli a1, a1, 15 # (ia ^ ib) >> 15
beqz a1, ma_plus_mb
ma_minus_mb:
sub t5, t5, t6
j ma_mb_done
ma_plus_mb:
add t5, t5, t6
ma_mb_done:
jal ra, clz
li a2, 9
blt a0, a2, shift_in_range # clz <= 8
addi t4, a0, -8 # shift = clz - 8
sll t5, t5, t4 # ma <<= shift
sub t3, t3, t4 # ea -= shift
j return
shift_in_range:
li t4, 8
sub t4, t4, a0 # shift = 8 - clz
srl t5, t5, t4 # ma >>= shift
add t3, t3, t4 # ea += shift
return:
- li a0, 0x8000 # create bit mask for sign
+ addi a0, a5, 0x1 # Reuse bit mask
and a0, a0, t0 # sign = ia & 0x8000
slli t3, t3, 7 # exponent = ea << 7
- li t2, 0x7F80 # create bit mask for exponent
+ addi t2, a5, -127 # Reuse bit mask
and t3, t3, t2 # exponent = (ea << 7) & 0x7F80
or a0, a0, t3
andi t5, t5, 0x7F # mantissa = ma & 0x7F
or a0, a0, t5
slli t0, a0, 16 # extend bf16 to fp32
j done
clz:
li a0, 0 # count = 0
- li a1, 0x8000 # create bit mask for count leading zero
+ addi a1, a5, 1 # reuse bit mask
while_loop:
and a2, t5, a1
bnez a2, clz_done # if((ma & mask) == 0)
addi a0, a0, 1
srli a1, a1, 1 # mask >>= 1
bnez a1, while_loop
clz_done:
ret
done:
nop
I have seen there has another method to optimize count leading zero,