antolaga
    • Create new note
    • Create a note from template
      • Sharing URL Link copied
      • /edit
      • View mode
        • Edit mode
        • View mode
        • Book mode
        • Slide mode
        Edit mode View mode Book mode Slide mode
      • Customize slides
      • Note Permission
      • Read
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Write
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Engagement control Commenting, Suggest edit, Emoji Reply
    • Invite by email
      Invitee

      This note has no invitees

    • Publish Note

      Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

      Your note will be visible on your profile and discoverable by anyone.
      Your note is now live.
      This note is visible on your profile and discoverable online.
      Everyone on the web can find and read all notes of this public team.
      See published notes
      Unpublish note
      Please check the box to agree to the Community Guidelines.
      View profile
    • Commenting
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
      • Everyone
    • Suggest edit
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
    • Emoji Reply
    • Enable
    • Versions and GitHub Sync
    • Note settings
    • Note Insights New
    • Engagement control
    • Make a copy
    • Transfer ownership
    • Delete this note
    • Save as template
    • Insert from template
    • Import from
      • Dropbox
      • Google Drive
      • Gist
      • Clipboard
    • Export to
      • Dropbox
      • Google Drive
      • Gist
    • Download
      • Markdown
      • HTML
      • Raw HTML
Menu Note settings Note Insights Versions and GitHub Sync Sharing URL Create Help
Create Create new note Create a note from template
Menu
Options
Engagement control Make a copy Transfer ownership Delete this note
Import from
Dropbox Google Drive Gist Clipboard
Export to
Dropbox Google Drive Gist
Download
Markdown HTML Raw HTML
Back
Sharing URL Link copied
/edit
View mode
  • Edit mode
  • View mode
  • Book mode
  • Slide mode
Edit mode View mode Book mode Slide mode
Customize slides
Note Permission
Read
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Write
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Engagement control Commenting, Suggest edit, Emoji Reply
  • Invite by email
    Invitee

    This note has no invitees

  • Publish Note

    Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

    Your note will be visible on your profile and discoverable by anyone.
    Your note is now live.
    This note is visible on your profile and discoverable online.
    Everyone on the web can find and read all notes of this public team.
    See published notes
    Unpublish note
    Please check the box to agree to the Community Guidelines.
    View profile
    Engagement control
    Commenting
    Permission
    Disabled Forbidden Owners Signed-in users Everyone
    Enable
    Permission
    • Forbidden
    • Owners
    • Signed-in users
    • Everyone
    Suggest edit
    Permission
    Disabled Forbidden Owners Signed-in users Everyone
    Enable
    Permission
    • Forbidden
    • Owners
    • Signed-in users
    Emoji Reply
    Enable
    Import from Dropbox Google Drive Gist Clipboard
       Owned this note    Owned this note      
    Published Linked with GitHub
    • Any changes
      Be notified of any changes
    • Mention me
      Be notified of mention me
    • Unsubscribe
    --- title: How to deal with class imbalance in Pytorch tags: Templates, Talk description: View the slide with "Slide Mode". --- # How to deal with Class Imbalance ![](https://i.imgur.com/QeTAVYW.png) Class imbalance is a common problem in machine learning in general. It appears naturally in the annotation of data, of any type (tabular, textual, image), and in any domain. For example, in the case of classification of the road surface condition (RSC), it makes sense that the label "Dry" comes up in majority. We are going to take the example of RSC classification dataset to illustrate the methods explained in this post. It is available here for free: https://arc-gis-hub-home-arcgishub.hub.arcgis.com/datasets/IowaDOT::road-weather-information-system-rwis-surface-data It is a simple open source tabular dataset with RSC measures from Iowa and other information relative to it like temperature, localization, etc. Here are the proportions of different classes in the dataset from December 15, 2020 to May 4, 2021: ![](https://i.imgur.com/UXLIePW.png) The "Dry" class is present in the majority at 62%, and the 3 other classes "Ice Watch", "Wet", "Trace Moisture" represent 35% of the dataset. Finally, the 3 classes " Frost ", " Ice Warning ", and " Chemical Wet " represent only 3% of the dataset. This dataset is an example of class imbalance, the proportions of the 7 classes are very unequal. We will see several methods to deal with this class imbalance problem: - Metric choice - Data augmentation - Data sampling - Cost-Sensitive learning # Metric choice A metric is used to say how good our classifier is, but we still must define what "good" means. Good can have different meanings depending on the problem we want to address. The most basic and intuitive metric is accuracy, it is the ratio of the number of good predictions to the total number of predictions. In the case of an unbalanced dataset, this metric is misleading. If a machine learning model simply classifies each data sample with the class present in majority, it obtains an accuracy equal to the proportion of the majority class in the dataset. This phenomenon is frequent in deep learning, where training based on gradient descent can naively lead to the prediction of a single class by the model. Another basic metric is precision: the ratio of the number of True Positive TP to TP + FP, in other words the proportion of predicted positive samples (TP+FP) that really are positive (TP). The complementary metric to precision is recall, which is the proportion of positives (TP+FN) that are correctly classified (TP). This metric focuses in some way on the positive class while omitting the other class. It is worth TP over (TP + FN). The importance we give to the precision or recall of a classifier depends on the problem, in the case where the classification of a certain class (positive by convention) is critical, we try to have a high recall so that a maximum of positive samples are classified as positive even if negative samples are also classified as positive by mistake. Otherwise, we want to have a high precision, so that the classifier predicts as positive a minimum of negative samples. To summarize, high precision implies low FP, and high recall implies low FN. The F1-score is a widely used metric that is a harmonic mean of precision and recall. The harmonic mean implies that more weight is given to lower values. ![](https://i.imgur.com/gfGWWU3.png) We can easily extrapolate the notions of precision and recall to the multi-class classification case, by calculating them for each class separately. In the same way, the F1-score is composed of as many values as there are classes, it then remains to combine these values to have a global F1-score (a single value as a metric). There are 3 ways to do this: - Arithmetic mean: Macro-F1 - Weighted mean by number of samples per class: Weighted-F1 - Global precision = Global Recall = Accuracy: Micro-F1 (it comes back to the deceptive metric) Note that these different ways of combining F1-score values can be used for precision and recall as well. The Macro-F1 is sometimes quoted to designate a slightly different metric: first the Macro-precision and the Macro-recall (arithmetic average of the precision/recall per class) are computed and then the harmonic average of the 2 obtained values is calculated. The sklearn and torchmetrics libraries calculate the Macro-F1 in the standard way without calculating the Macro-recall and the Macro-precision. The weighted-F1 is not suitable for our problem, because it weights the mean by the number of samples in each class, which amounts to a misleading metric just like accuracy did. Therefore, we use Macro-F1 instead, which does not weight by the number of samples per class. However, in this example, there are some classes present in extreme sub minority in the dataset, these 3 classes represent 3% of the dataset, they are not in sufficient quantity for the model to take them into account during training, plus “Ice Warning” and “Frost” are very close to the “Ice Watch” class. So, we would expect the model to classify these as “Ice Watch” which would penalize the score Moreover, let’s admit that “chemical wet” is irrelevant to the classification problem because not from natural cause. Therefore, it is preferable not to count these 3 classes in the calculation of the metric because, despite their sub-number, they could have a considerable impact on the (unweighted) Macro-F1 metric. How to implement this? We just calculate the f1-score of each class using a library like torchmetrics and make the arithmetic average of these f1-scores without considering the 3 classes in question. #### In practice In deep learning, it is usual to compute the loss function and the metrics online on data batches. In the case of the accuracy metric and where the batch size is constant, there is no difference between the accuracy computed continuously on the batches of an epoch and the accuracy computed on all the prediction/label pairs of the epoch at once. But in the case of our Macro-F1 metric, the 2 calculations are unequal, so we must calculate the Macro-F1 at the end of an epoch by keeping in RAM the prediction/label pairs of the epoch. #### Remarks Note also that the precision and recall metrics can be useful when considering production constraints. In the RSC classification task, it is important that the ADAS system do not miss that the RSC is wet or icy for obvious reasons. Note that the confusion matrix is a complete indicator of the performance of a classifier, where the sum on a row is the number of samples predicted to a class = TP + FP. The sum on a column is the number of samples belonging to such and such class = TP + FN. # Image augmentation Data augmentation aims at artificially creating new data from the existing training dataset, by applying modifications to it. In the case of computer vision, the possible modifications are for example cropping, rotation, flipping, applying a filter to the image, varying the saturation, etc. ![](https://i.imgur.com/YNKYHCR.png) Data augmentation serves several purposes, it can be used to complete the image dataset if it does not contain enough data, or to make the models create certain invariants by focusing on augmentations that synthetically create such cases that do not appear naturally enough for the model to learn by itself (such as the glow of the sun in front of the camera in an image). It can also be used to create more data in underrepresented classes in a dataset, in which case, separately from training, modifications must be made to the relevant images and added to the dataset. This method has a few drawbacks: - We need to store new images. - This is difficult to scale, as changing the way we augment data potentially requires performing other operations on the old synthetic data added (e.g. deleting it to replace it, etc.). - We want to be able to do data augmentation not just on underrepresented classes (for the reasons explained above). An alternative is to apply data augmentation on the fly, while modifying the way the data is sampled during the training. #### In practice Torchvision package provides data augmentation techniques for computer vision tasks, but Albumentations https://albumentations.ai/ is always preferred, because it runs faster, and is very complete. I recommend this website to visualize in a few clicks image augmentation with Albumentations https://albumentations-demo.herokuapp.com/. # Data sampling There are different methods of data sampling in machine learning. In particular, to solve class imbalance problems, the simplest ones are: - Under-sampling the data of over-represented classes. - Oversampling the data of underrepresented classes during training. In addition, we can sample this data at several levels: - At the epoch level, thus changing the granularity of the training. - At the training batch level. One method is to sample the data at the epoch level in order to have a balanced distribution of classes during the epoch, and to be able to apply data augmentation on the fly. This is equivalent to doing random under-sampling at each epoch. Implementing such a custom method with Pytorch requires an understanding of the Torch.utils.data module. #### Torch.utils.data.Dataset This is an abstract class that represents the data in a map-style way, it is necessary to implement the "__getitem__()" method that associates each sample of the dataset with an index. For example, by doing custom_dataset[idx], we could read the idx-th image and its corresponding label from a folder on the disk. We use this class to create our custom dataset, each dataset sample corresponds to a line of our dataframe, and it contains an "img_path" column giving the link of the image on the computer. ![](https://i.imgur.com/MQZzi4E.png) #### Torch.utils.data.DataLoader ![](https://i.imgur.com/UQOYWlZ.png) A DataLoader combines a dataset and a sampler and provides an iterable over the dataset. The "dataset" argument is the only mandatory argument to provide, it indicates where to load the data from. Two types of dataset in DataLoader argument: - Map-style (Torch.utils.data.Dataset) - Iterable-style (Torch.utils.data.IterableDataset) We are interested in the "sampler" argument because we want to change its default behavior to solve the class imbalance problem. It defines the strategy to draw samples from the dataset, if specified, the shuffle argument should not be. The batch_sampler argument does the same thing but for sample batches, it is mutually exclusive with batch_size, shuffle, sampler, and drop_last. The "sampler" and "batch_sampler" arguments are incompatible with IterableDataset since the custom sampler must yield an index batch. By default, automatic batching with the default collate_fn function aggregates the samples into a batch before outputting them from the DataLoader. The num_workers argument bypasses the constraint posed by the Global Interpreter Lock (GIL) of a python process. The main process generates the sample indices and then sends them to the workers. The prefetch_factor argument, equal to two by default, can be decreased to save RAM. #### Torch.utils.data.Sampler The custom sampler passed as an argument to the dataloader must inherit from this class. The __iter()__ method must be implemented in order to provide a way to iterate over the indices of the dataset elements. From this class we create a sampling method at the epoch level. We start by computing weights for each sample from the number of occurrences of the classes in the dataset. The weights do not need to sum to 1. Finally, the "torch.multinomial()" method performs a number of independent draws without replacement of a set of samples with a fixed weight associated with each sample calculated in the custom sampler constructor. It is then sufficient to pass an instance of this class as the "sampler" argument to the dataloader. ![](https://i.imgur.com/fDYa7Gl.png) On the example of the data from one station, extracting a fixed amount of data randomly, we get the proportions: ![](https://i.imgur.com/FJtzkB8.png) Then using epoch sampling which weights by class: ![](https://i.imgur.com/VbfLhbi.png) The classes that were under-represented at the extreme ("Frost", "Ice Warning", "Chemical Wet") were all drawn and are still under-represented because the draw was made without replacement. The 4 other classes, previously unbalanced, are now in the same proportion. We can also inherit from this class to create a training batch scale sampling method directly, we then pass an instance of the custom sampler to the "batch_sampler" argument of the dataloader. The "__iter__()" method of the custom sampler is then called once at the beginning of each epoch, just as in the epoch-wide sampling method. ![](https://i.imgur.com/z0ozofd.png) #### Remarks When using pytorch lightning, there is the "reload_dataloaders_every_epoch" argument of the pytorch-lightning.trainer.Trainer class that allows reloading the dataloader at each epoch, and in particular the "__iter__" methods of custom samplers. Thus, trainers have batches of chauqe epochs are different from each other, which is essential. The torch.utils.data.Subset class could have been a lead in order to perform a training on a subset of the dataset, so as to balance on the classes. The torch.utils.data.sampler.WeightedRandomSampler class samples elements from a weight list that it takes as an argument. It can be used to do the same thing as the epoch sampling method seen previously, it is then necessary to calculate the weights separately. However, it is often necessary to change the default behavior, for example by manually limiting the drawing of certain classes that you do not want to see above a certain threshold.

    Import from clipboard

    Paste your markdown or webpage here...

    Advanced permission required

    Your current role can only read. Ask the system administrator to acquire write and comment permission.

    This team is disabled

    Sorry, this team is disabled. You can't edit this note.

    This note is locked

    Sorry, only owner can edit this note.

    Reach the limit

    Sorry, you've reached the max length this note can be.
    Please reduce the content or divide it to more notes, thank you!

    Import from Gist

    Import from Snippet

    or

    Export to Snippet

    Are you sure?

    Do you really want to delete this note?
    All users will lose their connection.

    Create a note from template

    Create a note from template

    Oops...
    This template has been removed or transferred.
    Upgrade
    All
    • All
    • Team
    No template.

    Create a template

    Upgrade

    Delete template

    Do you really want to delete this template?
    Turn this template into a regular note and keep its content, versions, and comments.

    This page need refresh

    You have an incompatible client version.
    Refresh to update.
    New version available!
    See releases notes here
    Refresh to enjoy new features.
    Your user state has changed.
    Refresh to load new user state.

    Sign in

    Forgot password

    or

    By clicking below, you agree to our terms of service.

    Sign in via Facebook Sign in via Twitter Sign in via GitHub Sign in via Dropbox Sign in with Wallet
    Wallet ( )
    Connect another wallet

    New to HackMD? Sign up

    Help

    • English
    • 中文
    • Français
    • Deutsch
    • 日本語
    • Español
    • Català
    • Ελληνικά
    • Português
    • italiano
    • Türkçe
    • Русский
    • Nederlands
    • hrvatski jezik
    • język polski
    • Українська
    • हिन्दी
    • svenska
    • Esperanto
    • dansk

    Documents

    Help & Tutorial

    How to use Book mode

    Slide Example

    API Docs

    Edit in VSCode

    Install browser extension

    Contacts

    Feedback

    Discord

    Send us email

    Resources

    Releases

    Pricing

    Blog

    Policy

    Terms

    Privacy

    Cheatsheet

    Syntax Example Reference
    # Header Header 基本排版
    - Unordered List
    • Unordered List
    1. Ordered List
    1. Ordered List
    - [ ] Todo List
    • Todo List
    > Blockquote
    Blockquote
    **Bold font** Bold font
    *Italics font* Italics font
    ~~Strikethrough~~ Strikethrough
    19^th^ 19th
    H~2~O H2O
    ++Inserted text++ Inserted text
    ==Marked text== Marked text
    [link text](https:// "title") Link
    ![image alt](https:// "title") Image
    `Code` Code 在筆記中貼入程式碼
    ```javascript
    var i = 0;
    ```
    var i = 0;
    :smile: :smile: Emoji list
    {%youtube youtube_id %} Externals
    $L^aT_eX$ LaTeX
    :::info
    This is a alert area.
    :::

    This is a alert area.

    Versions and GitHub Sync
    Get Full History Access

    • Edit version name
    • Delete

    revision author avatar     named on  

    More Less

    Note content is identical to the latest version.
    Compare
      Choose a version
      No search result
      Version not found
    Sign in to link this note to GitHub
    Learn more
    This note is not linked with GitHub
     

    Feedback

    Submission failed, please try again

    Thanks for your support.

    On a scale of 0-10, how likely is it that you would recommend HackMD to your friends, family or business associates?

    Please give us some advice and help us improve HackMD.

     

    Thanks for your feedback

    Remove version name

    Do you want to remove this version name and description?

    Transfer ownership

    Transfer to
      Warning: is a public team. If you transfer note to this team, everyone on the web can find and read this note.

        Link with GitHub

        Please authorize HackMD on GitHub
        • Please sign in to GitHub and install the HackMD app on your GitHub repo.
        • HackMD links with GitHub through a GitHub App. You can choose which repo to install our App.
        Learn more  Sign in to GitHub

        Push the note to GitHub Push to GitHub Pull a file from GitHub

          Authorize again
         

        Choose which file to push to

        Select repo
        Refresh Authorize more repos
        Select branch
        Select file
        Select branch
        Choose version(s) to push
        • Save a new version and push
        • Choose from existing versions
        Include title and tags
        Available push count

        Pull from GitHub

         
        File from GitHub
        File from HackMD

        GitHub Link Settings

        File linked

        Linked by
        File path
        Last synced branch
        Available push count

        Danger Zone

        Unlink
        You will no longer receive notification when GitHub file changes after unlink.

        Syncing

        Push failed

        Push successfully