# Revisiting Batch Normalization for Training Low-Latency Deep Spiking Neural Networks From Scratch [Reproduction]
Authors:
Wenxuan Liang (5979579)
- Modify code to extract the learnable parameters value of different layers across time-step on the CIFAR-10 Dataset
- Analyze the characteristic of the learnable parameter's distribution
- Make poster and write blog
Ruopu Wu (5925010)
- Modify the code to extract and compute the spike rate for each layer during the training process
- Make poster and write blog
Anran Yang (5904625)
- Reproduce the result on CIFAR-10 with early exit algorithm
- Reproduce heatmap of learnable parameter and make analysis on early exit algorithm
- Make poster and write blog
Boxiang Zhu (5980402)
- Write completely new code to accomplish a fully connected neural network for BNTT and test it using the Sequential MNIST dataset.
- Make poster and write blog, github repository maintenance
Github Link: https://github.com/baimuchu/BNTT-Batch-Normalization-Through-Time/tree/main
Paper Link: https://arxiv.org/abs/2010.01729
## Introduction
### Spiking Neural Network
Spiking Neural Network(SNN) is the model on which the original paper improved upon. SNN is a biological-inspired model used in deep learning, different from other neural networks which use ReLU or other activation functions to compute the output of each layer, in the SNN the output of each layer are spikes which are discrete activation happens at a certain point of time. The spikes are encoded with time-specific information and mimic the mechanism of biological neural sending electronic signals in the brain.
In a Spiking Neural Network, the fundamental unit is a spiking neural that takes input and generates output spike based on its internal membrane state. For the generation of spikes, the most widely used model is the Leaky-Integrate-and-Fire (LIF) model[1] in which the generation of spikes in a spiking neural is in the steps of integration, Leaking, and Firing. In the integration, the spiking neuron integrates incoming spikes denoted as $R$ and $I(t)$ into its membrane potential denoted as $U_m$, and for Leaking, it means that the spiking neuron's membrane potential will gradually reduce over time through a constant $\tau_m$ like leaking. When the membrane potential of the spiking neural reaches a certain threshold, it will generate a spike as output and reset the membrane potential.
$$
\tau_m \frac{dU_m}{dt} = -U_m + RI(t) \tag{1}
$$
This model enables the encoding of time-dependent information into the spikes, in contrast to the constant and continuous state changes in a normal neural network. The advantage of SNN includes being more power-efficient since the generated spikes are binary and sparse in time. The Leaky-Integrate-and-Fire (LIF) model also enables the encoding of time-dependent information which makes it inherently suitable for processing of time-based data.
However, the Spiking Neural Network also has its problems, because the spike function is non-differentiable, so it could not apply the gradient decent training. Normalization could be used to make a neural network with an activation function like ReLU behave like SNN while able to be trained through gradient descent, but this generally will take more time to train. In the original paper, different learnable parameters $\gamma$ are used for each Normalization layer in an attempt to enable faster training together with mechanisms like early exit and also better learn the characteristic of input spikes across time.
### Batch Normalization Through Time (BNTT)
Batch Normalization Through Time (BNTT) is an innovative approach designed specifically for Spiking Neural Networks (SNNs) to address the inherent challenges posed by their temporal dynamic nature. Unlike traditional Batch Normalization (BN) methods which apply consistent parameters across all timesteps, BNTT introduces a novel mechanism where the BN parameters are varied at each timestep. This temporal decoupling of parameters enables BNTT to adapt to and learn from the time-varying input distributions effectively.
The fundamental equation for a BNTT layer can be expressed as:
$$
\hat{x}_t^l = \gamma_t^l \left(\frac{x_t^l - \mu_t^l}{\sqrt{(\sigma_t^l)^2 + \epsilon}}\right) + \beta_t^l
\tag{2}
$$
Where, $x_t^l$ is the input to layer $l$ at timestep $t$. $\mu_t^l$ and $(\sigma_t^l)^2$ are the mean and variance for layer $l$ at timestep $t$. $\gamma_t^l$ and $\beta_t^l$ are the learnable scale and shift parameters for layer $l$ at timestep $t$. $\hat{x}_t^l$ is the output that has been normalized and then rescaled and shifted. $\epsilon$ is a small constant added for numerical stability to prevent division by zero.
In SNNs, this mechanism is particularly crucial because the activity of the network varies over time. BNTT allows for different normalization parameters at each timestep, thus adapting more effectively to temporal signal changes, which enhances network performance without sacrificing the sparsity of spikes. This method enables the network to make effective predictions at earlier timesteps, avoiding the need to wait until the end of the time sequence, thereby reducing latency and improving computational efficiency.
Key characteristics of BNTT include:
- **Temporal Adaptation**: BNTT layers adjust their parameters independently for each timestep, which helps in capturing the temporal dynamics of input spikes more accurately.
- **Enhanced Learning**: By aligning the normalization parameters closely with the temporal variations in data, BNTT facilitates more stable and effective training of SNNs.
- **Improved Performance**: Implementing BNTT has shown to enhance the robustness of SNNs against variations in input data and noise, leading to better performance on various tasks compared to other normalization techniques.
The structure of BNTT involves computing the mean and variance for each timestep independently, followed by applying these statistics to normalize the layer outputs. This method ensures that each layer of the SNN can respond optimally to the spikes received at different times, thereby maintaining a dynamic and adaptive learning environment.
### Experiment Overview
In the original paper, the authors validate the effectiveness by evaluating it on three static datasets (i.e., CIFAR10, CIFAR-100, Tiny-ImageNet), one neuromophic dataset (i.e., DVS-CIFAR10), and one sequential dataset (i.e., Sequential MNIST) and achieve near state-of-the-art performance. The dataset is shown in Figure 1.

<center>
Figure 1. CIFAR-10 Dataset
</center></br>
To further assess the impact of BNTT layers on model performance and the spiking activity within each layer of SNNs, the authorsinvestigated on the distribution of the BNTT learnable parameter (i.e., $\gamma$) and spike rate in different layers and then visualize them. Through extensive analysis, they have concluded that unlike the conversion method and the surrogate gradient method, which maintain consistent spike activity throughout different layers over time, the layers trained using BNTT exhibit a Gaussian-like pattern in spike activity. In BNTT, each layer's activity peaks at a specific time range before diminishing. Furthermore, the peak activity times differ across layers, with earlier layers peaking sooner and later layers peaking in subsequent time steps. This phenomenon implies that learnable parameters in BNTT enable the networks to pass the visual information temporally from shallow to deeper layers in an effective manner. Based on those observation, they further propose a temporal early exit algorithm. It proposes that an SNN can predict at an earlier time-step and does not need to wait till the end of the time period to make a prediction.
In our reproduction, (i) We adhered to the original authors' methodology to train an SNN model with BNTT layers on the CIFAR-10 and sequential MNIST datasets. Due to computational constraints, we limited our analysis to these two datasets. We also adjusted the batch size to half its original size when training the VGG-9 networks on CIFAR-10 for the same reason. (ii) Then, building on the CIFAR-10 trained model, we conducted a detailed investigation in which we extracted and visualized several parameters, focusing particularly on the behavior of the learnable parameter $\gamma$ as described in the referenced paper. (iii) From our findings, we conducted a comprehensive analysis of the learnable parameter $\gamma$, justifying how BNTT effectively captures the temporal characteristics of SNNs. (iv) By monitoring the distribution of $\gamma$, we also explored the potential of the early exit algorithm, which reduces the number of timesteps required during the inference stage and could decrease latency. (v) Furthermore, we examined the spike rate of each layer to shed light on the advantages of BNTT over traditional ANN-SNN algorithm.
## Our Reproduction
The overall reproduction result on different datasets together with the original one in the paper are shown in Table 1. On CIFAR-10 dataset, we used VGG-9 networks as backbone and set time-step as 25, as demenstrated in the original paper, and the batch-size is set to 128, half of the original one due to computational constraint and trained 120 epochs to get an accuracy of 84.6%. On the same dataset, we then made use of early exit, reducing time-steps to 20 and trained the model again, with an accuracy of 89.5%. Also, BNTT is performed on Sequential MNIST dataset, with fully connected layers and obtain the result of 98.6% accuracy, transcending the original one.
| |Dataset | Time-step | Accuracy (%) | Accuracy (%) (Paper Original) |
|-----------------|------------------|-----------|--------------|---------------------------|
| BNTT |CIFAR-10 | 25 | 84.6 | 90.5 |
| BNTT+early exit | CIFAR-10 | 20 | 89.5 | 90.3 |
| BNTT(FC) | Sequential MNIST | 25 | 98.6 | 96.6 |
<center>
Table 1. Comparison of reproduction results with results in the paper
</center>
### Reproduce Sequential MNIST Dataset on Full-Connected Network
In our reproduction of the study on the Sequential MNIST dataset, we implemented a fully connected neural network model enhanced by Batch Normalization Through Time (BNTT) layers. This model was tasked with processing images as a sequence of pixels, a challenging format that diverges from traditional 2D image processing. The architecture comprised three layers, transitioning from the input sequence of 784 features to 256 neurons in both the first and second layers, and finally to the 10 output classes corresponding to the digits 0-9. Each layer utilized BNTT to effectively adapt to the inherent temporal dynamics of the input, allowing the network to capture subtle changes in the input sequence over time.
The training of this model was carried out using a batch size of 512, with the input processed over 25 time steps. This setup was iterated through for 62 epochs, leading to an optimal accuracy of 98.6%. This performance not only confirmed the efficiency of BNTT layers in handling sequential input but also slightly exceeded the performance reported in the original paper, which was 96.6%. The improvement in accuracy underscores the potential of BNTT layers in enhancing the network’s capability to manage the internal covariate shift, a common issue in dynamic input environments like Sequential MNIST.
The experiment demonstrated the pivotal role of the BNTT layer in modulating the network's response to sequential data. By dynamically adjusting the normalization parameters across different time steps, particularly the learnable parameter $\gamma$, the BNTT layer ensures that each segment of the input sequence is accurately represented and contributes to the final classification outcome. This approach not only boosts the accuracy but also highlights the adaptability of BNTT in various sequence-based learning tasks.
### Analysis on Learnable Parameter $\gamma$
The novelty of our reproduced paper is that it presents a Batch Normalization Through Time(BNTT) Layer to accelerate the training of Spike Neural Network. Within a BNTT Layer, multiple Batch Normalization Layer that act as time steps are used and allowed to have different parameters especially the learnable parameter $\gamma$, and the output is normalized based on the different parameters of each time step, which enables the model to better represent the spike's time-specific characteristic and increase the representation power. The learnable parameter $\gamma$ is important because it also helped in training to enable the Spike Networks to have better representation power.
In our reproduction, to analyze the characteristic of Learnable Parameters $\gamma$ across layer and time steps, based on the original model that runs on VGG-9 and uses the CIFAR-10 dataset, we implement code in the training process to output the value of a learnable parameter of all BNTT Layers and time-step.
To better illustrate the characteristic of learnable parameter $\gamma$, we take out the value of $\gamma$ for the BNTT Layer after the Convolution Layer 1, 4, 7 in the VGG9 Model which has in total 8 convolution layers. The below figure shows the distribution of $\gamma$ in histograms with the y-axis as frequency and the x-axis as the value of $\gamma$ value.
From the graph, we can see that the distribution of $\gamma$ value is different across layers, for the deeper layer, the value of $\gamma$ will peak at a later time step. In the BNTT Layer after the first Convolution Layer, the $\gamma$ value peaks at a very early time step, but in the fourth Convolution Layer, the $\gamma$ value peaks at a later time step. After the peak, the $\gamma$ value soon drops to near zero level in a later time step, this pattern also complies with the observation of later reproduce analysis on the Early Exit Algorithm.
However, there is a difference in results between what we achieved and the original paper. Although the pattern of distribution is correct, the distribution of the result in the original paper is much more distinct and sparse, for the deeper Convolution Layer, the peak is further in the later time step. Because of the limitation on our computation power, using a 16G GPU we use a batch size of 128 smaller than the batch size of 256 in the original paper, and we trained for 60 epochs instead of 120 epochs in the original paper. As we notice in the early epochs, the learnable parameters $\gamma$ are condensed in the first few time steps and then propagate to later time step, we think the different results are mainly due to different batch size and training epochs.

<center>
Figure 2. The distribution of Learnable Parameter across different convolution layers and time-step. X-axis is value, Y-axis is frequency.
</center></br>
### Early Exit Algorithm
The primary goal of the early exit strategy is to minimize latency during inference by ensuring that each layer requires fewer time steps to compute the output[2]. This approach not only conserves resources but also reduces costs and power consumption. And such efficiency highly aligns with the low-power consumption intuition of spiking neural networks.
The early exit of inference is based on the learnable parameter $\gamma$ in BNTT, generally, a larger $\gamma$ indicates a more active spike activity while a lower $\gamma$ means fewer spikes are generated. From previous research we could infer that generally $\gamma$ manipulates the spike activity of each layer to produce a peak value, which falls again when the time-step gets larger, forming a gaussian-like trend. It is known that a low $\gamma$ value raises the firing threshold, leading to reduced spike activity. Conversely, a high $\gamma$ value promotes increased spike activity. As the intensity of spike activity correlates directly with $\gamma$, it can be deduced that spikes will have minimal impact on the classification outcome when the $\gamma$ values in each layer decrease to their lowest levels. Consequently, we monitor the average $\gamma$ values in each layer at every time step and conclude the inference process once the $\gamma$ values in all layers fall below a predetermined threshold.
In the experiment, we uses VGG-9 networks as backbone, which consists of 7 convolutional layers and a fully-connected layer, leading to 8 BNTT layers in total. We research on the $\gamma$ distribution over time-steps in each layer of the networks, and plot the heatmap of $\gamma$ value corresponding to each BNTT layer through all time-steps, which is shown in Figure.2. Note that in our reproduction, due to GPU resource restriction, in the training process we could only use 128 as batch-size, which is half of the value 256 used in original paper, and this could introduce huge deviation from the result here, which will also be discussed in the following part.
From the figure, we find that after t > 13, all averaged $\gamma$ values fall below the threshold of 0.1. It is also worth noting that, from the first layer to the final one, there is a discernible delay in the tendency $\gamma$ as well as spiking activity. More specifically, in the first layer, $\gamma$ value is highest during the initial time-steps and gradually descends through time-step towards the end. Then in the subsequent layers, the peak of $\gamma$ value appears increasingly later. Also, the largest value of $\gamma$ in the whole process appears in the 7th BNTT layer, which is immediately following the 7th convolutional layer. This is due to the reason that in this layer of VGG-9 networks, its parameters tend to convey more crucial information and thus triggers more active spiking events. Also, in the adjacent layers, the peak values of $\gamma$ are relatively high. This behavior suggests that $\gamma$ enables SNNs to effectively pass the visual information temporally from shallow to deeper layers. All these behavior and tendency of $\gamma$ mentioned above align with the findings reported in the original paper. However, the difference lies in the early exit time-step, where our result is 13 compared to the original 20. Training on a batch size half of the original setting here could not fully leverage the training merits of batch nomalization through time (BNTT), which further influences the distribution and value of $\gamma$ over timesteps and tend to converge to the model using less time-steps.
As a result, we set the early exit time at t = 13. This allows us to establish the optimal time-step for early exit prior to forward propagation, without the need for extra computations. In essence, the temporal early exit method helps us pinpoint the earliest possible time-step during inference that captures essential information, thereby reducing inference latency without a substantial compromise in accuracy.

<center>
Figure 3. Heatmap of learnable parameter over timesteps
</center></br>
### Spike Rate in Different Layers
In order to exemplify the difference between the BNTT algorithm and the traditional ANN-SNN algorithm, we can also analyse the spike rate of each layer of the impulsive neural network during the training process. Specifically, the corresponding spike rate is obtained by calculating the spike produced by the layer at time step T divided by the number of neurons in the layer. The calculation formula is as follows:
$$
R_s(l) = \frac{\text{#spikes of layer } l \text{ over all timesteps}}{\text{#neurons of layer } l} \tag{3}
$$
Spike rate reflects the percentage of neurons activated in each layer of the neural network during training. In traditional impulse neural network algorithms such as backpropagation SNN and ANN-SNN[3], the spike rate of each layer varies greatly due to the large time steps in these algorithms. With the BNTT algorithm, the spike rate of each layer of the neural network should be more even during the training process due to the dynamic adjustment of the batch normalisation parameter and the shorter time steps in the training process.
In order to extract the parameters during the training process, we modified the code. We add a "current epoch" parameter to the model section, and update this parameter continuously at each epoch during the training process. Finally, in the last training epoch, the spikes generated by each layer of the network, i.e., the output, are extracted into a txt file. Finally, by processing the data in the txt file, the results in the figure below were obtained.

<center>
Figure 4. Spike rate of each layer after training in VGG-9 of CIFAR-10 dataset
</center></br>
Comparing the results obtained with those in the paper we can see that it basically reproduces the spike rate mentioned in the paper. The spike rate of the whole network during training is around 85% and the spike rate of each layer is relatively average. It is a good representation of how spikes pass through the neural network under the BNTT algorithm. However, due to the fact that the batch size we used in training was lower compared to the one used in the paper and the epoch of the training was also less, which resulted in us getting a lower spike rate.
## Discussion
During this reproduction, we basically implemented the properties of the BNTT algorithm mentioned in the paper. Applying the BNTT algorithm to a fully connected neural network and using different datasets achieved even higher accuracy than in the paper. This result reflects the wide range of promising applications of the BNTT algorithm in impulse neural networks.
**1. Enhancement of SNN Performance**
Our results demonstrate that BNTT significantly enhances the performance of SNNs. By allowing temporal adaptation of normalization parameters, BNTT ensures that the network can effectively handle the dynamic nature of input data over time.
**2. Impact of Hardware Limitations**
However, when using the VGG model and the CIFAR-10 data type, the performance of our trained network is not as good as the paper. The main reason for this is due to hardware limitations in the training process. Since we used a GPU of P100, we had to adjust the training parameters in order to be able to train on this GPU. To do this, we reduced the batch size in training to 128, which is half of the batch size mentioned in the paper. Since the BNTT algorithm is based on adding a BNTT layer after each convolutional layer, reducing the batch size in training will have a significant impact on this algorithm. This can also be seen in the reproduction results. For example, the results shown in Figure 2, while the distribution pattern embodied in our training results is correct, the distribution of the results in the original paper is much more pronounced and sparse, with peaks appearing at later time steps for deeper convolutional layers. The impact of reducing the batch size on the training process as well as on the neural network is also reflected in Figures 3 and 4.The impact of reducing the batch size on the training process as well as on the neural network is also reflected in Figures 3 and 4. Despite the effect of batch size reduction on the reproduction, we still observe the properties of the BNTT algorithm in the reproduction results.
**3. Future Research Directions**
The consistency in spike rates across layers, despite reduced batch sizes, invites further investigation into the algorithm's potential to facilitate more efficient and compact SNN architectures. Future research could explore:
- Adaptation to Different Computational Resources: How does BNTT perform across a range of hardware capabilities, especially in low-resource environments?
- Integration with Other Neural Network Paradigms: Could the principles of BNTT be successfully integrated with other types of neural networks, particularly those used in non-spiking, conventional deep learning tasks?
- Long-Term Network Performance: Investigating the long-term stability and performance of SNNs under continuous operation can further validate the practical utility of BNTT in industrial and real-world applications.
## Conclusion
In this reproduction, we implement the features of the BNTT algorithm mentioned in this paper. The BNTT algorithm is well suited for different network structures and datasets. Although we have reduced the batch size due to hardware limitations, we can still notice that the distribution of value is different across layers by extracting the network parameters during training, as well as the implementation of the early exit algorithm. In addition, we also observe that the spike rate of each layer of the neural network is almost equal under the BNTT algorithm. Overall, These findings suggest that BNTT can significantly improve the efficiency and accuracy of SNNs in real-time applications. Future work should explore BNTT’s scalability under varied computational conditions and its potential application to more complex datasets.
## Reference
[1] Dayan, Peter, and Laurence F. Abbott. "Theoretical neuroscience, vol. 806." (2001).
[2] Teerapittayanon, S., McDanel, B., and Kung, H.-T. "Branchynet: fast inference via early exiting from deep neural networks," in 2016 23rd International Conference on Pattern Recognition (ICPR), 2464–2469. (2016).
[3] Y. Chen, S. Zhang, S. Ren and H. Qu, "Gradual Surrogate Gradient Learning in Deep Spiking Neural Networks," ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Singapore, Singapore, 2022, pp. 8927-8931, doi: 10.1109/ICASSP43922.2022.9746774.