## Week 48 notes/report
still working on optimization and benching of the poly eval and fft kernel.
## FFT Development
Radix-2 Decimate-In-Time
$$
\begin{aligned}
X(k) = \sum_{n=0}^{N-1} x(n)W_{N}^{kn} \;\;\;\;\; k =0,1,2,\dots,N-1 \;\;\;\;\;\; W_N = e^{-j2\pi / N}
\end{aligned}
$$
An $N$-point DFT can be written as the weighted sum of two $N/2$-point DFT's. One DFT of the even-indexed samples and one DFT of the odd-indexed samples, as:
$$
\begin{aligned}
X(k) = \sum_{n=0}^{(N/2)-1} x(2n)W_{N}^{k2n} \; + \; \sum_{n=0}^{(N/2)-1} x(2n+1)W_{N}^{k(2n+1)} \\
= \sum_{n=0}^{(N/2)-1} x_{\textrm{even}}(n)W_{N/2}^{kn} \; + \; \sum_{n=0}^{(N/2)-1} x_{\textrm{odd}}(n)W_{N/2}^{kn} \\
\end{aligned}
$$
for $k =0,1,2,\dots,N-1$.
Building upon the split-radix idea, this can be extended to other radix pairs. Such an example is the radix-2/8 algorithm for length-$2^m$. The radix-2/8 split-radix FFT algorithm has the same asymptotic arithmetic complexity as the conventional split-radix FFT algorithm, but with the advantage of fewer loads and stores.
The idea of the extended split-radix FFT is the application of a radix-2 index map to the even-indexed tersm and a radix-8 index to the odd-indexed tersm. i.e., its the combination of one half-length and four eigth-length DFTs.
If $N = 2^m$, then for the even index tersm
$$
\begin{aligned}
X(2k) = \sum_{n=0}^{(N/2) -1} \big[ x(n) + x(n + N/2) \big] W_{N}^{2kn}
\end{aligned}
$$
and for the odd indexed terms:
$$
\begin{aligned}
X(8k+1) = \sum_{n=0}^{(N/8) -1} \big[ ( (x(n) - x(n + N/2)) - (j(x(n+N/4) - x(n+3N/4)))) \\+ \frac{1}{\sqrt{2}} \bigl\{ (1-j) (x(n + N/8) - x(n + 5N/8))\bigl\} \\ - (1-j)(x(n + 3N/8) - x(n + 7N/8)) \big] W_N^n W_N^{8nk}
\end{aligned}
$$
-----
$$
\begin{aligned}
X(8k+3) = = \sum_{n=0}^{(N/8) -1} \big[ ( (x(n) - x(n + N/2)) + (j(x(n+N/4) - x(n+3N/4)))) \\+ \frac{1}{\sqrt{2}} \bigl\{ (1+j) (x(n + N/8) - x(n + 5N/8))\bigl\} \\ - (1+j)(x(n + 3N/8) - x(n + 7N/8)) \big] W_N^3n W_N^{8nk}
\end{aligned}
$$
-----
$$
\begin{aligned}
X(8k+5) = \sum_{n=0}^{(N/8) -1} \big[ ( (x(n) - x(n + N/2)) - (j(x(n+N/4) - x(n+3N/4)))) \\ - \frac{1}{\sqrt{2}} (1-j) (x(n + N/8) - x(n + 5N/8)) \\ - (1+j)(x(n + 3N/8) - x(n + 7N/8)) \big] W_N^5n W_N^{8nk}
\end{aligned}
$$
-----
$$
\begin{aligned}
X(8k+7) = \sum_{n=0}^{(N/8) -1} \big[ ( (x(n) - x(n + N/2)) + (j(x(n+N/4) - x(n+3N/4)))) \\ + \frac{1}{\sqrt{2}} (1+j) (x(n + N/8) - x(n + 5N/8)) \\ - (1-j)(x(n + 3N/8) - x(n + 7N/8)) \big] W_N^7n W_N^{8nk}
\end{aligned}
$$
Looking to build this in a kernel and bench against current impl.
### kernel develpment:
testing on base field on BN256
```bn256.cu
#include "./fp_u256.cuh"
#include "../fft/fft.cuh"
#include "../fft/twiddles.cuh"
#include "../fft/bitrev_permutation.cuh"
#include "../utils.h"
namespace p256
{
// BN256 Base field
// P = 21888242871839275222246405745257275088696311157297823662689037894645226208583
using Fp = Fp256<
//
/* =N **/ /*u256(*/ 3486998266802970665, 13281191951274694749, 10917124144477883021, 4332616871279656263 /*)*/,
/* =R_SQUARED **/ /*u256(*/ 1011752739694698287, 7381016538464732716,
754611498739239741, 15230403791020821917 /*)*/,
/* =N_PRIME **/ /*u256(*/ 9786893198990664585, 11447725176084130505,
15613922527736486528, 17688488658267049067 /*)*/
>;
} // namespace p256
extern "C"
{
__global__ void radix2_dit_butterfly( p256::Fp *input,
const p256::Fp *twiddles,
const int stage,
const int butterfly_count)
{
_radix2_dit_butterfly<p256::Fp>(input, twiddles, stage, butterfly_count);
}
// NOTE: In order to calculate the inverse twiddles, call with _omega = _omega.inverse()
__global__ void calc_twiddles(p256::Fp *result, const p256::Fp &_omega, const int count)
{
_calc_twiddles<p256::Fp>(result, _omega, count);
};
// NOTE: In order to calculate the inverse twiddles, call with _omega = _omega.inverse()
__global__ void calc_twiddles_bitrev(p256::Fp *result,
const p256::Fp &_omega,
const int count)
{
_calc_twiddles_bitrev<p256::Fp>(result, _omega, count);
};
__global__ void bitrev_permutation(
const p256::Fp *input,
p256::Fp *result,
const int len
) {
_bitrev_permutation<p256::Fp>(input, result, len);
};
}
```
basic radix-2. This is stil la WIP as there is a bug somewhere. likley in my value for 2-adicity, is it 192? below.
```fft.cuh
#pragma once
template <class Fp>
inline __device__ void _radix2_dit_butterfly(Fp *input,
const Fp *twiddles,
const int stage,
const int butterfly_count)
{
int thread_pos = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_pos >= butterfly_count) return;
int half_group_size = butterfly_count >> stage;
int group = thread_pos / half_group_size;
int pos_in_group = thread_pos & (half_group_size - 1);
int i = thread_pos * 2 - pos_in_group; // multiply quotient by 2
Fp w = twiddles[group];
Fp a = input[i];
Fp b = input[i + half_group_size];
Fp res_1 = a + w * b;
Fp res_2 = a - w * b;
input[i] = res_1;
input[i + half_group_size] = res_2;
};
```
```rust!
impl IsFFTField for BN252PrimeField {
const TWO_ADICITY: u64 = 192;
//`0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46`
// const ROOT_OF_UNITY: Fq = Fq::from_raw([
// 0x3c208c16d87cfd46,
// 0x97816a916871ca8d,
// 0xb85045b68181585d,
// 0x30644e72e131a029,
// ]);
const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: U256 = UnsignedInteger::from_hex_unchecked(
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46",
);
fn field_name() -> &'static str {
"bn256"
}
}
```