Brief topic: Bfloat16 Logarithm
Descriptive topic: Approximation of logarithm of bfloat16 numbers using the RV32I instruction set
contributed by < coding-ray
(黃柏叡) >
Bfloat16
, RISC-V
, RV32I
, Natural Logarithm
, Remez Algorithm
The bfloat16 (brain floating point, bf16) floating-point format occupies only 16 bits in computer memory.
Bfloat16 is in the following format:
In comparison of bf16 to the standard IEEE 754 single-precision 32-bit float (fp32), bf16 has almost the same dynamic range as fp32 has. However, the precision of bf16 is ~65,000x worse than that of fp32. The details are shown in the following table.(John, 2021)
Format | Dynamic range | Precision (\(\epsilon\)) |
---|---|---|
Fp32 | 83.38dB | 0.00000012 |
Bf16 | 78.57dB | 0.00781250 |
The values above are calculated by the following steps.
IEEE Standard for Floating-Point Arithmetic (IEEE 754) defines the format of floating-point numbers.
For a given floating-point format with 1 sign bit \(s\), \(n_e\) exponent bits, and \(n_m\) mantissa (significand) bits, a number in such format is generally as follows.
\[ (-1)^s \times 1.\underbrace{xxx \cdots xxx}_{n_m \text{ bits}} \,_2 \times 2^{(exp - 2^{(n_e - 1)} + 1)} \]
, whose binary representaion (encoding) in computer memory is as follows.
\[ \underbrace{s}_\text{1 bit} \ \ \underbrace{eee \cdots eee}_{exp_2 ,\ n_e \text{ bits}} \ \ \underbrace{xxx \cdots xxx}_{n_m \text{ bits}} \]
, where \(exp_2\) is the right-aligned, binary value of \(exp \in [1, (2^{n_e} - 1)]\).
For \(exp = 0\) with nonzero \(mantissa\), the number is subnormal, generally as follows.
\[ (-1)^s \times 0.\underbrace{xxx \cdots xxx}_{n_m \text{ bits}} \,_2 \times 2^{-126} \]
For \(exp = mantissa = 0\), the number is either positive or negative zero, depending on the sign bit.
\[ (-1)^s \times 0 \]
For \(exp = 2^{n_e}\), there are two cases:
The largest normal value in fp32 format is as follows.
\[ \begin{align} L_\text{fp32} &= 0 \ \underbrace{11111110}_{8 \text{ bits}} \ \underbrace{111 \cdots 111}_{23 \text{ bits}} \\\ &= (2 - 2^{-23}) \times 2^{((2^8 - 2) - 2^7 + 1)} \\\ &= (2 - 2^{-23}) \times 2^{127} \\\ &\approx 3.403 \times 10^{38} \end{align} \]
The smallest positive subnormal number in fp32 format is as follows.
\[ \begin{align} S_\text{fp32} &= 0 \ \underbrace{00000000}_{8 \text{ bits}} \ \underbrace{000 \cdots 001}_{23 \text{ bits}} \\\ &= 2^{-23} \times 2^{-126} \\\ &= 2^{-149} \\\ &\approx 1.401 \times 10^{-45} \end{align} \]
So, the dynamic range of fp32 is
\[ \begin{align} \text{DR}_\text{fp32} & = \log_{10}\frac{L_\text{fp32}}{S_\text{fp32}} \\\ & \approx \log_{10} \frac {3.403 \times 10^{38}} {1.401 \times 10^{-45}} \\\ & \approx 83.39 \text{ dB} \end{align} \]
Similarly, for bf16, we have the following values.
\[ \begin{align} L_\text{bf16} &= 0 \ \underbrace{11111110}_{8 \text{ bits}} \ \underbrace{1111111}_{7 \text{ bits}} \\\ &= (2 - 2^{-7}) \times 2^{((2^8 - 2) - 2^7 + 1)} \\\ &= (2 - 2^{-7}) \times 2^{127} \\\ &\approx 3.390 \times 10^{38} \\\ \ \ \\\ S_\text{bf16} &= 0 \ \underbrace{00000000}_{8 \text{ bits}} \ \underbrace{0000001}_{7 \text{ bits}} \\\ &= 2^{-7} \times 2^{-126} \\\ &= 2^{-133} \\\ &\approx 9.184 \times 10^{-41} \\\ \ \ \\ \text{DR}_\text{bf16} & = \log_{10}\frac{L_\text{bf16}}{S_\text{bf16}} \\\ & \approx \log_{10} \frac {3.390 \times 10^{38}} {9.184 \times 10^{-41}} \\\ & \approx 78.57 \text{ dB} \end{align} \]
(Both sides in the following inequality relations are in the same format of either fp32 or bf16.)
\[ \begin{align} &\text{In fp32, } \underset{\epsilon} {\arg\min} \{ (1 + \epsilon) > 1 \} = 2^{-23} = 0.00000011921 \\\ &\text{In bf16, } \underset{\epsilon} {\arg\min} \{ (1 + \epsilon) > 1 \} = 2^{-7}\,\, = 0.0078125 \end{align} \]
, where \(2^{-23}\) in fp32 and \(2^{-7}\) in bf16 are the values of the least significant bit of mantissa when \(exp = (2^{(n_e - 1)} - 1)\). (John, 2021)
For \(x\) in the range of \((0,2]\), taking natural logarithm on it can be approximated as follows by Taylor series. (Milton et al., 1965; Wikipedia contributors, 2023.)
\[ \begin{align} ln(x) &= \sum_{k=1}^\infty { \frac {(-1)^{k-1} (x-1)^k }{k} } \\ \ &= (x-1) - \frac{(x-1)^2}{2} + \frac{(x-1)^3}{3} - \frac{(x-1)^4}{4} + \frac{(x-1)^5}{5} - \cdots \end{align} \]
\[ \forall x \in (0, 2] \]
Let \(la(x)\) denote the approximation of \(ln(x)\) with the first five terms in Taylor series approximation. That is,
\[ ln(x) \approx la(x) = \sum_{k=1}^5 { \frac {(-1)^{k-1} (x-1)^k }{k} } \]
To decrease computation complexity but losing the ability of parallel computing, expand \(la(x)\) as
\[
la(x) =
-2.28333 + (
5 - (
5 - (
3.33333 - (
1.25 - 0.2 x
) x
) x
) x
) x
\]
Apply this formula for \(x \in [0.1, 2]\), I got
\[ \begin{align} &\text{max(error)} \approx 0.47, \text{ for } x \in [0.1, 2] \\\ &\text{max(error)} \approx 0.09, \text{ for } x \in [1, 2] \end{align} \]
To achieve \(\text{max(error)} < 0.01\), \(x\) must be in the range \([0.44, 1.68]\).
As a result, to apply this algorithm on a float with significand \(s \in [1, 2)\) (by definition) and exponent \(p\), \(ln(x)\) can be approximated as follows.
\[ \begin{align} &\text{Given } x = 2^p \times s, \text{ where } s \in [1, 2) \\\ &ln(x) = p \times ln(2) + ln (s), \text{ where } \\\ &ln(s)\approx \begin{cases} la(s) &,\text{ if } s \le 1.68 \\ la(s/2) + ln(2) &, \text{ otherwise.} \end{cases} \end{align} \]
Note: \(ln(2) = 0.693\) can be stored statically in the memory (data or text section).
I will calculate the approximation of \(ln(x)\) in the bfloat16 (bf16) format, which has \(\epsilon \approx 0.0078\) for precision. This error accumulates as the number of arithmetic operations increases. That is, the maximal error from the arithmetic operations is summed up as
\[ ME(n) = 0.0078 n \]
, where \(n\) is the number of arithmetic operations.
Hence, to further decrease the maximum error, I have to have a polynomial solution of lower orders; moreover, applicable for input \(x \in [1, 2]\) in a single function.
Crouching (2017) utilized the computer algebra program Maple (1982) and the Remez algorithm (Remez, 1934; Wikipedia contributors, 2023) to generate the following 4th-order polynomial approximation of \(ln(x)\) for \(x \in [1, 2]\).
\[ ln(x) \approx -1.7417939 + (2.8212026 + (-1.4699568 + (0.44717955 - 0.056570851 x) x) x) x \]
It achieved
\[ \text{max(error)} = 6.101 \times 10^{-5}, \text{ for } x \in [1, 2] \]
Evil et al. (2021) utilized the Remez algorithm in Boost C++ Libraries (John, 2010) to get the following 3rd-order approximation for \(x \in [1, 2]\).
\[ ln(x) \approx −1.49278 + (2.11263 + (−0.729104 + 0.10969 x) x) x \]
And it achieved
\[ \text{max(error)} = 4.5 \times 10^{-4}, \text{ for } x \in [1, 2] \]
(1) Methods to approximate the natural logarithm \(ln(x) \ \forall x \in [1, 2]\), their own (2) minimal number of arithmetic operations (\(+\ -\ \times\)), and (3) maximal errors introduced from the methods, and (4) from the arithmetic operations are summarized in the following table.
Method | min. # of Arith. Op. | ME (Method) | ME (Arith. Op.) |
---|---|---|---|
5th-order Taylor | 10 | 0.09 | 0.078 |
4th-order Remez | 8 | 0.00006101 | 0.062 |
3rd-order Remez | 6 | 0.00045 | 0.047 |
Referring from the table above, I conclude that the 3rd-order polynomial approximation from the Remez method is the best algorithm in my case, for it gives the minimal overall ME (\(\approx 0.047\)).
The source code is hosted on my GitHub repository. Feel free to fork and modify it.
I adopted the 3rd-order solution from Evil et al. (2021). However, the precision of this algorithm is limited by the precision of bf16, so the following numbers are not as precise as the numbers in the original solution.
\[ \begin{align} &\text{Given } x = 2^p \times s, \text{ where } s \in [1, 2), \\\ &\ \ \ \ \ \ ln(x) \approx lnc0 + (lnc1 + (lnc2 + lnc3 s) s) s + ln2 \times p \end{align} \]
, where
Name | Bf16 Hexadecimal | Bf16 Decimal | Original |
---|---|---|---|
\(lnc0\) | 0xBFBF | \(-1.4921875\) | \(−1.49278\) |
\(lnc1\) | 0x4007 | \(+2.109375\) | \(+2.11263\) |
\(lnc2\) | 0xBF3B | \(-0.73046875\) | \(−0.729104\) |
\(lnc3\) | 0x3DE1 | \(+0.10986328125\) | \(+0.10969\) |
\(ln2\) | 0x3F31 | \(+0.69140625\) | \(+0.69314718\) |
The RV32I instruction set architecture (ISA) does not provide the multiplication (\(mul\)) functionality. Thus, in this section, I implement the multiplication for two unsigned 32-bit integers with the RV32I.
Given multiplier \(a_0\) and multiplicand \(a_1\), their product can be calculated by "summing \(a_1\) times of \(a_0\)". That is,
\[ a_0 \times a_1 = \sum_{i = 1}^{a_1} a_0 \]
# --- mul_sum_u32 ---
# a0 = a0 * a1 with summing a1 times of a0
# both a0 and a1 are unsigned or positive
mul_sum_u32:
addi sp, sp, -4
sw ra, 0(sp)
bge a0, a1, muu_no_swap
# make a1 <= a0
mv t1, a1
mv a1, a0
mv a0, t1
muu_no_swap:
# t0 = result
li t0, 0
addi a1, a1, -1
muu_loop:
add t0, t0, a0
addi a1, a1, -1
bge a1, zero, muu_loop
muu_exit:
mv a0, t0
lw ra, 0(sp)
addi sp, sp, 4
ret
In lines 7-11, I make \(a_1 < a_0\) to decrease the number of instructions with the overhead less than an iteration in the loop.
muu_loop
takes 5 CPU cycles if it branches from line 19 to line 17.To improve the speed on large numbers, I make another implementation of multiplication by shifting operations.
It is binary multiplication as illstrated below.
\[ \begin{split} 110 \\ \times )\ \ \ \ \ \ \ \ 101 \\ \hline 110 \\ 000 \ \ \\ \,+ ) \ \ \ \ 110 \ \ \ \ \\ \hline 11110 \end{split} \]
# --- mul_shift_u32 ---
# binary multiplication of two u32 numbers
# input:
# a0: a (u32): multiplier
# a1: b (u32): multiplicand
# output:
# a0: r (u32): product of a and b (a * b)
# --- mul_sum_u32 ---
# a0 = a0 * a1 with summing a1 times of a0
# both a0 and a1 are unsigned or positive
mul_sum_u32:
addi sp, sp, -4
sw ra, 0(sp)
bge a0, a1, muu_no_swap
# ensure a1 <= a0
mv t1, a1
mv a1, a0
mv a0, t1
muu_no_swap:
li t0, 0 # t0 = result
addi a1, a1, -1
muu_loop:
add t0, t0, a0
addi a1, a1, -1
bge a1, zero, muu_loop
muu_exit:
mv a0, t0
lw ra, 0(sp)
addi sp, sp, 4
ret
Given input number \(n\), the average time complixity for the summing method is \(O(n)\). For the shifting method, the average time complixity is \(O(\lg n)\). By comparing the average time complixity, the shifting method is faster than the summing method when \(n\) is large enough. As for what \(n\) is large enough, I will discuss it in the following paragraphs.
# testing program
main:
# multiply two numbers
li a0, 5
li a1, 5
jal ra, mul_sum_u32 # or mul_shift_u32
# print result
li a7, 1 # to print integer
ecall # a0 = result of mul_sum_u32
# exit program
li a7, 10
ecall
When \(a_0 = a_1 = 5_{10} = 101_2\), summing is better than shifting in comparison of CPU cycles.
When \(a_0 = a_1 = 6_{10} = 110_2\), summing method have additional 5 cycles; shifting method have the same cycles as in the previous case.
By testing all values ranged in \([1, 18]\), the reuslting numbers of CPU cycles for both summing and shifting methods are shown in the following figure.
Referring to the figure above, I have the following two observation, given that \(a_0 = a_1\):
As a result, in the following code that requires the library of unsigned integer multiplication, I will utilize mul_shift_u32
by default.
There is no floating-point operations in the RV32I ISA, and there is no bf16 operations as well. Hence, I have to implement my own bf16 arithmetic operations (\(+ - \times\)) in sections 2.3 and 2.4.
File 1: type_def.h
This file will be used in the following sections.
// type_def.h
#ifndef TYPE_DEF_H
#define TYPE_DEF_H
typedef float bf16;
typedef unsigned int u32;
typedef int i32;
#endif // TYPE_DEF_H
File 2: add_sub_bf16.c
// add_sub_bf16.c
/*
* This program implements and tests the following functionality:
* Addition and subtraction of bfloat16 (bf16) numbers.
*
* Definition of a bfloat16 (bf16) number:
* (1) 1 sign, 8 exp, 7 mantissa (significand) bits, in order.
* (2) Bf16 stored in a 32-bit memory chunk is at the
* highest (most significant) 16 bits.
*
* Reference: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
*
* Version: 0.0
* Tested: 2023-10-07T13:25:00+08:00
*/
#ifndef ADD_SUB_BF16_C
#define ADD_SUB_BF16_C
#include "type_def.h"
/* Addition or subtraction of two bf16 numbers.
* Returns (a + b) or (a - b), depends on whether to_add.
*
* Input format:
* a: bf16
* b: bf16
* to_add: 1 for addition, 0 for subtraction
* Output format: bf16
*/
bf16 add_sub_bf16(bf16 a, bf16 b, int to_add) {
u32 ba = *(u32 *)&a;
u32 bb = *(u32 *)&b;
u32 sa = ba & 0x80000000;
u32 sb = bb & 0x80000000;
i32 ea = ((ba & 0x7F800000) >> 23) - 127;
i32 eb = ((bb & 0x7F800000) >> 23) - 127;
i32 ma = ((ba & 0x007F0000) >> 16) | 0x0080;
i32 mb = ((bb & 0x007F0000) >> 16) | 0x0080;
u32 s = 0; // result sign (1 = negative, 0 = positive)
u32 e = 0; // result exponent
i32 m = 0; // result mantissa
// normalization: make 2 numbers have the same exponent
if (ea >= eb) {
e = ea;
mb >>= (ea - eb); // arithmetic right shift
} else {
e = eb;
ma >>= (eb - ea); // arithmetic right shift
}
// addition or subtraction;
// make abs(m) <= 0x1FE by implementation.
// note: negating numbers has to be postponed to after normalization.
// otherwise, it will lead to an error of -1 in mantissa.
ma = (sa != 0) ? -ma : ma;
mb = (sb != 0) ? -mb : mb;
mb = (to_add == 0) ? -mb : mb;
m = ma + mb;
// handle negative result
if (m < 0) {
m = -m;
s = 1;
} else
s = 0;
// handle carry bit; make m <= 0xFF
if (m & 0x100) {
m >>= 1;
e += 1;
}
// handle result of 0
if (m == 0) {
e = -127;
} else {
// handle result < 1
while (m < 0x80) {
e -= 1;
m <<= 1;
}
}
s = s << 31;
e = (e + 127) << 23;
m = (m & 0x7F) << 16;
u32 r = s | e | m;
return *(bf16 *)&r;
}
File 3: main.c
This file contains the main and testing functions to test the funcitonalities of the above program
// main.c
#include "type_def.h"
#include "add_sub_bf16.c"
// lib for tests
#include <stdio.h> // puts, printf
/* Test the functionalities in this unit.
* Return 0 if successes. Otherwise, return a non-zero number,
* which indicates the first failed test.
*/
int test_add_sub_bf16() {
bf16 a, b, r;
u32 s;
u32 *pa = (u32 *)&a;
u32 *pb = (u32 *)&b;
u32 *pr = (u32 *)&r;
// 1: add, a > 0, b > 0, exp_a == exp_b, exp carry
*pa = 0x3F9A0000; // 0 01111111 0011010
*pb = 0x3FB30000; // 0 01111111 0110011
s = 0x40260000; // 0 10000000 0100110
r = add_sub_bf16(a, b, 1);
if (*pr != s) return 1;
// 2: add, a > 0, b > 0, exp_a < exp_b , no exp carry
*pa = 0x3F9A0000; // 0 01111111 0011010
*pb = 0x40140000; // 0 10000000 0010100
s = 0x40610000; // 0 10000000 1100001
r = add_sub_bf16(a, b, 1);
if (*pr != s) return 2;
// 3: add, a > 0, b > 0, exp_a > exp_n, exp carry
*pa = 0x40410000; // 0 10000000 1000001
*pb = 0x3FFF0000; // 0 01111111 1111111
s = 0x40A00000; // 0 10000001 0100000
r = add_sub_bf16(a, b, 1);
if (*pr != s) return 3;
// 4: add, a < 0, b > 0, (a + b) < 0, exp decreases
*pa = 0xC0410000; // 1 10000000 1000001
*pb = 0x3FFF0000; // 0 01111111 1111111
s = 0xBF840000; // 1 01111111 0000100
r = add_sub_bf16(a, b, 1);
if (*pr != s) return 4;
// 5: sub, a = b
*pa = 0x40000000; // 0 10000000 0000000
*pb = 0x40000000; // 0 10000000 0000000
s = 0; // 0 00000000 0000000
r = add_sub_bf16(a, b, 0);
if (*pr != s) return 5;
// 6: sub, a > 0, b > 0, (a - b) > 0, no exp carry
*pa = 0x40450000; // 0 10000000 1000101
*pb = 0x3F7F0000; // 0 01111110 1111111
s = 0x40060000; // 0 10000000 0000110
r = add_sub_bf16(a, b, 0);
if (*pr != s) return 6;
// 7: sub, a < 0, b > 0, exp_a < exp_b, exp carry
*pa = 0xBFC00000; // 1 01111111 1000000
*pb = 0x40400000; // 0 10000000 1000000
s = 0xC0900000; // 1 10000001 0010000
r = add_sub_bf16(a, b, 0);
if (*pr != s) return 7;
return 0;
}
int main() {
int error_code = test_add_sub_bf16();
if (error_code == 0) {
puts("Test for add_sub_bf16.c passed.");
return 0;
} else {
printf("Test %d for add_sub_bf16.c failed.\n", error_code);
return 1;
}
}
Result: Test for add_sub_bf16.c passed.
The major features in v0.1 of add_sub_bf16.c
are the addition and subtraction wrappers add_bf16
and sub_bf16
. Utilizing these wrappersto call add_sub_bf16
have the overhead of calling a function, and that of adding one argument to the function. Nevertheless, developers don't need to add the to_add
argument each time calling add_bf16
or sub_bf16
.
// add_sub_bf16.c
// ...
/* Addition of two bf16 numbers.
* Returns (a + b).
*
* Input format:
* a: bf16
* b: bf16
* Output format: bf16
*/
bf16 add_bf16(bf16 a, bf16 b) { return add_sub_bf16(a, b, 1); }
/* Subtraction of two bf16 numbers.
* Returns (a - b).
*
* Input format:
* a: bf16
* b: bf16
* Output format: bf16
*/
bf16 sub_bf16(bf16 a, bf16 b) { return add_sub_bf16(a, b, 0); }
// ...
# This program implements and tests bf16 additions
# and subtraction.
#
# For including as a library, include only codes in
# the "Library" section.
#
# Version: 0.1.0
# Tested: 2023-10-09T11:32:00+08:00
.text
# ┌-------------------------------------------------------┐
# | Testing Suite |
# └-------------------------------------------------------┘
main:
# test all functionalities
jal ra, add_sub_bf16_test
# print result
li a7, 1 # to print integer
ecall # a0 = 0 for success, or non-zero for index of failed test
# exit program
li a7, 10
ecall
# --- add_sub_bf16_test ---
# test the functionalities of add_sub_bf16
# input: nothing
# output:
# a0: error_code: 0 for success
# otherwise, index of the first failed test
add_sub_bf16_test:
asbt_prologue:
addi sp, sp -4
sw ra, 0(sp)
asbt_t1:
li a0, 0x3F9A0000
li a1, 0x3FB30000
li a2, 1
jal ra, add_sub_bf16
li t0, 0x40260000
li t1, 1 # error code
bne t0, a0, asbt_epilogue
asbt_t2:
li a0, 0x3F9A0000
li a1, 0x40140000
jal ra, add_bf16
li t0, 0x40610000
li t1, 2 # error code
bne t0, a0, asbt_epilogue
asbt_t3:
li a0, 0x40410000
li a1, 0x3FFF0000
jal ra, add_bf16
li t0, 0x40A00000
li t1, 3 # error code
bne t0, a0, asbt_epilogue
asbt_t4:
li a0, 0xC0410000
li a1, 0x3FFF0000
jal ra, add_bf16
li t0, 0xBF840000
li t1, 4 # error code
bne t0, a0, asbt_epilogue
asbt_t5:
li a0, 0x40000000
li a1, 0x40000000
jal ra, sub_bf16
li t0, 0
li t1, 5 # error code
bne t0, a0, asbt_epilogue
asbt_t6:
li a0, 0x40450000
li a1, 0x3F7F0000
jal ra, sub_bf16
li t0, 0x40060000
li t1, 6 # error code
bne t0, a0, asbt_epilogue
asbt_t7:
li a0, 0xBFC00000
li a1, 0x40400000
jal ra, sub_bf16
li t0, 0xC0900000
li t1, 7 # error code
bne t0, a0, asbt_epilogue
asbt_all_passed:
li t1, 0
asbt_epilogue:
mv a0, t1 # error code
lw ra, 0(sp)
addi sp, sp, 4
ret
# ┌-------------------------------------------------------┐
# | Library |
# └-------------------------------------------------------┘
# --- add_sub_bf16 ---
# addition or subtraction of two bf16 numbers
# input:
# a0: a (bf16): add/sub candidate
# a1: b (bf16): add/sub candidate
# a2: to_add (int): 1 for addition; 0 for subtraction
# output:
# a0: r (bf16): result of (a + b) or (a - b)
# notes:
# t0: sa, s
# t1: sb
# t2: ea, e
# t3: eb
# t4: ma, m
# t5: mb
# t6: (always temp)
add_sub_bf16:
asb_prologue:
addi sp, sp, -4
sw ra, 0(sp)
asb_body:
# extract expoent and mantissa from a and b
li t6, 0x7F800000
and t2, a0, t6 # ea
srli t2, t2, 23
addi t2, t2, -127
li t6, 0x7F800000
and t3, a1, t6 # eb
srli t3, t3, 23
addi t3, t3, -127
li t6, 0x007F0000
and t4, a0, t6 # ma
srli t4, t4, 16
ori t4, t4, 0x80
li t6, 0x007F0000
and t5, a1, t6 # mb
srli t5, t5, 16
ori t5, t5, 0x80
# normalization: make 2 numbers have the same exponent
blt t2, t3, asb_normalization_1
mv t6, t2 # t6 = ea
sub t2, t2, t3 # t2 = ea - eb
srl t5, t5, t2 # mb >>= t2
mv t2, t6 # e = t6
j asb_normalization_end
asb_normalization_1:
mv t6, t3 # t6 = eb
sub t2, t3, t2 # t2 = ea - eb
srl t4, t4, t2 # ma >>= t2
mv t2, t6 # e = t6
asb_normalization_end:
# addition or subtraction
li t6, 0x80000000
and t0, a0, t6 # sa
beqz t0, asb_not_invert_ma
sub t4, zero, t4
asb_not_invert_ma:
li t6, 0x80000000
and t1, a1, t6 # sb
beqz t1, asb_not_invert_mb_1
sub t5, zero, t5
asb_not_invert_mb_1:
bnez a2, asb_not_invert_mb_2
sub t5, zero, t5
asb_not_invert_mb_2:
add t4, t4, t5 # m = ma + mb
# handle negative result
li t0, 0
bgez t4, asb_positive_m
sub t4, zero, t4
li t0, 1
asb_positive_m:
# handle carry bit
andi t5, t4, 0x100
beqz t5, asb_no_carry
srli t4, t4, 1
addi t2, t2, 1
asb_no_carry:
# handle result of 0
li t5, 0x80
bnez t4, asb_small
li t2, -127 # e = -127
j asb_small_end
asb_small:
bge t4, t5, asb_small_end # while (m < 0x80)
addi t2, t2, -1 # e -= 1
slli t4, t4, 1 # m <<= 1
j asb_small
asb_small_end:
# construct the result
slli t0, t0, 31 # s = s << 31
addi t2, t2, 127 # e = (e + 127) << 23
slli t2, t2, 23
andi t4, t4, 0x7F # m = (m & 0x7F) << 16
slli t4, t4, 16
or a0, t0, t2 # r = s | e | m
or a0, a0, t4
asb_epilogue:
lw ra, 0(sp)
addi sp, sp, 4
ret
# --- add_bf16 ---
# addition of two bf16 numbers.
# input:
# a0: a (bf16): addition candidate
# a1: b (bf16): addition candidate
# output:
# a0: r (bf16): reslut of (a + b)
add_bf16:
addi sp, sp, -4
sw ra, 0(sp)
li a2, 1
jal ra, add_sub_bf16
lw ra, 0(sp)
addi sp, sp, 4
ret
# --- sub_bf16 ---
# subtraction of two bf16 numbers.
# input:
# a0: a (bf16): subtraction candidate
# a1: b (bf16): subtraction candidate
# output:
# a0: r (bf16): reslut of (a - b)
sub_bf16:
addi sp, sp, -4
sw ra, 0(sp)
li a2, 0
jal ra, add_sub_bf16
lw ra, 0(sp)
addi sp, sp, 4
ret
File 1: mul_bf16.c
// mul_bf16.c
/*
* This program implements and tests the following functionality:
* Multiplication of bf16 numbers
*
* Definition of a bfloat16 (bf16) number:
* (1) 1 sign, 8 exp, 7 mantissa (significand) bits, in order.
* (2) Bf16 stored in a 32-bit memory chunk is at the
* highest (most significant) 16 bits.
*
* Reference: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
*
* Version: 0.0
* Tested: 2023-10-07T14:09:00+08:00
*/
#ifndef MUL_BF16_C
#define MUL_BF16_C
#include "type_def.h"
/* Multiply of two bf16 numbers.
* Returns (a * b).
*
* Input format:
* a: bf16
* b: bf16
* Output format: bf16
*/
bf16 mul_bf16(bf16 a, bf16 b) {
u32 ba = *(u32 *)&a;
u32 bb = *(u32 *)&b;
u32 sa = ba & 0x80000000;
u32 sb = bb & 0x80000000;
i32 ea = ((ba & 0x7F800000) >> 23) - 127;
i32 eb = ((bb & 0x7F800000) >> 23) - 127;
i32 ma = ((ba & 0x007F0000) >> 16) | 0x0080;
i32 mb = ((bb & 0x007F0000) >> 16) | 0x0080;
u32 s = sa ^ sb; // result sign (1 = negative, 0 = positive)
u32 e = ea + eb; // result exponent
i32 m = (ma * mb) >> 7; // result mantissa
// notes:
// * 0x4000 <= (ma * mb) <= 0xFE01
// * 0x80 m <= 0x1FC
// handle carry bit; make m <= 0xFF
if (m & 0x100) {
m >>= 1;
e += 1;
}
// handle result of 0
if (m == 0) {
e = -127;
}
s = s << 31;
e = (e + 127) << 23;
m = (m & 0x7F) << 16;
u32 r = s | e | m;
return *(bf16 *)&r;
}
#endif // MUL_BF16_C
File 2: main.c
// main.c
#include <stdio.h> // puts, printf
#include "type_def.h"
#include "mul_bf16.c"
/* Test the functionalities in this unit.
* Return 0 if successes. Otherwise, return a non-zero number,
* which indicates the first failed test.
*/
int test_mul_bf16() {
bf16 a, b, r;
u32 s;
u32 *pa = (u32 *)&a;
u32 *pb = (u32 *)&b;
u32 *pr = (u32 *)&r;
// 1: a = b = 1
*pa = 0x3F800000; // 0 01111111 0000000
*pb = 0x3F800000; // 0 01111111 0000000
s = 0x3F800000; // 0 10000000 0100110
r = mul_bf16(a, b);
if (*pr != s) return 1;
// 2: a = 0.5, b = 4
*pa = 0x3F000000; // 0 01111110 0000000
*pb = 0x40800000; // 0 10000001 0000000
s = 0x40000000; // 0 10000000 0000000
r = mul_bf16(a, b);
if (*pr != s) return 2;
// 3: a < 0, b > 0, mantissa carries
*pa = 0xBF400000; // 1 01111110 1000000
*pb = 0x40B00000; // 0 10000001 0110000
s = 0x40840000; // 0 10000001 0000100
r = mul_bf16(a, b);
if (*pr != s) return 3;
return 0;
}
int main() {
int error_code = test_mul_bf16();
if (error_code == 0) {
puts("Test for mul_bf16.c passed.");
return 0;
} else {
printf("Test %d for mul_bf16.c failed.\n", error_code);
return 1;
}
}
In mul_bf16
funciton in mul_bf16.c
, the major features are as follows.
if (ba == 0 || bb == 0) return 0;
if (m == 0) {
e = -127;
s = s << 31;
return *(bf16 *)&s;
}
// 4: a = 4, b = 0
*pa = 0x40800000; // 0 10000001 0000000
*pb = 0; // 0 00000000 0000000
s = 0; // 0 00000000 0000000
r = mul_bf16(a, b);
if (*pr != s) return 4;
# This program implements and tests multiplication
# of bf16 numbers.
#
# For including as a library, include only codes in…
# (1) the "Required Library" sections, and
# (2) the "Library" section.
#
# Library dependency graph:
# mul_shift_u32 -> **mul_bf16**
#
# Version: 0.1.0
# Tested: 2023-10-09T12:37:00+08:00
.text
# ┌-------------------------------------------------------┐
# | Testing Suite |
# └-------------------------------------------------------┘
main:
# test all functionalities
jal ra, mul_bf16_test
# print result
li a7, 1 # to print integer
ecall # a0 = 0 for success, or non-zero for index of failed test
# exit program
li a7, 10
ecall
# --- mul_bf16_test ---
# test the functionalities of mul_bf16
# input: nothing
# output:
# a0: error_code: 0 for success
# otherwise, index of the first failed test
mul_bf16_test:
mbt_prologue:
addi sp, sp -4
sw ra, 0(sp)
mbt_t1:
li a0, 0x3F800000
li a1, 0x3F800000
jal ra, mul_bf16
li t0, 0x3F800000
li t1, 1 # error code
bne t0, a0, mbt_epilogue
mbt_t2:
li a0, 0x3F000000
li a1, 0x40800000
jal ra, mul_bf16
li t0, 0x40000000
li t1, 2 # error code
bne t0, a0, mbt_epilogue
mbt_t3:
li a0, 0xBF400000
li a1, 0x40B00000
jal ra, mul_bf16
li t0, 0xC0840000
li t1, 3 # error code
bne t0, a0, mbt_epilogue
mbt_t4:
li a0, 0x40800000
li a1, 0
jal ra, mul_bf16
li t0, 0
li t1, 4 # error code
bne t0, a0, mbt_epilogue
mbt_all_passed:
li t1, 0
mbt_epilogue:
mv a0, t1 # error code
lw ra, 0(sp)
addi sp, sp, 4
ret
# ┌-------------------------------------------------------┐
# | Required Library - mul_shift_u32 v0.0.0 |
# └-------------------------------------------------------┘
# --- mul_shift_u32 ---
# binary multiplication of two u32 numbers
# input:
# a0: a (u32): multiplier
# a1: b (u32): multiplicand
# output:
# a0: r (u32): product of a and b (a * b)
mul_shift_u32:
mhu_prologue:
addi sp, sp, -4
sw ra, 0(sp)
bge a0, a1, mhu_no_swap
# make a1 <= a0
addi t0, a1, 0
mv a1, a0
mv a0, t0
mhu_no_swap:
# binary multiplication of t0 = a0 * a1
addi t0, zero, 0 # t0 = result
mhu_loop:
beq a1, zero, mhu_epilogue
andi t2, a1, 1 # the least significant bit of a1
beq t2, zero, mhu_next
add t0, t0, a0
mhu_next:
slli a0, a0, 1
srli a1, a1, 1
j mhu_loop
mhu_epilogue:
mv a0, t0
lw ra, 0(sp)
addi sp, sp, 4
ret
# ┌-------------------------------------------------------┐
# | Library |
# └-------------------------------------------------------┘
# --- mul_bf16 ---
# multiplication of two bf16 numbers
# input:
# a0: a (bf16): multiplier
# a1: b (bf16): multiplicand
# output:
# a0: m, r (bf16): product of a and b (a * b)
# notes:
# s0: s
# s1: e
# t0: sa
# t1: sb
# t2: ea
# t3: eb
# t4: ma
# t5: mb
mul_bf16:
mb_prologue:
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
mb_body:
beqz a0, mb_epilogue
bnez a1, mb_nonzero_input
mv a0, zero
j mb_epilogue
mb_nonzero_input:
# extract sign, exponent and mantissa of a and b
sltz t0, a0 # sa
sltz t1, a1 # sb
li t3, 0x7F800000
and t2, a0, t3
srli t2, t2, 23
addi t2, t2, -127 # ea
and t3, a1, t3
srli t3, t3, 23
addi t3, t3, -127 # eb
li t5, 0x007F0000
and t4, a0, t5
srli t4, t4, 16
ori t4, t4, 0x80 # ma
and t5, a1, t5
srli t5, t5, 16
ori t5, t5, 0x80 # mb
# calculate the initial result
xor s0, t0, t1 # s = sa ^ sb
add s1, t2, t3 # e = ea + eb
mv a0, t4
mv a1, t5
jal ra, mul_shift_u32
srli a0, a0, 7 # m = (ma * mb) >> 7
# handle carry bit
andi t1, a0, 0x100
beqz t1, mb_no_carry
srli a0, a0, 1
addi s1, s1, 1
mb_no_carry:
# handle result of +-0
bnez a0, mb_nonzero_result
slli a0, s0, 31 # r = s << 31
j mb_epilogue
mb_nonzero_result:
# construct the result
slli s0, s0, 31 # s = s << 31
addi s1, s1, 127
slli s1, s1, 23 # e = (e + 127) << 23
andi a0, a0, 0x7F
slli a0, a0, 16 # m = (m & 0x7F) << 16
or a0, a0, s0
or a0, a0, s1 # r = s | e | m
mb_epilogue:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
ret
Evil et al. (2021) provided the implementation in C for the above-mentioned 3rd-order polynomial approximation for \(ln(x)\), whereas his implementation can only deal with fp32 numbers.
As a result, I modified his code in order to apply it on bf16 numbers. (In the following code, I also added some of his comments and my refinement.)
File 1: ln_bf16.c
/*
* This program implements and tests the following functionality:
* Natural logarithm of fp32 and bf16 numbers.
*
* Version: 0.2
* Tested: 2023-10-07T22:39:00+08:00
*/
#include <math.h>
#include <stdio.h>
#include <string.h>
#include "fp32_bf16.c"
#include "add_sub_bf16.c"
#include "i32_bf16.c"
#include "mul_bf16.c"
#include "type_def.h"
/* ln(abs(x))
* Returns ln(abs(x)),
* which is calculated by the 3rd-order polynomial approximation
* obtained by the Remez algorithm.
*
* Input format: bf16
* Output format: bf16
*
* This function only works in a 32-bit runtime.
*/
bf16 ln_bf16(bf16 x) {
// constants for bf16 in the precision of bf16
const u32 u_lnc0 = 0xBFBF0000; // -1.49
const u32 u_lnc1 = 0x40070000; // 2.11
const u32 u_lnc2 = 0xBF3B0000; // -0.73
const u32 u_lnc3 = 0x3DE10000; // 0.109
const u32 u_ln2 = 0x3F310000; // 0.69
// constants for this function in the precision of bf16
const bf16 lnc0 = *(bf16 *)&u_lnc0; // -1.49
const bf16 lnc1 = *(bf16 *)&u_lnc1; // 2.11
const bf16 lnc2 = *(bf16 *)&u_lnc2; // -0.73
const bf16 lnc3 = *(bf16 *)&u_lnc3; // 0.109
const bf16 ln2 = *(bf16 *)&u_ln2; // 0.69
u32 *px = (u32 *)&x;
// remove extra bits (otherwise, offset-by-one bug occurs)
*px = *px & 0x7FFF0000;
// catch zero
if (*px == 0) {
*px = 0xFF800000;
return *(bf16 *)px;
}
bf16 exp = i32_to_bf16(((*px & 0x7F800000) >> 23) - 127);
// set x's exponent to 0, which is 127 after normalization.
*px = 0x3F800000 | (*px & 0x7F0000);
// return lnc0 + (lnc1 + (lnc2 + lnc3 * x) * x) * x + ln2 * exp;
bf16 t;
t = add_bf16(lnc2, mul_bf16(lnc3, x)); // t = lnc2 + lnc3 * x
t = add_bf16(lnc1, mul_bf16(t, x)); // t = lnc1 + t * x
t = add_bf16(lnc0, mul_bf16(t, x)); // t = lnc0 + t * x
t = add_bf16(t, mul_bf16(ln2, exp)); // t = t + ln2 * exp
return t;
}
File 2: main.c
/*
* if x = 0, return 0
* if 0 < x <= 9, return 10
* if 10 < x <= 99, return 100
* ...
*
* Note: a u32 number is less than or equal to 4,294,967,295.
*/
u32 get_smallest_decimal_number(u32 x) {
const u32 one_billion = 1000000000;
if (x > one_billion) return one_billion;
if (x == 0) return 0;
u32 divisor = 10;
while (x / divisor != 0) divisor *= 10;
return divisor;
}
/* Test the functionalities in this unit.
* Return 0 if successes. Otherwise, return a non-zero number,
* which indicates the first failed test.
*/
int test_ln_bf16(float n_rows, float *average_error, float *maximal_error) {
n_rows = roundf(n_rows);
*average_error = 0;
*maximal_error = 0;
float error_code_multiplier = (float)get_smallest_decimal_number((u32)n_rows);
float step = 2.0 / n_rows;
for (float f = 0; f <= 2.0001; f += step) {
float t = logf(f);
bf16 rb = ln_bf16(fp32_to_bf16(f));
bf16 error = fabsf(t - rb);
if (isfinite(error)) {
*average_error += error / n_rows;
if (error > 0.05) // something is wrong.
return (int)(f * error_code_multiplier);
if (error > *maximal_error) *maximal_error = error;
}
}
return 0;
}
int main() {
float average_error = 0, maximal_error = 0;
int error_code = test_ln_bf16(40, &average_error, &maximal_error);
if (error_code == 0) {
puts("Test for ln_bf16.c passed.");
printf("Average error: %.3f\n", average_error);
printf("Maximal error: %.3f\n", maximal_error);
return 0;
} else {
printf("Test %.2f for ln_bf16.c failed.\n", error_code / 100.0);
return 1;
}
}
Result:
Test for ln_bf16.c passed.
Average error: 0.007
Maximal error: 0.031
# This program implements natural logarithm of bf16 numbers.
#
# ... (omitted)
# Library dependency graph:
# add_sub_bf16 ↘
# mul_shift_u32 -> mul_bf16 ----> **ln_bf16**
# i32_bf16 ↗
#
# Version: 0.2.0
# Tested: 2023-10-00T12:37:00+08:00
.text
# ... (omitted)
# ┌-------------------------------------------------------┐
# | Library |
# └-------------------------------------------------------┘
# --- ln_bf16 ---
# return ln(abs(x))
# input:
# a0: x (bf16): number to transform
# output:
# a0: t (bf16): result of ln(abs(x))
# notes:
# s0: x, t (around last mul_bf16)
# s1: exp
# reference: ln_bf16.c
ln_bf16:
lb_prologue:
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
lb_body:
# remove extra bits (otherwise, offset-by-one bug occurs)
li t0, 0xFFFF0000
and a0, a0, t0
# catch zero
bnez a0, lb_nonzero_input
li a0, 0xFF800000
j lb_epilogue
lb_nonzero_input:
mv s0, a0 # s0 = x, for x will be used later
li t0, 0x7F800000
and a0, s0, t0 # a0 = *px & 0x7F800000
srli a0, a0, 23
addi a0, a0, -127 # a0 = (a0 >> 23) - 127
jal i32_to_bf16 # a0 = (bf16) a0
mv s1, a0 # exp = a0
# set x's exponent to 0
li t1, 0x7F0000
li t2, 0x3F800000
and s0, s0, t1
or s0, s0, t2 # x = 0x3F800000 | (*px & 0x7F0000)
# calculate result (t)
mv a0, s0 # a0 = x
li a1, 0x3DE10000 # lnc3 = 0.109
jal mul_bf16 # a0 = lnc3 * x
li a1, 0xBF3B0000 # lnc2 = -0.73
jal add_bf16 # a0 = a0 + lnc2
mv a1, s0 # a1 = x
jal mul_bf16 # a0 = a0 * x
li a1, 0x40070000 # lnc1 = 2.11
jal add_bf16 # a0 = a0 + lnc1
mv a1, s0 # a1 = x
jal mul_bf16 # a0 = a0 * x
li a1, 0xBFBF0000 # lnc0 = -1.49
jal add_bf16 # a0 = a0 + lnc0
mv s0, a0 # s0 = t
li a0, 0x3F310000 # ln2 = 0.69
mv a1, s1 # a1 = exp
jal mul_bf16 # a0 = ln2 * exp
mv a1, s0 # a1 = t
jal add_bf16 # a0 = a0 + t (result)
lb_epilogue:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
ret
# This program implements and tests natural logarithm of bf16 numbers.
#
# For including as a library, include only codes in…
# (1) all of the "Required Library" sections, and
# (2) the "Library" section.
#
# Library dependency graph:
# add_sub_bf16 ↘
# mul_shift_u32 -> mul_bf16 ----> **ln_bf16**
# i32_bf16 ↗
#
# Version: 0.2.0
# Tested: 2023-10-00T12:37:00+08:00
.text
# ┌-------------------------------------------------------┐
# | Testing Suite |
# └-------------------------------------------------------┘
main:
# test all functionalities
jal ra, ln_bf16_test
# print result
li a7, 1 # to print integer
ecall # a0 = 0 for success, or non-zero for index of failed test
# exit program
li a7, 10
ecall
# --- ln_bf16_test ---
# test the functionalities of mul_bf16
# input: nothing
# output:
# a0: error_code: 0 for success
# otherwise, index of the first failed test
# notes:
# the solution for the result of ln_bf16(x) is the result
# from the C program, for the values from both should be
# identical under the same implementation
ln_bf16_test:
lbt_prologue:
addi sp, sp -4
sw ra, 0(sp)
lbt_t1:
li a0, 0x00000000 # 0.00
jal ra, ln_bf16
li t0, 0xFF800000 # -inf
li t1, 1 # error code
bne t0, a0, lbt_epilogue
lbt_t2:
li a0, 0x3D4D0000 # 0.05
jal ra, ln_bf16
li t0, 0xC03E0000 # -2.969
li t1, 2 # error code
bne t0, a0, lbt_epilogue
lbt_t3:
li a0, 0x3E1A0000 # 0.15
jal ra, ln_bf16
li t0, 0xBFF20000 # -1.891
li t1, 3 # error code
bne t0, a0, lbt_epilogue
lbt_t4:
li a0, 0x3E4D0000 # 0.20
jal ra, ln_bf16
li t0, 0xBFCA0000 # -1.578
li t1, 4 # error code
bne t0, a0, lbt_epilogue
lbt_t5:
li a0, 0x3F260000 # 0.65
jal ra, ln_bf16
li t0, 0xBEDA0000 # -0.426
li t1, 5 # error code
bne t0, a0, lbt_epilogue
lbt_t6:
li a0, 0x3F800000 # 1.00
jal ra, ln_bf16
li t0, 0x3C000000 # 0.008
li t1, 6 # error code
bne t0, a0, lbt_epilogue
lbt_t7:
li a0, 0x40000000 # 2.00
jal ra, ln_bf16
li t0, 0x3F330000 # -0.006
li t1, 7 # error code
bne t0, a0, lbt_epilogue
lbt_all_passed:
li t1, 0
lbt_epilogue:
mv a0, t1 # error code
lw ra, 0(sp)
addi sp, sp, 4
ret
# ┌-------------------------------------------------------┐
# | Required Library - add_sub_bf16 v0.1.0 |
# └-------------------------------------------------------┘
# --- add_sub_bf16 ---
# addition or subtraction of two bf16 numbers
# input:
# a0: a (bf16): add/sub candidate
# a1: b (bf16): add/sub candidate
# a2: to_add (int): 1 for addition; 0 for subtraction
# output:
# a0: r (bf16): result of (a + b) or (a - b)
# notes:
# t0: sa, s
# t1: sb
# t2: ea, e
# t3: eb
# t4: ma, m
# t5: mb
# t6: (always temp)
add_sub_bf16:
asb_prologue:
addi sp, sp, -4
sw ra, 0(sp)
asb_body:
# extract expoent and mantissa from a and b
li t6, 0x7F800000
and t2, a0, t6 # ea
srli t2, t2, 23
addi t2, t2, -127
li t6, 0x7F800000
and t3, a1, t6 # eb
srli t3, t3, 23
addi t3, t3, -127
li t6, 0x007F0000
and t4, a0, t6 # ma
srli t4, t4, 16
ori t4, t4, 0x80
li t6, 0x007F0000
and t5, a1, t6 # mb
srli t5, t5, 16
ori t5, t5, 0x80
# normalization: make 2 numbers have the same exponent
blt t2, t3, asb_normalization_1
mv t6, t2 # t6 = ea
sub t2, t2, t3 # t2 = ea - eb
srl t5, t5, t2 # mb >>= t2
mv t2, t6 # e = t6
j asb_normalization_end
asb_normalization_1:
mv t6, t3 # t6 = eb
sub t2, t3, t2 # t2 = ea - eb
srl t4, t4, t2 # ma >>= t2
mv t2, t6 # e = t6
asb_normalization_end:
# addition or subtraction
li t6, 0x80000000
and t0, a0, t6 # sa
beqz t0, asb_not_invert_ma
sub t4, zero, t4
asb_not_invert_ma:
li t6, 0x80000000
and t1, a1, t6 # sb
beqz t1, asb_not_invert_mb_1
sub t5, zero, t5
asb_not_invert_mb_1:
bnez a2, asb_not_invert_mb_2
sub t5, zero, t5
asb_not_invert_mb_2:
add t4, t4, t5 # m = ma + mb
# handle negative result
li t0, 0
bgez t4, asb_positive_m
sub t4, zero, t4
li t0, 1
asb_positive_m:
# handle carry bit
andi t5, t4, 0x100
beqz t5, asb_no_carry
srli t4, t4, 1
addi t2, t2, 1
asb_no_carry:
# handle result of 0
li t5, 0x80
bnez t4, asb_small
li t2, -127 # e = -127
j asb_small_end
asb_small:
bge t4, t5, asb_small_end # while (m < 0x80)
addi t2, t2, -1 # e -= 1
slli t4, t4, 1 # m <<= 1
j asb_small
asb_small_end:
# construct the result
slli t0, t0, 31 # s = s << 31
addi t2, t2, 127 # e = (e + 127) << 23
slli t2, t2, 23
andi t4, t4, 0x7F # m = (m & 0x7F) << 16
slli t4, t4, 16
or a0, t0, t2 # r = s | e | m
or a0, a0, t4
asb_epilogue:
lw ra, 0(sp)
addi sp, sp, 4
ret
# --- add_bf16 ---
# addition of two bf16 numbers.
# input:
# a0: a (bf16): addition candidate
# a1: b (bf16): addition candidate
# output:
# a0: r (bf16): reslut of (a + b)
add_bf16:
addi sp, sp, -4
sw ra, 0(sp)
li a2, 1
jal ra, add_sub_bf16
lw ra, 0(sp)
addi sp, sp, 4
ret
# --- sub_bf16 ---
# subtraction of two bf16 numbers.
# input:
# a0: a (bf16): subtraction candidate
# a1: b (bf16): subtraction candidate
# output:
# a0: r (bf16): reslut of (a - b)
sub_bf16:
addi sp, sp, -4
sw ra, 0(sp)
li a2, 0
jal ra, add_sub_bf16
lw ra, 0(sp)
addi sp, sp, 4
ret
# ┌-------------------------------------------------------┐
# | Required Library - mul_shift_u32 v0.0.0 |
# └-------------------------------------------------------┘
# --- mul_shift_u32 ---
# binary multiplication of two u32 numbers
# input:
# a0: a (u32): multiplier
# a1: b (u32): multiplicand
# output:
# a0: r (u32): product of a and b (a * b)
mul_shift_u32:
mhu_prologue:
addi sp, sp, -4
sw ra, 0(sp)
bge a0, a1, mhu_no_swap
# make a1 <= a0
addi t0, a1, 0
mv a1, a0
mv a0, t0
mhu_no_swap:
# binary multiplication of t0 = a0 * a1
addi t0, zero, 0 # t0 = result
mhu_loop:
beq a1, zero, mhu_epilogue
andi t2, a1, 1 # the least significant bit of a1
beq t2, zero, mhu_next
add t0, t0, a0
mhu_next:
slli a0, a0, 1
srli a1, a1, 1
j mhu_loop
mhu_epilogue:
mv a0, t0
lw ra, 0(sp)
addi sp, sp, 4
ret
# ┌-------------------------------------------------------┐
# | Required Library - mul_bf16 v0.1.0 |
# └-------------------------------------------------------┘
# --- mul_bf16 ---
# multiplication of two bf16 numbers
# input:
# a0: a (bf16): multiplier
# a1: b (bf16): multiplicand
# output:
# a0: m, r (bf16): product of a and b (a * b)
# notes:
# s0: s
# s1: e
# t0: sa
# t1: sb
# t2: ea
# t3: eb
# t4: ma
# t5: mb
mul_bf16:
mb_prologue:
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
mb_body:
beqz a0, mb_epilogue
bnez a1, mb_nonzero_input
mv a0, zero
j mb_epilogue
mb_nonzero_input:
# extract sign, exponent and mantissa of a and b
sltz t0, a0 # sa
sltz t1, a1 # sb
li t3, 0x7F800000
and t2, a0, t3
srli t2, t2, 23
addi t2, t2, -127 # ea
and t3, a1, t3
srli t3, t3, 23
addi t3, t3, -127 # eb
li t5, 0x007F0000
and t4, a0, t5
srli t4, t4, 16
ori t4, t4, 0x80 # ma
and t5, a1, t5
srli t5, t5, 16
ori t5, t5, 0x80 # mb
# calculate the initial result
xor s0, t0, t1 # s = sa ^ sb
add s1, t2, t3 # e = ea + eb
mv a0, t4
mv a1, t5
jal ra, mul_shift_u32
srli a0, a0, 7 # m = (ma * mb) >> 7
# handle carry bit
andi t1, a0, 0x100
beqz t1, mb_no_carry
srli a0, a0, 1
addi s1, s1, 1
mb_no_carry:
# handle result of +-0
bnez a0, mb_nonzero_result
slli a0, s0, 31 # r = s << 31
j mb_epilogue
mb_nonzero_result:
# construct the result
slli s0, s0, 31 # s = s << 31
addi s1, s1, 127
slli s1, s1, 23 # e = (e + 127) << 23
andi a0, a0, 0x7F
slli a0, a0, 16 # m = (m & 0x7F) << 16
or a0, a0, s0
or a0, a0, s1 # r = s | e | m
mb_epilogue:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
ret
# ┌-------------------------------------------------------┐
# | Required Library - u32_bf16 v0.0.0 |
# └-------------------------------------------------------┘
# --- bf16_to_i32 ---
# (NOT IMPLEMENTED YET!)
# convert bf16 to i32
# input:
# a0: x (bf16): bf16 number to be processed
# output:
# a0: m, r (i32): 32-bit integer (without fraction)
bf16_to_i32:
ret
# --- i32_to_bf16 ---
# convert i32 to bf16
# input:
# a0: x (i32): integer to convert
# output:
# a0: m, r (bf16): float with roughly the same
# value as input
# notes:
# t0: s
# t1: e
i32_to_bf16:
itb_prologue:
addi sp, sp, -4
sw ra, 0(sp)
itb_body:
bnez a0, itb_nonzero_x
# x == 0
j itb_epilogue
itb_nonzero_x:
sltz t0, a0 # s = sign bit of x
li t1, 7 # e = 7
# `m = x` is `mv a0, a0`, which is nop
beqz t0, itb_positive_x
sub a0, zero, a0 # m = -x
itb_positive_x:
li t2, 0x80
itb_small_x:
bge a0, t2, itb_large_x_outer
addi t1, t1, -1
slli a0, a0, 1
j itb_small_x
itb_large_x_outer:
li t2, 0x100
itb_large_x_inner:
blt a0, t2, itb_result
addi t1, t1, 1
srli a0, a0, 1
j itb_large_x_inner
itb_result:
andi a0, a0, 0x7F
slli a0, a0, 16
addi t1, t1, 127
slli t1, t1, 23
slli t0, t0, 31
or a0, a0, t1
or a0, a0, t0
itb_epilogue:
lw ra, 0(sp)
addi sp, sp, 4
ret
# ┌-------------------------------------------------------┐
# | Library |
# └-------------------------------------------------------┘
# --- ln_bf16 ---
# return ln(abs(x))
# input:
# a0: x (bf16): number to transform
# output:
# a0: t (bf16): result of ln(abs(x))
# notes:
# s0: x, t (around last mul_bf16)
# s1: exp
# reference: ln_bf16.c
ln_bf16:
lb_prologue:
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
lb_body:
# remove extra bits (otherwise, offset-by-one bug occurs)
li t0, 0xFFFF0000
and a0, a0, t0
# catch zero
bnez a0, lb_nonzero_input
li a0, 0xFF800000
j lb_epilogue
lb_nonzero_input:
mv s0, a0 # s0 = x, for x will be used later
li t0, 0x7F800000
and a0, s0, t0 # a0 = *px & 0x7F800000
srli a0, a0, 23
addi a0, a0, -127 # a0 = (a0 >> 23) - 127
jal i32_to_bf16 # a0 = (bf16) a0
mv s1, a0 # exp = a0
# set x's exponent to 0
li t1, 0x7F0000
li t2, 0x3F800000
and s0, s0, t1
or s0, s0, t2 # x = 0x3F800000 | (*px & 0x7F0000)
# calculate result (t)
mv a0, s0 # a0 = x
li a1, 0x3DE10000 # lnc3 = 0.109
jal mul_bf16 # a0 = lnc3 * x
li a1, 0xBF3B0000 # lnc2 = -0.73
jal add_bf16 # a0 = a0 + lnc2
mv a1, s0 # a1 = x
jal mul_bf16 # a0 = a0 * x
li a1, 0x40070000 # lnc1 = 2.11
jal add_bf16 # a0 = a0 + lnc1
mv a1, s0 # a1 = x
jal mul_bf16 # a0 = a0 * x
li a1, 0xBFBF0000 # lnc0 = -1.49
jal add_bf16 # a0 = a0 + lnc0
mv s0, a0 # s0 = t
li a0, 0x3F310000 # ln2 = 0.69
mv a1, s1 # a1 = exp
jal mul_bf16 # a0 = ln2 * exp
mv a1, s0 # a1 = t
jal add_bf16 # a0 = a0 + t (result)
lb_epilogue:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
ret
Utilizing the testing suite in ln_bf16.c
, I have the results as shwon in the following table.
x, x (bf16, hex), ln(x), ln_fp32(x), ln_bf16(x), ln_bf16(x) (hex), error
0.0, 0x00000000, -inf, -inf, -inf, 0xFF800000, -nan(ind)
0.1, 0x3DCCCCCD, -2.303, -2.302, -2.281, 0xC0120000, -0.021
0.2, 0x3E4CCCCD, -1.609, -1.609, -1.578, 0xBFCA0000, -0.031
0.3, 0x3E99999A, -1.204, -1.204, -1.203, 0xBF9A0000, -0.001
0.4, 0x3ECCCCCD, -0.916, -0.916, -0.898, 0xBF660000, -0.018
0.5, 0x3F000000, -0.693, -0.693, -0.684, 0xBF2F0000, -0.010
0.6, 0x3F19999A, -0.511, -0.511, -0.512, 0xBF030000, 0.001
0.7, 0x3F333334, -0.357, -0.356, -0.355, 0xBEB60000, -0.001
0.8, 0x3F4CCCCE, -0.223, -0.223, -0.207, 0xBE540000, -0.016
0.9, 0x3F666668, -0.105, -0.106, -0.113, 0xBDE80000, 0.008
1.0, 0x3F800001, 0.000, 0.000, 0.008, 0x3C000000, -0.008
1.1, 0x3F8CCCCE, 0.095, 0.095, 0.086, 0x3DB00000, 0.009
1.2, 0x3F99999B, 0.182, 0.182, 0.180, 0x3E380000, 0.003
1.3, 0x3FA66668, 0.262, 0.262, 0.266, 0x3E880000, -0.003
1.4, 0x3FB33335, 0.336, 0.337, 0.336, 0x3EAC0000, 0.001
1.5, 0x3FC00002, 0.405, 0.406, 0.406, 0x3ED00000, -0.001
1.6, 0x3FCCCCCF, 0.470, 0.470, 0.484, 0x3EF80000, -0.014
1.7, 0x3FD9999C, 0.531, 0.530, 0.531, 0x3F080000, -0.001
1.8, 0x3FE66669, 0.588, 0.587, 0.578, 0x3F140000, 0.010
1.9, 0x3FF33336, 0.642, 0.642, 0.641, 0x3F240000, 0.001
2.0, 0x40000001, 0.693, 0.694, 0.699, 0x3F330000, -0.006
Average error: 0.007
Maximal error: 0.031
As for assebly, all the designed tests passed using the 5-stage processor in Ripes simulator.
With the help of Ripes, I see that pseudo-instructions (p.110) are transformed to the equivalent instructions that RISC-V proccessers can understand.
Moreover, registers in ABI name are also transformed to the register name that processors can understand.
In the figure above, li t6, 0x80000000
is converted to lui x31 0x80000
; beqz t1, asb_not_invert_mb_1
is converted to beq x6 x0 8 <asb_not_invert_mb_1>
To test assembly code in Ripes, I first test it with "single cycle processor", for the registers can be updated as soon as the instruction is executed. It is useful for the proof of concept.
Later, I choose "5-stage processor" to test if the code works well under potential hazards. However, either Ripes handle hazards for me, or it cannot properly simulate hazards. What a pity.
The block diagram of 5-stage processor in Ripes is like the following figure.
(not finished…)
ret
is equal to jr ra
and jalr x0, ra, 0
. It is used to return from or end a function (subroutine). (Reference)