--- tags: machine-learning --- # Convolutional Neural Network with Numpy (Slow) <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/0.png" width="80%"> </div> > In this post, we are going to see how to implement a Convolutional Neural Network using only Numpy. > > The main goal here is not only to give a boilerplate code but rather to have an in-depth explanation of the underlying mechanisms through illustrations, especially during the backward propagation where things get trickier. > > However, some knowledge about Convolutional Neural Networks building blocs are required. > > To see the full implementation, please refer to my [repository](https://github.com/3outeille/CNNumpy). > ><a href="https://hackmd.io/@machine-learning/blog-post-cnnumpy-fast" style="color:red"> For the more advanced, here is another post</a> **where we implement a faster CNN using im2col/col2im methods**. > > Also, if you want to read some of my blog posts, feel free to check them at my [blog](https://3outeille.github.io/deep-learning/). # I) Architecture We are going to implement the LeNet-5 architecture. <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/1.png"> </div> <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/2.png"> </div> # II) Forward propagation - An image is of shape (h, w, c) where: - h: image height. - w: image width. - c: channel. - Here is an RGB image (3 channels): <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/3.png" width="70%"> </div> ## 1) Convolutional layer - A convolution operation is defined as follow: $$\boxed{Conv(Input, Kernel) = Output}$$ :::info - In theory, a convolution operation is a cross-correlation operation with its kernels flipped by 180°. - In practice, we don't really care whether or not a convolution or a cross-correlation was used since the main goal is to learn the kernels (Indeed, if you were to learn for example, english irregular verbs. It doesn't matter if you learn them from top to bottom or bottom to top. The most important thing is you learn them !). - However, if you decide to use a convolution, you will have to apply a **rotation during the forward and backward propagation**. - For your information, Pytorch **nn.Conv2d()** uses cross-correlation. - **To make it easier in this blog post, convolution will refer to cross-correlation**. ::: - We are going to explain: - How to get the output shape after a convolution. - How the content of the output is generated after a convolution. --- - At the begining of our architecture, we want to make a convolution between an input of shape (32,32,1) and 6 kernels of shape (5,5,1). - In order to perform a convolution operation, **both Input and Kernel must have the exact same number of channels** (in our case 1). - The output will be a new image of shape (28,28,6). - ==Since we have here 6 kernels, we are going to perform 6 convolution operations which will produce 6 outputs (feature maps).== - ==The 6 feature maps will then be stacked to form the 6 channels of the new image.== - During the convolution operation, the following formula is used to get the (28,28) shape: $$O = \left \lfloor \frac{I + 2p - K}{s} + 1 \right \rfloor$$ - O: Output shape. - I: Input shape. - p: padding = 0 (default value). - K: Kernel shape. - s: stride = 1 (default value). - $\left \lfloor ... \right \rfloor$: floor function. <br> - Using the formula above for our example, we have: \begin{align*} O &= \left \lfloor \frac{32 + 2*0 - 5}{1} + 1 \right \rfloor \\ &= \left \lfloor \frac{27}{1} + 1 \right \rfloor \\ &= 28 \end{align*} --- - Now that we know where does the output shape come from, let's see how the content of the output image is generated. - ==During the convolution operation, **the kernel is sliding over the whole input**.== - In our following example, we perform a convolution between a (5,5,3) input and 1 kernel of size (3,3,3) to get an (3,3,1) image. - ==**At each slide, we perform an element-wise multiplication and sum everything to get a single value**.== <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/4.gif"> </div> <br> - If we perform a convolution between an (5,5,3) input and 6 kernels of shape (3,3,3). We will have to repeat the convolution operation 6 times on each of the (3,3,3) kernel. - This will result in an output of shape (3,3,6). - Here is an implementation of what we have seen so far. ```python= def forward_conv(self, X): """ Performs a forward convolution. Parameters: - X : Last conv layer of shape (m, n_C_prev, n_H_prev, n_W_prev). Returns: - out: output of convolution. """ self.cache = X m, n_C_prev, n_H_prev, n_W_prev = X.shape # Define output size. n_C = self.n_F n_H = int((n_H_prev + 2 * self.p - self.f)/ self.s) + 1 n_W = int((n_W_prev + 2 * self.p - self.f)/ self.s) + 1 out = np.zeros((m, n_C, n_H, n_W)) for i in range(m): # For each image. for c in range(n_C): # For each channel. for h in range(n_H): # Slide the filter vertically. h_start = h * self.s h_end = h_start + self.f for w in range(n_W): # Slide the filter horizontally. w_start = w * self.s w_end = w_start + self.f # Element wise multiplication + sum. out[i, c, h, w] = np.sum(X[i, :, h_start:h_end, w_start:w_end] * self.W['val'][c, ...]) + self.b['val'][c] return out ``` ## 2) Pooling layer - Our architecture use **average pooling** layer. Another common pooling layer is the max pooling layer. - **The goal of the average pooling is to reduce the height and width of our image but not the number of channels** by using a stride > 1. <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/5-a.gif" width="80%"> </div> - Here is the implementation. ```python= def forward_pool(self, X): """ Apply average pooling. Parameters: - X: Output of activation function. Returns: - A_pool: X after average pooling layer. """ m, n_C_prev, n_H_prev, n_W_prev = X.shape # Define output size. n_C = n_C_prev n_H = int((n_H_prev - self.f)/ self.s) + 1 n_W = int((n_W_prev - self.f)/ self.s) + 1 A_pool = np.zeros((m, n_C, n_H, n_W)) for i in range(m): for c in range(n_C): for h in range(n_H): h_start = h * self.s h_end = h_start + self.f for w in range(n_W): w_start = w * self.s w_end = w_start + self.f A_pool[i, c, h, w] = np.mean(X[i, c, h_start:h_end, w_start:w_end]) self.cache = X return A_pool ``` ## 3) Multilayer perceptron - ==To go from an output of convolutional/pooling layers to a MLP part, we have to flatten the output.== - For example, the last convolutional layer gives an output of shape (5,5,16). - After flattening it, we get a (5x5x16) = (400). - Then, we just perform a weighted sum. - For more information about it, please refer to my other [blog post](https://hackmd.io/@machine-learning/SkXPHwL8L#II-Forward-propagation) about forward propagation in MLP. - Here is the implementation. ```python= def forward_mlp(self, fc): """ Performs a forward propagation between 2 fully connected layers. Parameters: - fc: fully connected layer. Returns: - A_fc: new fully connected layer. """ self.cache = fc A_fc = np.dot(fc, self.W['val'].T) + self.b['val'] return A_fc ``` # III) Backward propagation Here comes the tricky part. Most of the tutorials I have read so far only say that the backward propagation is the same as in a MLP (which is true). However, we will see that it's not that straightforward to implement especially at the convolution layer. <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/5-b.png" width="80%"> </div> ## 1) Multilayer perceptron - To compute the loss gradient, we first need to compute the errors $\delta$. - The error $\delta$ is computed differently when you are at: - The last layer $L$ of the network (softmax level). - Every other layers $l$. - For more details, please refer to one of my [blog post](https://hackmd.io/1I0ij9HAQM-ocZIQu6aCWQ?view#IV-Backward-propagation). --- - At the last layer $L$, the formula to compute the error is: $$\boxed{\delta^{(L)} = (a^{(L)} - y)}$$ - $\delta^{(L)}$: error at last layer. - $a^{(L)}$: activation function output at last layer. - $y$: ground truth label. <br> - At every other layers $l$, the formula to compute the error is: $$\boxed{\delta^{(l)} = ((\Theta^{(l+1)})^T \delta^{(l+1)})\ .*\ a'(z^{(l)})}$$ - $\delta^{(l)}$: error at layer $l$. - $\Theta^{(l+1)}$: Weight matrix at layer $l+1$. - $\delta^{(l+1)}$: error at layer $l+1$. - $a'(z^{(l)})$: derivative of activation function at layer $l$. <br> - We can then compute the loss gradient at each layer with the following formula: $$\boxed{\dfrac{\partial \mathcal{L}}{\partial \Theta_{j,k}^{(l)}} = \frac{1}{m}\sum_{t=1}^m a_j^{(t)(l)} {\delta}_k^{(t)(l+1)}}$$ - This will then be used to update your weigths: $$\boxed{\Theta^{(l)} \leftarrow \Theta^{(l)} - \alpha \frac{\mathrm{\partial \mathcal{L}} }{\mathrm{\partial}\Theta^{(l)}}}$$ - Here is the implementation: ```python= def backward(self, deltaL): """ Returns the error of the current layer and compute gradients. Parameters: - deltaL: error at last layer. Returns: - new_deltaL: error at current layer. """ fc = self.cache m = fc.shape[0] # Compute gradient. self.W['grad'] = (1/m) * np.dot(deltaL.T, fc) self.b['grad'] = (1/m) * np.sum(deltaL, axis = 0) # Compute error. # We still need to multiply new_deltaL by the derivative of the activation # function which is done in TanH.backward(). new_deltaL = np.dot(deltaL, self.W['val']) return new_deltaL, self.W['grad'], self.b['grad'] ``` ## 2) Pooling layer - During forward propagation, we were averaging value of the input within the pooling window size. - ==During backward propagation, we need to **proportionally** back-propagate the error to the input.== - ==**Remember, no weights gradient are computed here! We only compute the layer gradient.**== <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/6.gif"> </div> - Here is the implementation: ```python= def backward(self, dout): """ Distributes error through pooling layer. Parameters: - dout: Previous layer with the error. Returns: - dX: Conv layer updated with error. """ X = self.cache # Define output size. m, n_C, n_H, n_W = dout.shape dX = np.zeros(X.shape) for i in range(m): # For each image. for c in range(n_C): # For each channel. for h in range(n_H): # Slide the filter vertically. h_start = h * self.s h_end = h_start + self.f for w in range(n_W): # Slide the filter horizontally. w_start = w * self.s w_end = w_start + self.f average = dout[i, c, h, w] / (self.f * self.f) filter_average = np.full((self.f, self.f), average) dX[i, c, h_start:h_end, w_start:w_end] += filter_average return dX ``` ## 3) Convolutional layer - Let's come back to the part of the architecture where we have to perform a convolution between an (14,14,6) input and 16 kernels of shape (5,5,6). - This will output us an (10,10,16) image. <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/7.png" width="40%"> </div> - The idea of backward propagation is to back-propagate the gradient from lower layers to upper layers. - In order to perform backward propagation in the example above, we have to do 2 things: - ==a) Compute the layer gradients $\frac{\partial L}{\partial I}$ at layer (10, 10, 16).== - ==b) Compute the kernel gradients $\frac{\partial L}{\partial K}$ at layer (10,10,16).== - The optimizer (SGD, Adam, RMSprop) will then update the kernels value. - In the following, I will be using 2 formulas. For more details, feel free to take a look at this [blog post](https://www.jefkine.com/general/2016/09/05/backpropagation-in-convolutional-neural-networks/) and at this [note](http://people.csail.mit.edu/jvb/pubs/papers/cnn_tutorial.pdf). --- ### ==a) Compute Layer gradient== - The formula to compute the layer gradient is: $$\boxed{\frac{\partial L}{\partial I} = Conv(K, \frac{\partial L}{\partial O})}$$ - $\frac{\partial L}{\partial I}$: Input gradient. - $K$: Kernels. - $\frac{\partial L}{\partial O}$: Output gradient. - $Conv$: Convolution operation. <br> - However, there is a little problem when we actually want to implement it. - ==The formula asks us to perform a convolution operation between **16** kernels $K$ of shape **(5,5,6)** and $\frac{\partial L}{\partial O}$ of shape **(10, 10, 16)**.== - ==But we know that in order to perform a convolution operation, **both arguments of the convolution operation need to have the same exact number of channels** which is not the case here (6 != 16).== - During forward propagation, in order to get the (10,10, 16) output, 16 convolutions were performed between the input (14,14,6) and 16 kernels of shape (5,5,6). - ==During backward propagation, the **16 channels (feature maps) of the (10, 10) output now contain the gradient** that need to be **back-propagate to the input layer (14, 14, 6)**.== - ==Thus, **we need to "broadcast" the gradient in each feature map of the (10,10) to its associate filter which will then be used to compute the input (14,14,6)**.== <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/8.gif"> </div> - As you can see the sliding of kernels over the (14,14,6) input is in fact a **convolution** ! It was less obvious to notice it though. <br> --- ### ==a) Compute Kernel gradient== - The formula to compute the kernel gradient is: $$\boxed{\frac{\partial L}{\partial K} = Conv(I, \frac{\partial L}{\partial O})}$$ - $\frac{\partial L}{\partial K}$: Kernels gradient. - $I$: Input image. - $\frac{\partial L}{\partial O}$: Output gradient. - $Conv$: Convolution operation. <br> - Same problem than before, performing a convolution is again not straightforward because of channels mismatch (6 != 16). - During forward propagation, we perform convolution between the input (14,14,6) and 16 kernels of shape (5,5,6) which output us an (10,10,16) image. - Thus, each feature map were made by a convolution between the input and each kernel. - ==Then, during backward propagation, **each feature map of the output contains the gradient** that needs to be back-propagate to each filter.== - ==It makes sense that we need to **"broadcast" the gradient in each feature map of the (10,10) to each "slide" we did during forward propagation over the input and add it to its associate filter.**== <div style="text-align: center"> <img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/blog-deep-learning/cnnumpy-naive/9.gif"> </div> --- - Here is the implementation for the above steps: ```python= def backward(self, dout): """ Distributes error from previous layer to convolutional layer and compute error for the current convolutional layer. Parameters: - dout: error from previous layer. Returns: - deltaL: error of the current convolutional layer. """ X = self.cache m, n_C, n_H, n_W = X.shape m, n_C_dout, n_H_dout, n_W_dout = dout.shape dX = np.zeros(X.shape) #Compute dW. for i in range(m): # For each example. for c in range(n_C_dout): # For each channel. for h in range(n_H_dout): # Slide the filter vertically. h_start = h * self.s h_end = h_start + self.f for w in range(n_W_dout): # Slide the filter horizontally. w_start = w * self.s w_end = w_start + self.f self.W['grad'][c, ...] += dout[i, c, h, w] * X[i, :, h_start:h_end, w_start:w_end] dX[i, :, h_start:h_end, w_start:w_end] += dout[i, c, h, w] * self.W['val'][c, ...] #Compute db. for c in range(self.n_F): self.b['grad'][c, ...] = np.sum(dout[:, c, ...]) return dX, self.W['grad'], self.b['grad'] ``` :::info - At this point of the post, I hope you now understand how to build a Convolutional Neural Network from scratch in a naive way. - However, the naive implementation takes a lot of time to train (mainly due to the nested for loops). - As an example, it takes around 4 hours to perform a single epoch on the MNIST dataset. - In the following <a href="https://hackmd.io/@bouteille/B1Cmns09I" style="color:red">post</a>, we are going to see how to implement a faster CNN using im2col/col2im methods. :::