Danae Savvidi
    • Create new note
    • Create a note from template
      • Sharing URL Link copied
      • /edit
      • View mode
        • Edit mode
        • View mode
        • Book mode
        • Slide mode
        Edit mode View mode Book mode Slide mode
      • Customize slides
      • Note Permission
      • Read
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Write
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Engagement control Commenting, Suggest edit, Emoji Reply
    • Invite by email
      Invitee

      This note has no invitees

    • Publish Note

      Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

      Your note will be visible on your profile and discoverable by anyone.
      Your note is now live.
      This note is visible on your profile and discoverable online.
      Everyone on the web can find and read all notes of this public team.
      See published notes
      Unpublish note
      Please check the box to agree to the Community Guidelines.
      View profile
    • Commenting
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
      • Everyone
    • Suggest edit
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
    • Emoji Reply
    • Enable
    • Versions and GitHub Sync
    • Note settings
    • Note Insights New
    • Engagement control
    • Make a copy
    • Transfer ownership
    • Delete this note
    • Save as template
    • Insert from template
    • Import from
      • Dropbox
      • Google Drive
      • Gist
      • Clipboard
    • Export to
      • Dropbox
      • Google Drive
      • Gist
    • Download
      • Markdown
      • HTML
      • Raw HTML
Menu Note settings Note Insights Versions and GitHub Sync Sharing URL Create Help
Create Create new note Create a note from template
Menu
Options
Engagement control Make a copy Transfer ownership Delete this note
Import from
Dropbox Google Drive Gist Clipboard
Export to
Dropbox Google Drive Gist
Download
Markdown HTML Raw HTML
Back
Sharing URL Link copied
/edit
View mode
  • Edit mode
  • View mode
  • Book mode
  • Slide mode
Edit mode View mode Book mode Slide mode
Customize slides
Note Permission
Read
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Write
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Engagement control Commenting, Suggest edit, Emoji Reply
  • Invite by email
    Invitee

    This note has no invitees

  • Publish Note

    Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

    Your note will be visible on your profile and discoverable by anyone.
    Your note is now live.
    This note is visible on your profile and discoverable online.
    Everyone on the web can find and read all notes of this public team.
    See published notes
    Unpublish note
    Please check the box to agree to the Community Guidelines.
    View profile
    Engagement control
    Commenting
    Permission
    Disabled Forbidden Owners Signed-in users Everyone
    Enable
    Permission
    • Forbidden
    • Owners
    • Signed-in users
    • Everyone
    Suggest edit
    Permission
    Disabled Forbidden Owners Signed-in users Everyone
    Enable
    Permission
    • Forbidden
    • Owners
    • Signed-in users
    Emoji Reply
    Enable
    Import from Dropbox Google Drive Gist Clipboard
       Owned this note    Owned this note      
    Published Linked with GitHub
    • Any changes
      Be notified of any changes
    • Mention me
      Be notified of mention me
    • Unsubscribe
    # Reproduction: Balanced Resonate-and-Fire Neurons ### Authors **Danae Savvidi (5457769)** - Reproduce ECG parts of Figure 5. - Run models to reproduce SHD parts of Figure 5. - Ablation study to test effect of the 3 BRF mechanisms. **Henry Page (5472636)** - Reproduce S-MNIST parts of Figure 5. - Jax port of code and experiments to determine speedups. **Bernadett Bakos (5515890)** - Reproduce PS-MNIST parts of Figure 5. - Evaluation of different hidden layer sizes in BRF neurons. ### Links - GitHub link: https://github.com/1henrypage/brf-neurons - Blog post: https://hackmd.io/@danaesav/brf-neurons - Paper link: https://arxiv.org/abs/2402.14603 ### AI Disclaimer We have used AI to rewrite some parts of our report for better phrasing. This included inputting our ideas and asking it to rewrite in more formal language. ## Introduction Spiking Neural Networks (SNNs) represent one of the most promising frontiers in neuromorphic computing, offering the potential for ultra-low power AI systems that more closely mimic the brain's computational principles. However, despite their biological plausibility and energy efficiency advantages, SNNs have historically struggled with training stability and practical performance compared to traditional artificial neural networks. Enter Balanced Resonate-and-Fire (BRF) neurons. BRF neurons introduce three critical stabilization mechanisms: a refractory period that temporarily increases the firing threshold, a smooth reset mechanism that enhances dampening rather than abruptly resetting membrane potential, and a mathematically-derived divergence boundary that ensures stable dynamics during training. The original paper demonstrated impressive results across multiple challenging datasets, showing not only improved accuracy but also faster convergence and more efficient spiking behavior. However, as with any research, independent reproduction and deeper analysis are essential to validate these claims and understand the underlying mechanisms. In this comprehensive reproduction study, we tackled different aspects of the BRF model to provide a thorough evaluation of its capabilities and limitations. Whether you're interested in the theoretical foundations of spiking neural networks, practical implementation considerations, or the potential for neuromorphic computing applications, this study offers valuable perspectives on one of the most promising recent advances in the field. Through our combined efforts, we not only validate the original claims but also extend our understanding of the paper, providing you the reader with both confirmation of their effectiveness and new tools for future research and development. ### Spiking Neural Networks Spiking Neural Networks (SNNs) are a type of neural network that aim to be more biologically realistic than standard artificial neural networks (ANNs). Instead of using continuous activations, neurons in SNNs communicate by sending discrete spikes over time, similar to how real neurons in the brain work. This makes them especially useful for processing time-based or sequential data, like audio signals, or brain activity. In an SNN, each neuron keeps track of its membrane potential, and when this potential crosses a certain threshold, the neuron "fires" and sends a spike to connected neurons. The timing of these spikes is important, and this time-based representation allows SNNs to naturally model temporal patterns. One of the main advantages of SNNs is their potential for energy efficiency, especially when used on neuromorphic hardware that only computes when spikes occur. However, training SNNs has historically been difficult because of the non-differentiable nature of spikes. Recently, methods like surrogate gradients have made it possible to train SNNs using backpropagation, which has led to growing interest in using them for machine learning tasks. ### Resonate-and-Fire Neurons Resonate-and-fire (RF) neurons are a type of simplified spiking neuron model designed to capture the resonant properties observed in some biological neurons. Unlike the classic integrate-and-fire (IF) model, which responds primarily to the cumulative input current, RF neurons exhibit subthreshold oscillations and are tuned to fire in response to inputs at particular frequencies. This makes them particularly useful for modeling neurons in systems like auditory pathways, where timing and frequency sensitivity are essential. Despite these advantages, RF neurons face several limitations that make them challenging to use in practice. Their resonant dynamics can easily become unstable, especially during training, leading to excessive or uncontrolled spiking. This instability not only introduces noise into the network but can also prevent meaningful learning. Additionally, standard RF models often struggle to maintain consistent performance across different tasks and datasets, particularly those requiring long-term temporal processing. These issues have motivated the development of improved models, such as the Balanced Resonate-and-Fire (BRF) neuron, which introduces mechanisms to control spiking behavior and ensure stable learning while preserving the useful resonance properties of the original model. ### Balanced Resonate-and-Fire Neurons Balanced Resonate-and-Fire (BRF) neurons are an improved version of the original RF model, designed to address some of its limitations during training and inference. Standard RF neurons can become unstable or produce excessive spikes due to their resonant dynamics. The BRF model introduces three key modifications: a refractory period to temporarily raise the spiking threshold after firing, a smooth reset mechanism that increases dampening instead of resetting the membrane potential abruptly, and a divergence boundary to ensure stability by constraining the dampening factor based on frequency. ## Reproduction - Group (Figure 5) We reproduce Figure 5 of the paper which can be seen below: ![Screenshot 2025-05-22 at 16.06.01](https://hackmd.io/_uploads/H1JDyh2Zee.png) The authors run each model 5 times to get these results, however due to limited computational resources we have ran each model once. The authors clearly specified all hyperparameters used in their experiments, not only in the paper itself but also in their publicly available codebase, allowing us to replicate the setup accurately. ### Reproduction for SMNIST dataset <div style="display: flex; justify-content: space-between; gap: 20px; margin: 20px 0;"> <img src="https://hackmd.io/_uploads/H1zqBaB7el.png" alt="smnist_accuracy_plot" style="width: 100%; height: auto; min-width: 0;"/> <img src="https://hackmd.io/_uploads/SyXqB6rmlx.png" alt="smnist_brf_parameters_plot" style="width: 100%; height: auto; min-width: 0;"/> </div> #### Accuracy The test accuracy curves on the S-MNIST dataset for all three models closely align with those presented in the original paper, both in terms of final accuracy and overall learning dynamics. #### Divergence Boundary Analysis In our parameter analysis, we also observe the formation of clusters, however, their locations differ from those reported in the original paper. While the original study identified clusters between $\omega$ values of 18–28 and around 40, our results show clusters around $\omega$ values of 23, 40, 60, and 76. Additionally, whereas the original parameters were located in the range of 0–60 for $\omega$, our parameters span a broader range from 0 to 80, indicating a more dispersed distribution. A similar trend is seen in the dampening factor $b_c$: the original paper reported values between -30 and 0, whereas our results range from -45 to 0. Another notable difference is that our optimized parameters occasionally exceed the divergence boundary, whereas in the original study, all parameters remained within it. ### Reproduction for PSMNIST dataset <div style="display: flex; justify-content: space-between; gap: 20px; margin: 20px 0;"> <img src="https://hackmd.io/_uploads/Hkm5S6HQxl.png" alt="psmnist_accuracy_plot" style="width: 100%; height: auto; min-width: 0;"/> <img src="https://hackmd.io/_uploads/SyX9HpH7xl.png" alt="psmnist_brf_parameters_plot" style="width: 100%; height: auto; min-width: 0;"/> </div> #### Accuracy The test accuracies for the BRF model on the PS-MNIST dataset show similar behavior to those reported in the original paper, converging at a comparable speed and reaching roughly the same final accuracy. In contrast, the ALIF model displays a slightly different trend: its accuracy remains low until around the 30th epoch, resulting in slower convergence and reaching a final accuracy of 78.5%, significantly lower than the 94.3% reported in the original study. This difference in behavior might be the result of having performed just a single run, which may have been an outlier; repeated runs could potentially yield different results. #### Divergence Boundary Analysis The $b_c$ and $\omega$ parameters exhibit distributions similar to those reported in the original paper, generally clustering near the divergence boundary. The only difference is that, while the original parameters remained entirely within this boundary, our values occasionally exceeded it. ### Reproduction for ECG dataset <div style="display: flex; justify-content: space-between; gap: 20px; margin: 20px 0;"> <img src="https://hackmd.io/_uploads/H1tok6pWgg.png" alt="ecg_accuracy_plot" style="width: 100%; height: auto; min-width: 0;"/> <img src="https://hackmd.io/_uploads/BJ4bg3nZeg.png" alt="ecg_brf_parameters_plot" style="width: 100%; height: auto; min-width: 0;"/> </div> #### Accuracy The test accuracy curves for the RF and BRF models look similar to those reported in the original BRF paper. Notably, the ALIF model required significantly more training time—nearly two full days. This adds to the original claim that BRF models converge much faster. Interestingly, our reproduced ALIF model achieved a considerably higher final accuracy than reported in the original paper. However, this came with the expected trade-off: training was much slower than in BRF or RF models. #### Divergence Boundary Analysis We can see in general that the values align with the ones in the paper with the exception of some outliers between the 10-20 range for angular frequency and -5 to -8 for the dampening factor. However we believe that with a few more runs these values could be observed, since we also got some outliers in the 7.5+ range for angular frequency and <-4 for dambening factor. ### Reproduction for SHD dataset <div style="display: flex; justify-content: space-between; gap: 20px; margin: 20px 0;"> <img src="https://hackmd.io/_uploads/Bk-tKiefge.png" alt="shd_accuracy_plot" style="width: 100%; height: auto; min-width: 0;"/> <img src="https://hackmd.io/_uploads/SJZFtjxMge.png" alt="shd_brf_parameters_plot" style="width: 100%; height: auto; min-width: 0;"/> </div> #### Accuracy Both the RF and BRF models show test accuracy trends consistent with those reported in the original paper. In contrast, the ALIF model performs significantly better in our experiments. It converges more quickly and reaches a higher final accuracy than what was reported. However, the original paper also noted a high variance in ALIF's test accuracy across multiple runs, which suggests that our results may be an outlier, especially considering that our evaluation was based on a single run, whereas the original study averaged results over multiple runs. #### Divergence Boundary Analysis The $b_c$ and $\omega$ parameters exhibit similar behavior to what was reported in the original paper. Both the initial and optimized parameter values follow comparable distributions. While some points in our experiment do lie outside the divergence boundary, the deviation is smaller than what was observed in the S-MNIST and PS-MNIST cases. ### Conclusion Across all datasets, the test accuracy curves for the RF and BRF models closely match those reported in the original paper. The ALIF model, on the other hand, performed better in some cases and worse in others. This variability is likely due to the fact that our results are based on a single run, whereas the original paper averaged results over multiple runs and reported high variance. The $b_c$ and $\omega$ parameters generally exhibit similar behavior across all datasets compared to the original study. The main difference is that, in our experiments, these parameters occasionally exceed the divergence boundary for all datasets. This may be explained by the fact that the divergence boundary is only enforced on $p(ω)$, while $b'$ is learned without any constraints. Since the plots display $b_c=p(ω)-b'$, this unconstrained learning of $b'$ can result in $b_c$ values falling outside the boundary. ## Reproductions - Individual ### Danae - Ablation Study: Impact of the 3 BRF Mechanisms on Performance In this ablation study we explore the performance effect of the three key mechanisms of the Balanced Resonate-and-Fire (BRF) neuron model: a refractory period (RP), smooth reset (SmR), and a divergence boundary (DB). #### Motivation This ablation study aims to isolate the effects of each individual mechanism introduced in the BRF neuron model to better understand their contribution to the model's overall performance. While the full BRF architecture has been shown to improve stability, convergence, and spiking efficiency, it is unclear to what extent each mechanism is responsible for these improvements. Without this analysis, it is difficult to assess whether all components are necessary or if similar results could be achieved with a simpler model. Additionally, reproducing and analyzing this model helps validate the original claims that all three are benefitial, and provides insight into the trade-offs of including each mechanism. #### Changes Here we describe the changes we have made such that we can enable/disable the different mechanisms. The **refractory period** (RP) models the temporary increase in spiking threshold that occurs after a neuron fires. Mathematically, this is captured by a decaying term $q(t)=γq(t−δ)+z(t−δ)$, which is added to the firing threshold $\vartheta(t)=\vartheta_c+q(t)$. This regulates how frequently a neuron can fire. Implementation-wise, disabling RP means removing $q(t)$ from the threshold calculation ($\vartheta(t)=\vartheta_c$) and spike condition. In summary: $$ q(t) = \begin{cases} \gamma q(t-\delta) + z(t-\delta) & \text{if RP enabled} \\ 0 & \text{otherwise} \end{cases} $$ $$ \vartheta(t) = \begin{cases} \vartheta_c + q(t) & \text{if RP enabled} \\ \vartheta_c & \text{otherwise} \end{cases} $$ The **smooth reset** (SmR) modifies the way a neuron resets after firing. Traditional RF neurons use a hard reset, meaning they zero out membrane potential after a spike. In BRF, instead of an abrupt reset, the refractory term $q(t)$ is also subtracted from the membrane’s dampening factor $b(t)=p(ω)−b' −q(t)$, causing the membrane potential to decay more quickly but continuously after spiking. To disable SmR, we simply remove $q(t)$ from the computation of $b(t)$, meaning the neuron resets uniformly regardless of recent spiking. In summary: $$ b(t) = \begin{cases} p(\omega) - b' - q(t) & \text{if SmR enabled} \\ p(\omega) - b' & \text{otherwise} \end{cases} $$ The **divergence boundary** (DB) is one of the most important stability mechanisms in the BRF model. The RF neuron’s dynamics are based on a complex-valued linear differential equation: $\dot{u}=(b+i\omega)u+I$, which, when discretized, can become unstable for certain combinations of dampening $b$, frequency $\omega$, and timestep $\delta$. The authors for the BRF mode analytically derive a boundary for $b$ that ensures convergence: $p(\omega)= \frac{-1 + \sqrt{1 - (\delta\omega)^2}}{\delta}$. This constraint prevents the neuron's oscillations from diverging over time. In our ablation study, removing the divergence boundary meant replacing $p(\omega)$ with a fixed value (−1). In summary: $$ p(\omega) = \begin{cases} \frac{-1 + \sqrt{1 - (\delta\omega)^2}}{\delta} & \text{if DB enabled} \\ -1 & \text{otherwise} \end{cases} $$ #### Experiments and Metrics We test the following scenarios: Full BRF, No RP, No SmR, No DB. We compare the scenarios on two tasks (ECG and SHD) to test convergence, accuracy, stability, and SOP (spike operations). The authors of the paper report SOP (Spiking Operations) as the average number of spikes per test sample (total spikes divided by dataset size), measured on the test set. Since we have multiple epochs, they report SOP using the model that achieved the highest validation accuracy during training. For our experiments we do the same. #### Results & Discussion A table summarizing the results can be seen below. ![results_table](https://hackmd.io/_uploads/H1NZx_efel.png) ##### Accuracy Removing the smooth reset (SmR) component led to a slight but consistent improvement in accuracy on both ECG and SHD datasets. This suggests that while SmR helps by damping spikes to encourage sparsity, it might also block some useful subthreshold oscillations. In tasks that have periodic or extended timing patterns (e.g. the rhythmic waveforms in ECG or SHD’s long, sparse audio-based spikes) keeping those oscillations without the extra dampening seems to help with learning (although as we will later explain see it comes at the cost of higher SOPs). The divergence boundary (DB) had the biggest effect on accuracy. Turning it off caused the biggest performance drop (85.8% &rarr; 83.1% on ECG and 91.7% &rarr; 88.3% on SHD). We later explain what might cause this. The refractory period (RP) had a smaller impact, slightly lowering accuracy when removed, but not as much as DB or SmR. ##### Spike Efficiency (SOPs) Interestingly, removing DB led to the lowest number of spike operations across both tasks. Without the boundary constraint, the dampening factor becomes fixed and possibly overly suppressive as we will explain later, resulting in fewer spikes. This however comes at the cost of learning. In contrast, removing RP or SmR increased spike counts. Although both mechanisms are designed to reduce unnecessary firing, and hence removing one of them would cause more spikes, we noticed an increase in accuracy which could indicate better filtering of spikes (so more meaningful spikes were made and not dampened as in the no-DB scenario). The full BRF model struck a good balance, keeping spike operations relatively low while still maintaining solid accuracy. This suggests that the combined effects of RP, SmR, and DB work together for efficient/accurate spiking that matches the task’s needs, rather than just minimizing spikes for the sake of it. ##### Convergence <div style="display: flex; justify-content: space-between;"> <img src="https://hackmd.io/_uploads/HyzwFixzxe.png" alt="shd_accuracy_plot_ablation" width="500" height="230"> <img src="https://hackmd.io/_uploads/ByVHzJZfel.png" alt="ecg_accuracy_plot_ablation" width="500" height="230"> </div> The divergence boundary (DB) had a significant impact on convergence, but its effect differed across datasets. On the ECG task, removing DB resulted in unstable training and lower accuracy. Interestingly, this wasn’t due to excessive spiking, in fact, SOPs were lowest without DB. Instead, the issue seems to be that the fixed dampening (set to −1 when DB is disabled) was too strong for the low-frequency neurons typically involved in ECG. From the BRF parameters plot we can see that usually in ECG the optimized dampening values are above -1 (so lower dampening). These neurons are supposed to resonate slowly and carry information over time, but when the dampening is too high, their oscillations decay too quickly. This can cause the model to lose important subthreshold dynamics and fail to respond meaningfully to the input, leading to poor learning. For SHD, removing DB had less of an effect on convergence. This is likely because SHD inputs are made up of short, event-based spike bursts, and the neurons that respond to them tend to have higher frequencies. In that range, the fixed dampening was closer to what the DB would have assigned anyway (maybe a bit lower again from what it shown in the BRF parameter plot - although it is hard to make a conclusion as the values are very spread out), so the membrane dynamics remained relatively stable. As a result, the SHD model without DB could still train stably, although with slightly lower accuracy. Overall, it seems the divergence boundary plays a more critical role when the task involves low-frequency dynamics and long temporal dependencies, as in ECG, however this effect should be better explored before any definitivte conclusions are made. Lastly, SmR and RP had little effect on convergence speed and stability. #### Conclusion <!-- Between the three mechanisms, we believe based on the presented results that the divergence boundary proved most essential - both for higher accuracy and training stability. While not as impactful, the refactory period and smooth reset still help reduce spike counts/better filter quality spikes. --> The results of this ablation study support several of the original authors’ claims about the importance of the BRF neuron’s three core mechanisms. Among them, the divergence boundary (DB) proved most essential, both for accuracy and convergence. The smooth reset (SmR) and refractory period (RP) had more subtle but still meaningful effects. Interestingly, disabling the SmR mechanism resulted in slightly improved accuracy but increased spike counts (compared to the original BRF) which could mean a trade-off between sparsity and expressiveness. These findings do not contradict the original paper but rather further analyze its conclusions. Overall, the results confirm that the full BRF model offers a well-balanced design that supports both efficient and effective learning in spiking neural networks. ### Henry - Jax Port: Speeding up training #### Motivation My contribution focused on porting the existing PyTorch implementation to JAX (**New code variant**). Firstly, JAX's "just-in-time" (JIT) compilation and automatic differentiation capabilities offer significant performance advantages, particularly for iterative numerical simulations common in SNNs, and could potentially decrease the training time. Secondly, JAX's functional programming paradigm and immutability often lead to cleaner, more maintainable code, reducing the likelihood of subtle bugs that can plague stateful models. Finally, exploring the BRF model within a different deep learning framework provides an independent validation of its core mechanisms and ensures the robustness of its observed behaviors, confirming that the benefits are inherent to the model design rather than artifacts of a specific framework's implementation. The outcomes of this contribution are two-fold: - We're exploring if porting the BRF model to JAX delivers the expected performance boost from its JIT compilation and automatic differentiation. We aim to quantify the exact speed improvements. - We're investigating if running our JAX-powered BRF model on Tensor Processing Units (TPUs) provides further performance benefits beyond the initial gains from a local run of JAX itself. #### JAX Speedup techniques ##### Understanding the Role of JIT Compilation in JAX Throughout our JAX implementation, you will encounter frequent use of `@jax.jit`. This is not merely a stylistic choice, it is a critical optimization that underscores JAX’s efficiency, particularly for our spiking neural network model. JIT, or Just-In-Time compilation, transforms the execution of Python functions. Normally, operations are interpreted step-by-step, introducing small overheads between each instruction. However, with JIT compilation, the entire function is compiled into optimized machine code prior to execution. This allows the runtime to eliminate interpretation delays, optimize instruction ordering, and execute the computation in a single efficient pass. In essence, JIT compilation bridges the gap between high-level Python code and low-level performance, making it indispensable for computationally intensive tasks. ###### How JIT works in JAX 1. **Tracing the Operations**: When a `@jax.jit` decorated function is called for the first time with a new set of input shapes (not necessarily values), JAX "traces" through the Python code. It records all the mathematical operations and data flows. 2. **Building an XLA Graph**: The trace is then transformed into a highly optimized computational graph. This graph is expressed in Accelerated Linear Algebra (XLA), which is a specialized compiler designed by Google to optimize numerical computations. 3. **Compiling to Machine Code**: XLA then takes this graph and compiles it into highly efficient, low-level machine code specifically tailored for your hardware. 4. **Executing Rapidly**: From that point onward, every time you call the JIT-compiled function with inputs of the same shape, JAX doesn't re-interpret the Python code. Instead, it directly executes the pre-compiled, lightning-fast machine code. ###### JIT Function Example The following example as demonstrated by the [Jax Documentation](https://docs.jax.dev/en/latest/jit-compilation.html) why JIT is not just a nice-to-have, but a necessity for high-performance machine learning. ```python= import jax import jax.numpy as jnp def selu(x, alpha=1.67, lambda_=1.05): return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) x = jnp.arange(1000000) %timeit selu(x).block_until_ready() # 2.42 ms ± 590 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` The current implementation dispatches operations individually to the accelerator, preventing the XLA compiler from performing broader optimizations across the entire computation. ```python= selu_jit = jax.jit(selu) # Pre-compile the function before timing... selu_jit(x).block_until_ready() %timeit selu_jit(x).block_until_ready() # 273 μs ± 4.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ``` The optimised implementation achieved a **~8.9x speedup**. ###### Usage of JIT We applied JIT compilation wherever we could (step functions, forward pass functions, etc). These are perfect candidates for JIT because they run over and over again, thousands of times per training step. For optimal performance with JIT, it's vital to tell JAX what parts of your function will stay constant during execution and what will change. That's where static_argnums comes in. For instance, in our gaussian function, mu and sigma are hyperparameters that usually don't change during a single computation. By marking them with `@partial(jax.jit, static_argnums=(1, 2))`, we tell JAX to "hardcode" these constant values directly into the compiled machine code, allowing for even deeper optimizations. ```python= @jax.jit def step(x: jnp.ndarray) -> jnp.ndarray: return (x > 0).astype(jnp.float32) # mu and sigma are hyperparameters that stay constant. They can be compiled away!! @partial(jax.jit, static_argnums=(1, 2)) def gaussian(x: jnp.ndarray, mu: float = 0.0, sigma: float = 1.0) -> jnp.ndarray: scale = _INV_SQRT_2PI / sigma arg = (x - mu) / sigma return scale * jnp.exp(-0.5 * arg * arg) @jax.custom_vjp def StepGaussianGrad(x: jnp.ndarray): return step(x) def StepGaussianGrad_fwd(x: jnp.ndarray): y = step(x) return y, x def StepGaussianGrad_bwd(x: jnp.ndarray, g: jnp.ndarray): dfdx = gaussian(x) return (g * dfdx,) ``` ##### Recurrent Computations in JAX `nn.scan` SNNs inherently process sequences timestep by timestep. A naive Python for loop over timesteps would be incredibly slow due to Python's interpretation overhead. `nn.scan` takes this sequential computation and compiles it into a single, highly optimized XLA operation. This is what allows our SimpleResRNN to run efficiently on GPUs and TPUs, making training feasible for long sequences. JAX's scan transformation is designed to efficiently process sequences by "unrolling" loops at compile time. Think of it as a glorified for loop that runs on your accelerator (GPU/TPU) without Python's usual loop overhead. Seen below, is an example of this: ```python= # Fields and method have been omitted for clarity # For full implementation visit https://github.com/1henrypage/brf-neurons class SimpleResRNN(nn.Module): def setup(self): ... self.scanned_core = nn.scan( self.ResScanCore, variable_broadcast="params", split_rngs={"params": False}, in_axes=0, out_axes=0 )(self.hidden_cell, self.out_cell) def __call__(self, x: jnp.ndarray, train: bool = True) -> Tuple[jnp.ndarray, Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray], jnp.ndarray]: ... (final_hidden_z, final_hidden_u, _, _, final_out_u, spike_sum), outs = self.scanned_core(init_state, x) ... class ResScanCore(nn.Module): hidden_cell: nn.Module out_cell: nn.Module @nn.compact def __call__(self, carry, x_t): hidden_z, hidden_u, hidden_v, hidden_q, out_u, spike_sum = carry x_in = jnp.concatenate([x_t, hidden_z], axis=1) new_hidden_z, new_hidden_u, new_hidden_v, new_hidden_q = self.hidden_cell(x_in, (hidden_z, hidden_u, hidden_v, hidden_q)) new_spike_sum = spike_sum + jnp.sum(new_hidden_z) new_out_u = self.out_cell(new_hidden_z, out_u) return (new_hidden_z, new_hidden_u, new_hidden_v, new_hidden_q, new_out_u, new_spike_sum), new_out_u ``` ##### Leveraging `jax.vmap` for loss calculation When dealing with sequential data over large batches, being able to efficiently calculate losses is important. Let's look at our `apply_seq_loss_jax` function: ```python= @jax.jit def apply_seq_loss_jax(outputs, targets): targets_indices = jnp.argmax(targets, axis=2) # This vmap is crucial to apply the loss function correctly across sequence length and batch losses_per_element = jax.vmap(jax.vmap(nll_loss_fn))(outputs, targets_indices) return jnp.sum(losses_per_element) / (outputs.shape[0] * outputs.shape[1]) ``` The most important line here is `losses_per_element = jax.vmap(jax.vmap(nll_loss_fn))(outputs, targets_indices)`. ###### Vectorization Instead of looping in Python over our `sequence_length` and `batch_size` dimensions, `jax.vmap` transforms our single-element `nll_loss_fn` into a function that can operate on entire batches or sequences simultaneously. It creates a "batched" version of our function, pushing the parallel computation down to the highly optimized XLA level. - Inner `vmap`: `jax.vmap(nll_loss_fn)` takes our nll_loss_fn (which expects a single prediction and target) and tells JAX to map it across the first axis of its inputs. When applied to outputs and `targets_indices`, this first `vmap` will effectively process all elements along the `batch_size `dimension for a given timestep. - Outer `vmap` (Sequence Dimension): The outer `vmap` then takes the output of the inner vmap (which is already vectorized over the batch) and applies that vectorized function across the `sequence_length` dimension. This means the loss calculation is now parallelized across both the sequence and batch dimensions. ###### Importance of Vectorization - **Eliminates Python Loops**: This is the most critical benefit. `vmap` eliminates the need for explicit for loops in Python, which would introduce significant overhead. All the "looping" is compiled into a single, efficient XLA operation. - **Maximizes Parallelism**: By explicitly telling JAX to vectorize operations over these key dimensions, we ensure that the underlying hardware (GPU/TPU) is utilized to its fullest extent, performing many loss calculations simultaneously. - **Composability**: vmap is a JAX transformation, meaning it can be composed with other JAX transformations like `jax.jit`. This allows us to get the best of both worlds: highly optimized, compiled code that also operates efficiently on batched and sequential data. #### Performance Benchmarks: Quantifying the Speedups The theoretical advantages of JAX are compelling, but what do they mean in practice for SNN training? To answer this, we ran a series of benchmarks comparing our original (PyTorch-based) implementation on a local PC against our JAX-ported code running on both the same PC and a Google Cloud TPU. ##### Experimental Setup - Local PC: Equipped with an NVIDIA GTX 1080 Ti GPU. - Google TPU: We utilized a Google Cloud TPU v2-8 for accelerated JAX execution. - Methodology: - Metrics: All reported times are the average training time per epoch, calculated over the 2nd through 11th epochs. We skipped the first epoch, as JAX's initial JIT compilation can sometimes make the first run slower and less representative of sustained performance. - Model: We used the `SimpleResRNN` with 1 `BRFCell` and 1 `LICell`. The implementation for these is described on the original paper. - Hyperparameters: We replicated their experimental settings in terms of all external parameters. \* For the SHD dataset, our local GTX 1080 Ti couldn't fit the full training set into memory, so the SHD* results operate on the first 100 samples of the training set. \** These results represent a single experimental run. For robust statistical analysis, we recommend repeating the experiments to establish measurement uncertainty. ##### Results <div style="display: flex; justify-content: space-between; gap: 20px; margin: 20px 0;"> <img src="https://hackmd.io/_uploads/H1pvqnBXeg.png" alt="Average epoch training time" style="width: 100%; height: auto; min-width: 0;"/> <img src="https://hackmd.io/_uploads/rJpvq2B7eg.png" alt="Log scaled training time" style="width: 100%; height: auto; min-width: 0;"/> </div> ##### Discussion By simply adopting JAX, we observe order-of-magnitude speedups across all datasets. This dramatic improvement is a direct testament to JAX's aggressive JIT compilation, efficient array operations, and functional programming paradigm. It demonstrates that JAX significantly reduces the overhead associated with Python, transforming our computations into highly optimized XLA graphs even on consumer-grade GPUs. While the jump from JAX on PC to JAX on TPU isn't as universally dramatic as the initial JAX port, we still see significant additional speedups, particularly on the ECG and SMNIST datasets. TPUs are custom-built for large-scale matrix multiplications and parallel computing, making them ideal for accelerating deep learning workloads. The more modest gains on SHD* are likely due to the smaller dataset size, where the overhead of communication with the TPU might still be a factor, or the specific computational pattern doesn't fully exploit TPU capabilities as much as the other tasks. Our benchmarking also revealed a clear avenue for even greater speedups, particularly evident in the SMNIST dataset's results. In our initial porting efforts, we prioritized getting the core BRF neuron model and its dynamics correctly translated to JAX. Due to time constraints, our data loading pipeline for these experiments was not fully optimized for a pure JAX environment. Specifically, within each training epoch, every batch of data was being loaded as a PyTorch tensor and then converted to a JAX array before being fed into our SimpleResRNN model. This torch to jax array conversion, while seemingly minor, introduces a repeated overhead. For datasets like SMNIST, which typically involves many small batches or a large number of training samples, this per-batch conversion can accumulate into a significant bottleneck. ###### Why this matters for performance JAX thrives when it can operate directly on its own array types and when entire computations are JIT-compiled. When data is constantly being converted from an external framework's tensor type to a JAX array, this breaks the seamless flow and introduces implicit host-device transfers or re-allocations that JAX might otherwise optimize away. We hypothesize that this repeated torch to jax array conversion was a primary contributor to the residual slowness observed, especially in the SMNIST case compared to the relative gains on other datasets. For ECG and SHD*, where either the data size or batching characteristics were different, this overhead might have been less pronounced. ###### Looking Ahead A natural next step for further optimization would be to completely decouple the data loading process from PyTorch, leveraging JAX's own data handling utilities. #### Conclusion Replicating the original Balanced Resonate-and-Fire Neurons codebase in JAX has clearly demonstrated the immense performance benefits this framework offers. These gains are critical for accelerating research, enabling faster experimentation, and tackling larger, more ambitious SNN architectures. We encourage you to explore the full repository to see how we've implemented these optimizations and ported all the models included in the original paper authors' code. ### Bernadett - Effects of hidden layer size To evaluate the impact of hidden layer size in BRF neurons, we conducted experiments using different hidden layer configurations. These experiments were performed on the PS-MNIST dataset due to its increased complexity, which arises from the input permutation. The original paper used a fixed hidden layer size of 256 neurons across all experiments, reporting a final accuracy of 99% on the S-MNIST dataset and 95% on the PS-MNIST dataset. Since PS-MNIST is a permuted version of S-MNIST, it presents a more challenging task and may benefit from a more complex model architecture. We tested the BRF model on PS-MNIST using three configurations: a single hidden layer with 256 neurons, a single hidden layer with 512 neurons, and two hidden layers with 128 and 256 neurons, respectively. All experiments were run for 100 epochs. ![hidden_layer_comparison](https://hackmd.io/_uploads/BkHiCw6Qle.png) The plot indicates that hidden layer size does impact the model's test accuracy. Specifically, larger hidden layers lead to higher test accuracy, suggesting that increased model size is beneficial for more challenging datasets like PS-MNIST. For a fair comparison we also evaluated the BRF neurons with 512 hidden neurons on the S-MNIST dataset and plotted the resulted test accuracy curves. ![image](https://hackmd.io/_uploads/r1imQzJ4ex.png) <!-- From the plot we can see that the larger hidden layer size is also beneficial for the more simple S-MNIST dataset, the amount of improvement however is much less significan than it is for the PS-MNIST dataset. --> The plot shows that a larger hidden layer size improves performance on the simpler S-MNIST dataset, although the improvement is notably smaller compared to that observed on the more challenging PS-MNIST dataset. <!-- ![image](https://hackmd.io/_uploads/SkIX6waXll.png) --> ![image](https://hackmd.io/_uploads/Hyk8QfJEeg.png) <!-- Table 2 shows the performance of the BRF neurons after 100 epochs with different hidden layer sizes on the PS-MNIST and S-MNIST datasets. Here we can see that with a single hidden layer of size 256, the S-MNIST dataset has 4.32% higher accuracy than the PS-MNIST, while with a layer size of 512, this accuracy is only 2.41% higher. --> Table 2 presents the performance of BRF neurons after 100 epochs with varying hidden layer sizes on the PS-MNIST and S-MNIST datasets. With a single hidden layer of 256 neurons, the S-MNIST dataset achieves an accuracy that is 4.32% higher than that of PS-MNIST. However, when the hidden layer size is increased to 512, the accuracy gap narrows to 2.41%, suggesting that larger hidden layers help bridge the performance difference between the two datasets.

    Import from clipboard

    Paste your markdown or webpage here...

    Advanced permission required

    Your current role can only read. Ask the system administrator to acquire write and comment permission.

    This team is disabled

    Sorry, this team is disabled. You can't edit this note.

    This note is locked

    Sorry, only owner can edit this note.

    Reach the limit

    Sorry, you've reached the max length this note can be.
    Please reduce the content or divide it to more notes, thank you!

    Import from Gist

    Import from Snippet

    or

    Export to Snippet

    Are you sure?

    Do you really want to delete this note?
    All users will lose their connection.

    Create a note from template

    Create a note from template

    Oops...
    This template has been removed or transferred.
    Upgrade
    All
    • All
    • Team
    No template.

    Create a template

    Upgrade

    Delete template

    Do you really want to delete this template?
    Turn this template into a regular note and keep its content, versions, and comments.

    This page need refresh

    You have an incompatible client version.
    Refresh to update.
    New version available!
    See releases notes here
    Refresh to enjoy new features.
    Your user state has changed.
    Refresh to load new user state.

    Sign in

    Forgot password

    or

    By clicking below, you agree to our terms of service.

    Sign in via Facebook Sign in via Twitter Sign in via GitHub Sign in via Dropbox Sign in with Wallet
    Wallet ( )
    Connect another wallet

    New to HackMD? Sign up

    Help

    • English
    • 中文
    • Français
    • Deutsch
    • 日本語
    • Español
    • Català
    • Ελληνικά
    • Português
    • italiano
    • Türkçe
    • Русский
    • Nederlands
    • hrvatski jezik
    • język polski
    • Українська
    • हिन्दी
    • svenska
    • Esperanto
    • dansk

    Documents

    Help & Tutorial

    How to use Book mode

    Slide Example

    API Docs

    Edit in VSCode

    Install browser extension

    Contacts

    Feedback

    Discord

    Send us email

    Resources

    Releases

    Pricing

    Blog

    Policy

    Terms

    Privacy

    Cheatsheet

    Syntax Example Reference
    # Header Header 基本排版
    - Unordered List
    • Unordered List
    1. Ordered List
    1. Ordered List
    - [ ] Todo List
    • Todo List
    > Blockquote
    Blockquote
    **Bold font** Bold font
    *Italics font* Italics font
    ~~Strikethrough~~ Strikethrough
    19^th^ 19th
    H~2~O H2O
    ++Inserted text++ Inserted text
    ==Marked text== Marked text
    [link text](https:// "title") Link
    ![image alt](https:// "title") Image
    `Code` Code 在筆記中貼入程式碼
    ```javascript
    var i = 0;
    ```
    var i = 0;
    :smile: :smile: Emoji list
    {%youtube youtube_id %} Externals
    $L^aT_eX$ LaTeX
    :::info
    This is a alert area.
    :::

    This is a alert area.

    Versions and GitHub Sync
    Get Full History Access

    • Edit version name
    • Delete

    revision author avatar     named on  

    More Less

    Note content is identical to the latest version.
    Compare
      Choose a version
      No search result
      Version not found
    Sign in to link this note to GitHub
    Learn more
    This note is not linked with GitHub
     

    Feedback

    Submission failed, please try again

    Thanks for your support.

    On a scale of 0-10, how likely is it that you would recommend HackMD to your friends, family or business associates?

    Please give us some advice and help us improve HackMD.

     

    Thanks for your feedback

    Remove version name

    Do you want to remove this version name and description?

    Transfer ownership

    Transfer to
      Warning: is a public team. If you transfer note to this team, everyone on the web can find and read this note.

        Link with GitHub

        Please authorize HackMD on GitHub
        • Please sign in to GitHub and install the HackMD app on your GitHub repo.
        • HackMD links with GitHub through a GitHub App. You can choose which repo to install our App.
        Learn more  Sign in to GitHub

        Push the note to GitHub Push to GitHub Pull a file from GitHub

          Authorize again
         

        Choose which file to push to

        Select repo
        Refresh Authorize more repos
        Select branch
        Select file
        Select branch
        Choose version(s) to push
        • Save a new version and push
        • Choose from existing versions
        Include title and tags
        Available push count

        Pull from GitHub

         
        File from GitHub
        File from HackMD

        GitHub Link Settings

        File linked

        Linked by
        File path
        Last synced branch
        Available push count

        Danger Zone

        Unlink
        You will no longer receive notification when GitHub file changes after unlink.

        Syncing

        Push failed

        Push successfully