owned this note
owned this note
Published
Linked with GitHub
---
title: 'About Square Root operation on bfloat16 data'
---
<style>
.two-column-layout {
column-count: 2; /* Set column number */
column-gap: 20px;
max-width: 100%;
overflow: hidden;
}
/* Media query for mobile devices */
@media (max-width: 768px) {
.two-column-layout {
column-count: 1; /* Switch to single column on small screens */
column-gap: 0; /* Optional: Set gap to 0 for single column */
}
}
.markdown-body, .ui-infobar {
max-width: unset !important;
}
.two-column-layout ul,
.two-column-layout ol {
margin: 0;
padding-left: 20px;
}
.two-column-layout strong {
font-weight: bold;
}
.two-column-layout em {
font-style: italic;
}
.two-column-layout h1,
.two-column-layout h2,
.two-column-layout h3,
.two-column-layout h4,
.two-column-layout h5,
.two-column-layout h6 {
margin-top: 0;
}
</style>
# What is the Square Root
"How far is it from here to there?", "What is the distance between point A and B?", How long does it take to drive to our school? You might usually hear above questions in you life, but how did you realize the concept of distance? We can start from "Pythagorean Theorem", which is known as a method calculating the length of hypotenuse in a triangle. In the theorem, if we let the lengths of the legs be $a$ and $b$, we can obtain the equation $c^2=a^2+b^2$. Then, to get $c$ from $c^2$, we apply a new function defined by $y=\sqrt{x}$, so the hypotenuse $c=\sqrt{c^2}$. With this, we can easily implement the graphical application such as drawing, animating, browser,..., etc on a computer. Besides, as the neural networks precedes, we can model more complicate tasks using the distance between model outputs and expected results. In summary, without a proper implementation of square root, we cannot develop applications close to our human imagination (After all, we like to see an object through our eyes).
# Why is the bfloat16 data type
In addition to integer type of data on a computer, preciser distance can be represented by floating point number. That is, before we start to implement the square root, we need to concern the issues about floating point number representation. In IEEE 754 standard, 32-bit data is used for the floating number (float32), with 1-bit sign bit, 8-bit exponent, and 23-bit mantissa. Based on this, one can imagine that all registers or digital circuit design should process the 32-bit data, leading to inevitable large overheads on costs if the higher throughput design is considered. Therefore, the Google introduced a new representation of floating point number called "bfloat16", with 1-bit sign bit, 8-bit exponent, and 7-bit mantissa. The bfloat16 have the same range as the float32 because they both have 8-bit exponent. However, the bfloat16 discards some resolution by reducing the bit width of mantissa. The bfloat16 gains its benefits from implementing a great parallelized designs such as neural network or accelerator for artificially intellegence (AI). Among these targets, one can use a 32-bit data register for storing two floating point number. That is, "Same storage cost but double throughtput". In summary, we have to study how to implement the arithmetic operations on the bfloat16. In this case, the square root is considered.
# How to efficiently perform square root on bfloat16 data type
## Observation on Square Root
With an important property "values increasing, the square root of it increasing", we can guess the square root output by iterating all possible value within a pre-computed range.
Let $x\in\mathbb{R^+}$, the square root of $x$ is $\sqrt{x}$.
We can find $x+1 > x$,$\sqrt{x+1} > \sqrt{x}$.
Next, after pre-computing the upper bound $O$ and lower bound $\Omega$ of $\sqrt{x}$, we can iterate $\sqrt{t}$ from $\Omega$ to $O$ and calculate $(\sqrt{t})^2$.
Finally, we compare $(\sqrt{t})^2$ with $x$ and find the closest one for the answer.
The above approach has time complexity $O(n)$. In order to optimize the performance, we can change the search scheme for a faster and logical way.
## Binary Search algorithm
The time complexity $O(n)$ is like searching a specific number on a 1-D line. To reduce the finding iterations, we can assume the search line now shrinks as a balanced tree. With the balanced tree, we can show that the longest path of search process is reduced to $O(\log n)$. However, how to build a such tree for our goal? We can use the property $x+1 > x$,$\sqrt{x+1} > \sqrt{x}$. That is, the specific number can be found through a series of comparisons.
1. Let the root of the balanced tree be $med=median(O, \Omega)$
2. if $med^2 = x$, then return $med$
3. else if $med^2 < x$, then calculate $med=median(med-1, \Omega)$
4. else $med^2 > x$, then calculate $med=median(O, med+1)$
5. repeat 2.~ 4.
Since the numbers between $O$ and $\Omega$ is naturally sorted, the $median$ can be replaced by $Avg(x,y)=\frac{x+y}{2}$. Instead of using division, we can obtain the square root result by adding, subtractin, multiplication, and right-shifting (divide-by-2).
One can find that, in fact, we don't need to know the apprearence of the "balanced tree" because we actually perform the binary search. Therefore, the balanced tree is implcitly built as the binary search tree (BST). In this case, we can further utilize the structure of the BST to analyze the timing complexity.
According to the structure, we can obtain the worst time complexity is $O(\log n)$ since the height of BST is $O(\log n)$.
As for the best time complexity, it can be found in the step 1 in the above algorithm: $O(1)$.
Finally, the average time complexity can be calculated by the "Recurrence" technique in algorithm analysis. Based on the structure of BST, the recurrence form is
$T(n)=\begin{cases}O(1)\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ if\ n=1\\T(n/2)+O(1)\ if\ n>1\end{cases}$
Next, we introduce the Master Theorem: $T(n)=aT(n/b)+f(n)$
We get $a=1, b=2, f(n)=O(1)$.
Since $n^{log_b(a)}=1$ is as same order as $f(n)$. Case 2 is applied.
Therefore, $T(n)=\Theta(n^{log_b(a)}log_2(n))=\Theta(log_2(n))$
## Prerequisites
To implement any operations about bfloat16 or float32, it is good to extract each field of target type and manipulate them according to requirements. Therefore, the first thing is to know the layout of bfloat16.
|Sign (1-bit)|Exponent (8-bit)|Mantissa (7)|
|------------|----------------|------------|
15---------------------14---------------------------------7-6--------------------------0
---
For readable code snippet, the struct bf16_t is defined as
```c=1
typedef struct {
uint16_t bits
} bf16_t // we use bf16_t for following code
```
Therefore, when we need to manipulate the scope of the bit, we can use "xxx.bits".
## Basic information
As mentioned above, we need to manipulate the data according to requirements. The requirements are listed below:
- +/- bfloat16 value: exponent: 1~254, mantissa: any
- +/- infinity: exponent: 255, mantissa: zero
- zero: exponent: 0, mantissa: 0
- NaN: exponent: 255, mantissa: non-zero
- Denorm: exponent: 0, mantissa: non-zero
</div>
---
## Implementation
```c=4
static inline bf16_t bf16_sqrt(bf16_t a) {
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
...
}
```
Here, we perform the extractions using bitwise operations. The extraction means masking the valid part in bits and make it align with LSB. Therefore, some shiftings are required.
For sign bit, code line 5 moves the sign bit to the LSB. In general, if one has already put two bfloat16 into a 32-bit register, masking with ```& 1``` is additionally needed. But in this case, I think it can be optimized as
```c=4
static inline bf16_t bf16_sqrt(bf16_t a) {
uint16_t sign = (a.bits >> 15);
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
...
}
```
As for exponent, one may think masking 0x7f80 then right shift 7 can obtain the same result, but this implicits large immediate value (exceed 12-bit) of shift offset, causing some problems for C compiler during transform into assembly. Thus, the code line 6 is great.
Last, the mantissa is originally aligned with LSB, so masking with 0x7F is enough and straightforward.
---
Next, let's move on to judging several special cases. Not a number($NaN$), Infinity ($\pm\infty$), zero ($0$) cases are generated from below situations:
- $x=NaN$, $\sqrt{x}=NaN$。
- $x=\pm\infty$, $\sqrt{x}=\pm\infty$。
- $x=\pm0$, $\sqrt{x}=0$。
- $x<0$, $\sqrt{x}=NaN$。
- $x\in Denorm$, $\sqrt{x}=0$。
So the code snippet following is
```c=8
/* Handle special cases */
if (exp == 0xFF) {
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
...
```
---
After judging the special cases, we take a look at the behavior of the binary search. In the binary search discussed above, there are some square operations. i.e. $x^2$. For a bfloat16 value, the exponent value can be easily handled by right-shifting. Note that if we perform right-shifting, the LSB will be discarded. Therefore, we need to judge the discarded bit is 1 or 0 to avoid precision loss. Just like 2 people eating pizza, when the pizza is initially divided into 7 pieces, they can fairly obtain 3.5 pieces pizza (by whatever way as long as they agree with... LOL). The 0.5 factor will be reflect to the mantissa.
Yep, now turn to the mantissa part, the mantissa will be the role in the guessing process. That is, we need to pre-compute the upper and lower bound of it. To increase precision, we add implicit 1 into mantissa, so the range will be [128, 256). If the exponent (pizza) is odd, 0.5 is added in mantissa, so the range will be [256, 512). Now, we get the upper bound $O=\sqrt{512}$ and lower bound $\Omega=\sqrt{128}$.
```c=28
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
...
```
Now, we can perform the binary search.
```c=51
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
...
```
Here, we see a strange term ```/128```. The term is for the previous operation ```uint32_t m = 0x80 | mant;```, which scales value and lead to the range [128, 256). When we perform the comparisons in binary search, ```mid*mid``` has double bit width making the direct comparison inavailable. Therefore, we need to scale the ```mid*mid``` by ```/128```. After the process, we also need to eliminate the implicit 1.
```c=63
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
...
```
Although we have already analyzed the special cases, we still need to check the output for the code integrity.
The original code from our instructor:
```c=80
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; /* +Inf */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
```
This judge is basically from the **Section Basic information**.
In my opinion, the code can be extent by adding $NaN$ judge, so the code will be
```c=80
if (new_exp > 0xFF)
return (bf16_t) {.bits = 0x7F80}; /* +Inf */
if (new_exp == 0xFF && new_mant != 0)
return BF16_NAN(); /* NaN */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
```
---
## Reference
[從 √2 的存在談開平方根的快速運算 https://hackmd.io/@sysprog/sqrt](https://hackmd.io/@sysprog/sqrt)
[Quiz1 of Computer Architecture (2025 Fall) https://hackmd.io/@sysprog/arch2025-quiz1-sol](https://hackmd.io/@sysprog/arch2025-quiz1-sol)