--- title: How to deal with class imbalance in Pytorch tags: Templates, Talk description: View the slide with "Slide Mode". --- # How to deal with Class Imbalance ![](https://i.imgur.com/QeTAVYW.png) Class imbalance is a common problem in machine learning in general. It appears naturally in the annotation of data, of any type (tabular, textual, image), and in any domain. For example, in the case of classification of the road surface condition (RSC), it makes sense that the label "Dry" comes up in majority. We are going to take the example of RSC classification dataset to illustrate the methods explained in this post. It is available here for free: https://arc-gis-hub-home-arcgishub.hub.arcgis.com/datasets/IowaDOT::road-weather-information-system-rwis-surface-data It is a simple open source tabular dataset with RSC measures from Iowa and other information relative to it like temperature, localization, etc. Here are the proportions of different classes in the dataset from December 15, 2020 to May 4, 2021: ![](https://i.imgur.com/UXLIePW.png) The "Dry" class is present in the majority at 62%, and the 3 other classes "Ice Watch", "Wet", "Trace Moisture" represent 35% of the dataset. Finally, the 3 classes " Frost ", " Ice Warning ", and " Chemical Wet " represent only 3% of the dataset. This dataset is an example of class imbalance, the proportions of the 7 classes are very unequal. We will see several methods to deal with this class imbalance problem: - Metric choice - Data augmentation - Data sampling - Cost-Sensitive learning # Metric choice A metric is used to say how good our classifier is, but we still must define what "good" means. Good can have different meanings depending on the problem we want to address. The most basic and intuitive metric is accuracy, it is the ratio of the number of good predictions to the total number of predictions. In the case of an unbalanced dataset, this metric is misleading. If a machine learning model simply classifies each data sample with the class present in majority, it obtains an accuracy equal to the proportion of the majority class in the dataset. This phenomenon is frequent in deep learning, where training based on gradient descent can naively lead to the prediction of a single class by the model. Another basic metric is precision: the ratio of the number of True Positive TP to TP + FP, in other words the proportion of predicted positive samples (TP+FP) that really are positive (TP). The complementary metric to precision is recall, which is the proportion of positives (TP+FN) that are correctly classified (TP). This metric focuses in some way on the positive class while omitting the other class. It is worth TP over (TP + FN). The importance we give to the precision or recall of a classifier depends on the problem, in the case where the classification of a certain class (positive by convention) is critical, we try to have a high recall so that a maximum of positive samples are classified as positive even if negative samples are also classified as positive by mistake. Otherwise, we want to have a high precision, so that the classifier predicts as positive a minimum of negative samples. To summarize, high precision implies low FP, and high recall implies low FN. The F1-score is a widely used metric that is a harmonic mean of precision and recall. The harmonic mean implies that more weight is given to lower values. ![](https://i.imgur.com/gfGWWU3.png) We can easily extrapolate the notions of precision and recall to the multi-class classification case, by calculating them for each class separately. In the same way, the F1-score is composed of as many values as there are classes, it then remains to combine these values to have a global F1-score (a single value as a metric). There are 3 ways to do this: - Arithmetic mean: Macro-F1 - Weighted mean by number of samples per class: Weighted-F1 - Global precision = Global Recall = Accuracy: Micro-F1 (it comes back to the deceptive metric) Note that these different ways of combining F1-score values can be used for precision and recall as well. The Macro-F1 is sometimes quoted to designate a slightly different metric: first the Macro-precision and the Macro-recall (arithmetic average of the precision/recall per class) are computed and then the harmonic average of the 2 obtained values is calculated. The sklearn and torchmetrics libraries calculate the Macro-F1 in the standard way without calculating the Macro-recall and the Macro-precision. The weighted-F1 is not suitable for our problem, because it weights the mean by the number of samples in each class, which amounts to a misleading metric just like accuracy did. Therefore, we use Macro-F1 instead, which does not weight by the number of samples per class. However, in this example, there are some classes present in extreme sub minority in the dataset, these 3 classes represent 3% of the dataset, they are not in sufficient quantity for the model to take them into account during training, plus “Ice Warning” and “Frost” are very close to the “Ice Watch” class. So, we would expect the model to classify these as “Ice Watch” which would penalize the score Moreover, let’s admit that “chemical wet” is irrelevant to the classification problem because not from natural cause. Therefore, it is preferable not to count these 3 classes in the calculation of the metric because, despite their sub-number, they could have a considerable impact on the (unweighted) Macro-F1 metric. How to implement this? We just calculate the f1-score of each class using a library like torchmetrics and make the arithmetic average of these f1-scores without considering the 3 classes in question. #### In practice In deep learning, it is usual to compute the loss function and the metrics online on data batches. In the case of the accuracy metric and where the batch size is constant, there is no difference between the accuracy computed continuously on the batches of an epoch and the accuracy computed on all the prediction/label pairs of the epoch at once. But in the case of our Macro-F1 metric, the 2 calculations are unequal, so we must calculate the Macro-F1 at the end of an epoch by keeping in RAM the prediction/label pairs of the epoch. #### Remarks Note also that the precision and recall metrics can be useful when considering production constraints. In the RSC classification task, it is important that the ADAS system do not miss that the RSC is wet or icy for obvious reasons. Note that the confusion matrix is a complete indicator of the performance of a classifier, where the sum on a row is the number of samples predicted to a class = TP + FP. The sum on a column is the number of samples belonging to such and such class = TP + FN. # Image augmentation Data augmentation aims at artificially creating new data from the existing training dataset, by applying modifications to it. In the case of computer vision, the possible modifications are for example cropping, rotation, flipping, applying a filter to the image, varying the saturation, etc. ![](https://i.imgur.com/YNKYHCR.png) Data augmentation serves several purposes, it can be used to complete the image dataset if it does not contain enough data, or to make the models create certain invariants by focusing on augmentations that synthetically create such cases that do not appear naturally enough for the model to learn by itself (such as the glow of the sun in front of the camera in an image). It can also be used to create more data in underrepresented classes in a dataset, in which case, separately from training, modifications must be made to the relevant images and added to the dataset. This method has a few drawbacks: - We need to store new images. - This is difficult to scale, as changing the way we augment data potentially requires performing other operations on the old synthetic data added (e.g. deleting it to replace it, etc.). - We want to be able to do data augmentation not just on underrepresented classes (for the reasons explained above). An alternative is to apply data augmentation on the fly, while modifying the way the data is sampled during the training. #### In practice Torchvision package provides data augmentation techniques for computer vision tasks, but Albumentations https://albumentations.ai/ is always preferred, because it runs faster, and is very complete. I recommend this website to visualize in a few clicks image augmentation with Albumentations https://albumentations-demo.herokuapp.com/. # Data sampling There are different methods of data sampling in machine learning. In particular, to solve class imbalance problems, the simplest ones are: - Under-sampling the data of over-represented classes. - Oversampling the data of underrepresented classes during training. In addition, we can sample this data at several levels: - At the epoch level, thus changing the granularity of the training. - At the training batch level. One method is to sample the data at the epoch level in order to have a balanced distribution of classes during the epoch, and to be able to apply data augmentation on the fly. This is equivalent to doing random under-sampling at each epoch. Implementing such a custom method with Pytorch requires an understanding of the Torch.utils.data module. #### Torch.utils.data.Dataset This is an abstract class that represents the data in a map-style way, it is necessary to implement the "__getitem__()" method that associates each sample of the dataset with an index. For example, by doing custom_dataset[idx], we could read the idx-th image and its corresponding label from a folder on the disk. We use this class to create our custom dataset, each dataset sample corresponds to a line of our dataframe, and it contains an "img_path" column giving the link of the image on the computer. ![](https://i.imgur.com/MQZzi4E.png) #### Torch.utils.data.DataLoader ![](https://i.imgur.com/UQOYWlZ.png) A DataLoader combines a dataset and a sampler and provides an iterable over the dataset. The "dataset" argument is the only mandatory argument to provide, it indicates where to load the data from. Two types of dataset in DataLoader argument: - Map-style (Torch.utils.data.Dataset) - Iterable-style (Torch.utils.data.IterableDataset) We are interested in the "sampler" argument because we want to change its default behavior to solve the class imbalance problem. It defines the strategy to draw samples from the dataset, if specified, the shuffle argument should not be. The batch_sampler argument does the same thing but for sample batches, it is mutually exclusive with batch_size, shuffle, sampler, and drop_last. The "sampler" and "batch_sampler" arguments are incompatible with IterableDataset since the custom sampler must yield an index batch. By default, automatic batching with the default collate_fn function aggregates the samples into a batch before outputting them from the DataLoader. The num_workers argument bypasses the constraint posed by the Global Interpreter Lock (GIL) of a python process. The main process generates the sample indices and then sends them to the workers. The prefetch_factor argument, equal to two by default, can be decreased to save RAM. #### Torch.utils.data.Sampler The custom sampler passed as an argument to the dataloader must inherit from this class. The __iter()__ method must be implemented in order to provide a way to iterate over the indices of the dataset elements. From this class we create a sampling method at the epoch level. We start by computing weights for each sample from the number of occurrences of the classes in the dataset. The weights do not need to sum to 1. Finally, the "torch.multinomial()" method performs a number of independent draws without replacement of a set of samples with a fixed weight associated with each sample calculated in the custom sampler constructor. It is then sufficient to pass an instance of this class as the "sampler" argument to the dataloader. ![](https://i.imgur.com/fDYa7Gl.png) On the example of the data from one station, extracting a fixed amount of data randomly, we get the proportions: ![](https://i.imgur.com/FJtzkB8.png) Then using epoch sampling which weights by class: ![](https://i.imgur.com/VbfLhbi.png) The classes that were under-represented at the extreme ("Frost", "Ice Warning", "Chemical Wet") were all drawn and are still under-represented because the draw was made without replacement. The 4 other classes, previously unbalanced, are now in the same proportion. We can also inherit from this class to create a training batch scale sampling method directly, we then pass an instance of the custom sampler to the "batch_sampler" argument of the dataloader. The "__iter__()" method of the custom sampler is then called once at the beginning of each epoch, just as in the epoch-wide sampling method. ![](https://i.imgur.com/z0ozofd.png) #### Remarks When using pytorch lightning, there is the "reload_dataloaders_every_epoch" argument of the pytorch-lightning.trainer.Trainer class that allows reloading the dataloader at each epoch, and in particular the "__iter__" methods of custom samplers. Thus, trainers have batches of chauqe epochs are different from each other, which is essential. The torch.utils.data.Subset class could have been a lead in order to perform a training on a subset of the dataset, so as to balance on the classes. The torch.utils.data.sampler.WeightedRandomSampler class samples elements from a weight list that it takes as an argument. It can be used to do the same thing as the epoch sampling method seen previously, it is then necessary to calculate the weights separately. However, it is often necessary to change the default behavior, for example by manually limiting the drawing of certain classes that you do not want to see above a certain threshold.