## Week 48 notes/report still working on optimization and benching of the poly eval and fft kernel. ## FFT Development Radix-2 Decimate-In-Time $$ \begin{aligned} X(k) = \sum_{n=0}^{N-1} x(n)W_{N}^{kn} \;\;\;\;\; k =0,1,2,\dots,N-1 \;\;\;\;\;\; W_N = e^{-j2\pi / N} \end{aligned} $$ An $N$-point DFT can be written as the weighted sum of two $N/2$-point DFT's. One DFT of the even-indexed samples and one DFT of the odd-indexed samples, as: $$ \begin{aligned} X(k) = \sum_{n=0}^{(N/2)-1} x(2n)W_{N}^{k2n} \; + \; \sum_{n=0}^{(N/2)-1} x(2n+1)W_{N}^{k(2n+1)} \\ = \sum_{n=0}^{(N/2)-1} x_{\textrm{even}}(n)W_{N/2}^{kn} \; + \; \sum_{n=0}^{(N/2)-1} x_{\textrm{odd}}(n)W_{N/2}^{kn} \\ \end{aligned} $$ for $k =0,1,2,\dots,N-1$. Building upon the split-radix idea, this can be extended to other radix pairs. Such an example is the radix-2/8 algorithm for length-$2^m$. The radix-2/8 split-radix FFT algorithm has the same asymptotic arithmetic complexity as the conventional split-radix FFT algorithm, but with the advantage of fewer loads and stores. The idea of the extended split-radix FFT is the application of a radix-2 index map to the even-indexed tersm and a radix-8 index to the odd-indexed tersm. i.e., its the combination of one half-length and four eigth-length DFTs. If $N = 2^m$, then for the even index tersm $$ \begin{aligned} X(2k) = \sum_{n=0}^{(N/2) -1} \big[ x(n) + x(n + N/2) \big] W_{N}^{2kn} \end{aligned} $$ and for the odd indexed terms: $$ \begin{aligned} X(8k+1) = \sum_{n=0}^{(N/8) -1} \big[ ( (x(n) - x(n + N/2)) - (j(x(n+N/4) - x(n+3N/4)))) \\+ \frac{1}{\sqrt{2}} \bigl\{ (1-j) (x(n + N/8) - x(n + 5N/8))\bigl\} \\ - (1-j)(x(n + 3N/8) - x(n + 7N/8)) \big] W_N^n W_N^{8nk} \end{aligned} $$ ----- $$ \begin{aligned} X(8k+3) = = \sum_{n=0}^{(N/8) -1} \big[ ( (x(n) - x(n + N/2)) + (j(x(n+N/4) - x(n+3N/4)))) \\+ \frac{1}{\sqrt{2}} \bigl\{ (1+j) (x(n + N/8) - x(n + 5N/8))\bigl\} \\ - (1+j)(x(n + 3N/8) - x(n + 7N/8)) \big] W_N^3n W_N^{8nk} \end{aligned} $$ ----- $$ \begin{aligned} X(8k+5) = \sum_{n=0}^{(N/8) -1} \big[ ( (x(n) - x(n + N/2)) - (j(x(n+N/4) - x(n+3N/4)))) \\ - \frac{1}{\sqrt{2}} (1-j) (x(n + N/8) - x(n + 5N/8)) \\ - (1+j)(x(n + 3N/8) - x(n + 7N/8)) \big] W_N^5n W_N^{8nk} \end{aligned} $$ ----- $$ \begin{aligned} X(8k+7) = \sum_{n=0}^{(N/8) -1} \big[ ( (x(n) - x(n + N/2)) + (j(x(n+N/4) - x(n+3N/4)))) \\ + \frac{1}{\sqrt{2}} (1+j) (x(n + N/8) - x(n + 5N/8)) \\ - (1-j)(x(n + 3N/8) - x(n + 7N/8)) \big] W_N^7n W_N^{8nk} \end{aligned} $$ Looking to build this in a kernel and bench against current impl. ### kernel develpment: testing on base field on BN256 ```bn256.cu #include "./fp_u256.cuh" #include "../fft/fft.cuh" #include "../fft/twiddles.cuh" #include "../fft/bitrev_permutation.cuh" #include "../utils.h" namespace p256 { // BN256 Base field // P = 21888242871839275222246405745257275088696311157297823662689037894645226208583 using Fp = Fp256< // /* =N **/ /*u256(*/ 3486998266802970665, 13281191951274694749, 10917124144477883021, 4332616871279656263 /*)*/, /* =R_SQUARED **/ /*u256(*/ 1011752739694698287, 7381016538464732716, 754611498739239741, 15230403791020821917 /*)*/, /* =N_PRIME **/ /*u256(*/ 9786893198990664585, 11447725176084130505, 15613922527736486528, 17688488658267049067 /*)*/ >; } // namespace p256 extern "C" { __global__ void radix2_dit_butterfly( p256::Fp *input, const p256::Fp *twiddles, const int stage, const int butterfly_count) { _radix2_dit_butterfly<p256::Fp>(input, twiddles, stage, butterfly_count); } // NOTE: In order to calculate the inverse twiddles, call with _omega = _omega.inverse() __global__ void calc_twiddles(p256::Fp *result, const p256::Fp &_omega, const int count) { _calc_twiddles<p256::Fp>(result, _omega, count); }; // NOTE: In order to calculate the inverse twiddles, call with _omega = _omega.inverse() __global__ void calc_twiddles_bitrev(p256::Fp *result, const p256::Fp &_omega, const int count) { _calc_twiddles_bitrev<p256::Fp>(result, _omega, count); }; __global__ void bitrev_permutation( const p256::Fp *input, p256::Fp *result, const int len ) { _bitrev_permutation<p256::Fp>(input, result, len); }; } ``` basic radix-2. This is stil la WIP as there is a bug somewhere. likley in my value for 2-adicity, is it 192? below. ```fft.cuh #pragma once template <class Fp> inline __device__ void _radix2_dit_butterfly(Fp *input, const Fp *twiddles, const int stage, const int butterfly_count) { int thread_pos = blockDim.x * blockIdx.x + threadIdx.x; if (thread_pos >= butterfly_count) return; int half_group_size = butterfly_count >> stage; int group = thread_pos / half_group_size; int pos_in_group = thread_pos & (half_group_size - 1); int i = thread_pos * 2 - pos_in_group; // multiply quotient by 2 Fp w = twiddles[group]; Fp a = input[i]; Fp b = input[i + half_group_size]; Fp res_1 = a + w * b; Fp res_2 = a - w * b; input[i] = res_1; input[i + half_group_size] = res_2; }; ``` ```rust! impl IsFFTField for BN252PrimeField { const TWO_ADICITY: u64 = 192; //`0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46` // const ROOT_OF_UNITY: Fq = Fq::from_raw([ // 0x3c208c16d87cfd46, // 0x97816a916871ca8d, // 0xb85045b68181585d, // 0x30644e72e131a029, // ]); const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: U256 = UnsignedInteger::from_hex_unchecked( "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46", ); fn field_name() -> &'static str { "bn256" } } ```