---
tags: machine-learning
---
# ResNet: Summary and Implementation
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/0.png?token=AMAXSKNIC7NYZYKOZTZNQIS6WMHPA">
</div>
>This post is divided into 2 sections: Summary and Implementation.
>
>We are going to have an in-depth review of [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) and [Study of Residual Networks for Image Recognition](https://arxiv.org/pdf/1805.00325.pdf) paper which introduces the ResNet architecture.
>
> The implementation uses Pytorch as framework. To see full implementation, please refer to this [repository](https://github.com/3outeille/Research-Paper-Summary/tree/master/src/architecture/resnet/pytorch).
>
> Also, if you want to read other "Summary and Implementation", feel free to check them at my [blog](https://ferdinandmom.engineer/deep-learning/).
# I) Summary
- It's important to understand that the **main problem** here is the **difficulty to optimize a deep network rather than its lack of ability to learn features**.
- Feature learning (or representation learning) is the ability to find a transformation that maps raw data into a representation that is more suitable for a machine learning task (e.g classification).
## 1) Problem
- Intuitively, the more layers we have, the better the accuracy will be.
- So if we take a shallow network that performs well and copy its layers and stack them to make the model deeper, we can expect the deep network to perform comparably good or better than its counterpart.
- ==Surprisingly, as we go deeper, accuracy increases up to a saturation point and then begins to degrade.==
- Unexpectedly, such degradation is not caused by overfitting and making the network even deeper leads to a high training error.
<figure>
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/1.png?token=AMAXSKJ2RNVF56MBUXT7MOK6WMHPC">
<figcaption style="text-align:center"> Figure: Trained on CIFAR-10</figcaption>
</figure>
<br>
- Thus, the deep network performs worse than the shallow network.
- One possible explanation could be that the deep network suffered from the vanishing gradient problem.
- However, it can mostly be fixed with batch normalization and normalized initializations.
- A second explanation could be that the deep network wasn't able to learn the identity function.
- Indeed, it could at least perform exactly like the shallow network by just "learning nothing" (remember the deep network was built by copying and stacking layers of the shallow network).
- But the fact that he wasn't able to perform exactly like the shallow network means he has trouble to learn nothing! (learn the identity function).
- ==This suggest a new problem: **Is learning better networks as easy as stacking more layers ?**==
## 2) Solution
==The solution to this problem is to use a **residual module** so that adding more layer will not cause any performance degradation.==
A residual module is composed of:
- a sequence of convolutions, batch normalization and ReLU activations.
- a residual connection $x$.
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/2.png?token=AMAXSKMCEIOS26M63ITFN226WMHPE">
<figcaption > Figure: Residual module</figcaption>
</div>
<br>
- We then combine through addition the residual connection with the sequence.
- Suppose $H(x) = F(x) + x$. If the deep network wants to learn the idendity function, it just has to use the residual connection and thus, set $F(x)$ to 0 !
- It is always easier for a sequence of layer to fit to a zero than an identity function, so the proposed structure is easier to train and ensure that a deeper network will be at least comparably good or better than its counterpart (**neutral-or-better characteristic**).
- The residual connection is also called **skip connection** because they give a chance for the information to skip the function located within the residual module.
- ==**Skip connection** provides a clear path for gradients to back propagate to early layers of the network. This makes the learning process faster by avoiding vanishing gradient problem.==
- However, the trade of is that residual networks are more prone to overfitting.
- It seems that residual modules are more powerful for very deep networks and could even hurt the performance for very shallow networks if employed improperly.
- When several residual modules are stacked, residual networks can be thought of as a complex combinations or ensemble of many shallower networks.
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/3.png?token=AMAXSKMUBO2XUKQXMEXU7OS6WMHPG">
<figcaption > Figure: Residual module</figcaption>
</div>
## 3) Architecture
There are several types of ResNet-X (with X, the number of layers).
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/4.png?token=AMAXSKIAGV56QYKSDAP7G3S6WMHPG">
</div>
<br>
- For ResNet-50/101/152, they used a bottleneck architecture because they are cheaper in term of operations.
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/5.png?token=AMAXSKKMRMKTDHLF2ELDBTK6WMHPG">
</div>
<br>
- We are going to implement ResNet on CIFAR-10 which architecture is slighty different from the ImageNet one (probably due to its input image size).
- Here is the ResNet-50 architecture on Imagenet:
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/6.png?token=AMAXSKN7ZCE6LSLBUU67ZVS6WMHPI">
</div>
<br>
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/7.png?token=AMAXSKIZ2MTZI24KXSISWZK6WMHPK">
</div>
<br>
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/8.png?token=AMAXSKJJQUEJRABDD2BU6B26WMHPK" width="70%">
</div>
<br>
- We use identity shortcuts when input and output channel dimensions are the same.
- Otherwise, we have 2 options:
- A) Use identity shortcuts with zero padding to increase channel dimension.
- B) Use 1x1 convolution to increase channel dimension (projection shortcut).
- When input and output spatial dimensions don't match, we use one of the 2 above options with stride 2.
- Since we are going to implemenet ResNet-50 on CIFAR-10, the architecture will be slightly different (ResNet-56):
- No maxpooling (probably due to small input size).
- We will use option A)
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/9.png?token=AMAXSKN2MAJZERV55W3E6NC6WMHPM">
</div>
<br>
<div style="text-align: center">
<img src="https://raw.githubusercontent.com/valoxe/image-storage-1/master/research-paper-summary/resnet/10.png?token=AMAXSKKUFYQBE26YZUO3R526WMHPM" width="70%">
</div>
<br>
# Implementation
We are going to implemenet ResNet-56 on CIFAR-10.
## 1) Architecture build
```python
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, option='A'):
super(ConvBlock, self).__init__()
self.features = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)),
('bn1', nn.BatchNorm2d(out_channels)),
('act1', nn.ReLU()),
('conv2', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)),
('bn2', nn.BatchNorm2d(out_channels))
]))
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
if option == 'A':
pad = out_channels//4
self.shortcut = LambdaLayer(lambda x:
F.pad(x[:, :, ::2, ::2], (0,0, 0,0, pad,pad, 0,0)))
if option == 'B':
self.shortcut = nn.Sequential(OrderedDict([
('s_conv1', nn.Conv2d(in_channels, 2*out_channels, kernel_size=1, stride=stride, padding=0, bias=False)),
('s_bn1', nn.BatchNorm2d(2*out_channels))
]))
def forward(self, x):
out = self.features(x)
out += self.shortcut(x)
out = F.relu(out)
return out
```
## 2) Training on CIFAR-10
```python
train_costs, val_costs = train_model()
```
```
[Epoch 1/15]: train-loss = 2.418560 | train-acc = 0.186 | val-loss = 0.345510 | val-acc = 0.260
[Epoch 2/15]: train-loss = 1.814710 | train-acc = 0.318 | val-loss = 0.301081 | val-acc = 0.370
[Epoch 3/15]: train-loss = 1.610937 | train-acc = 0.394 | val-loss = 0.274469 | val-acc = 0.431
[Epoch 4/15]: train-loss = 1.430526 | train-acc = 0.465 | val-loss = 0.237520 | val-acc = 0.504
[Epoch 5/15]: train-loss = 1.315914 | train-acc = 0.514 | val-loss = 0.223509 | val-acc = 0.553
[Epoch 6/15]: train-loss = 1.204232 | train-acc = 0.560 | val-loss = 0.200776 | val-acc = 0.586
[Epoch 7/15]: train-loss = 1.133270 | train-acc = 0.589 | val-loss = 0.196912 | val-acc = 0.612
[Epoch 8/15]: train-loss = 1.051415 | train-acc = 0.621 | val-loss = 0.191916 | val-acc = 0.633
[Epoch 9/15]: train-loss = 0.966809 | train-acc = 0.653 | val-loss = 0.166661 | val-acc = 0.670
[Epoch 10/15]: train-loss = 0.902538 | train-acc = 0.678 | val-loss = 0.162810 | val-acc = 0.690
[Epoch 11/15]: train-loss = 0.828471 | train-acc = 0.705 | val-loss = 0.147712 | val-acc = 0.713
[Epoch 12/15]: train-loss = 0.762280 | train-acc = 0.729 | val-loss = 0.135717 | val-acc = 0.729
[Epoch 13/15]: train-loss = 0.723422 | train-acc = 0.744 | val-loss = 0.142070 | val-acc = 0.739
[Epoch 14/15]: train-loss = 0.656613 | train-acc = 0.766 | val-loss = 0.120628 | val-acc = 0.757
[Epoch 15/15]: train-loss = 0.601020 | train-acc = 0.786 | val-loss = 0.123011 | val-acc = 0.757
```
## 3) Evaluating model
```python
nb_test_examples = 10000
correct = 0
model.eval().cuda()
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
# Make predictions.
prediction = model(inputs)
# Retrieve predictions indexes.
_, predicted_class = torch.max(prediction.data, 1)
# Compute number of correct predictions.
correct += (predicted_class == labels).float().sum().item()
test_accuracy = correct / nb_test_examples
print('Test accuracy: {}'.format(test_accuracy))
```
```
Test accuracy: 0.7394
```