--- title: 2023/03/01 # 簡報的名稱 tags: meeting # 簡報的標籤 slideOptions: # 簡報相關的設定 theme: black # 顏色主題 transition: 'fade' # 換頁動畫 spotlight: enabled: true --- # Paper studying 進入訓練之前,我們先來看看我們有什麼困難 1. Imbalance data 2. Multi-label classification Weekly Experimentals Results: https://wandb.ai/ddcvlab/CXR-Multi-label-Binary-Classification/runs/j6zraxnh?workspace=user-ddcvlab tags: CXR multi-label classification, imbalance samples, eight classes ## Loss function for imbalance data 本章節介紹目前對付不平衡資料常用的 loss function ### Focal Loss 論文地址: https://arxiv.org/pdf/1708.02002.pdf 由於物件偵測任務中,常常出現對於大物件的預測框有比較高的準確度,但是對於小物件的預測框通常表現都很差,所以作者提出了 Focal Loss 用以解決這類問題 除了用於物件偵測的任務中,我們也常用該方法解決不平衡資料的問題 **Focal Loss:** $$ \begin{align} \mbox{loss}&=-\alpha_t(1-p_t)^\gamma\log p_t \\ p_t &= \begin{cases} p, &\mbox{ if } y=1 \\ 1-p, &\mbox{ else} \end{cases} \end{align} $$ 當 $\gamma=0, \alpha=1$ 時,Focal loss 退化到 CE,我們來看看這兩個參數有什麼用 #### Focal Loss $\alpha$ 當 $\gamma=0$,Focal loss 等於 $-\alpha_t\log(p_t)=-\alpha_t\mbox{CE}$ 由上可知 $\alpha$ 可以用來調整每個類別對於 loss 影響大小的權重參數 #### Focal Loss $\gamma$ 當 $\alpha=1$ 時,Focal loss 等於 $-(1-p_t)^\gamma\log p_t=-(1-p_t)^\gamma\mbox{CE}$ 當真實類別為 1 時 $(y=1)$,$\gamma$ 的影響見下圖 ![](https://i.imgur.com/d4tQwkk.png) 由上圖可知,當 $\gamma$ 越大,對於那些好分類的類別 (well-classified examples) 我們給他越小的權重,這會讓模型盡量關注那分難分的樣本 #### Coding ```python! def focal_loss(input_values, gamma): """計算 Focal Loss""" p = torch.exp(-input_values) loss = (1 - p) ** gamma * input_values return loss.mean() class FocalLoss(nn.Module): def __init__(self, weight=None, gamma=0.): super(FocalLoss, self).__init__() assert gamma >= 0 self.gamma = gamma self.weight = weight def forward(self, input, target): """ 計算 Focal Loss :param input: 模型預測輸出 :param target: 真實標籤 :return: Focal Loss """ return focal_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), self.gamma) ``` ### LDAM Loss 論文地址: https://arxiv.org/pdf/1906.07413.pdf LDAM Loss (Label-Distribution-Aware Margin Loss) 是一種針對不平衡數據分類問題設計的損失函數,該函數可以幫助模型更好地處理多類別不平衡數據集。LDAM Loss 通過使用標籤分佈信息調整損失函數的權重,使得模型更關注在少數類別上。 具體來說,LDAM Loss 首先根據每個樣本的標籤分佈信息計算出一個權重向量,這個向量會被用來調整交叉熵損失函數中每個類別的權重。為了使得權重向量更能反映標籤分佈的信息,LDAM Loss 還會引入一個 margin 參數,通過調節 margin 的大小,可以控制權重向量的分佈緊密程度,進一步影響損失函數的權重。 ![](https://i.imgur.com/7MgjR2S.png) #### Coding ```python! class LDAMLoss(nn.Module): def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30): super(LDAMLoss, self).__init__() # 計算每個類別的 m 值 m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) # 將 m 值縮放到 0 到 max_m 之間 m_list = m_list * (max_m / np.max(m_list)) # 將 m_list 轉為 PyTorch 的浮點數張量 m_list = torch.cuda.FloatTensor(m_list) self.m_list = m_list # LDAM 的 scaling factor s,需大於 0 assert s > 0 self.s = s # 損失函數的權重 self.weight = weight def forward(self, x, target): # 將 target 做 one-hot 編碼 index = torch.zeros_like(x, dtype=torch.uint8) index.scatter_(1, target.data.view(-1, 1), 1) # 將 index 轉為浮點數張量 index_float = index.type(torch.cuda.FloatTensor) # 計算 batch 內每個類別的 m 值 batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1)) batch_m = batch_m.view((-1, 1)) # 計算每個樣本的 x_m 值 x_m = x - batch_m # 將 output 設為 x_m 或 x,取決於樣本的類別 output = torch.where(index, x_m, x) # 計算交叉熵損失 return F.cross_entropy(self.s*output, target, weight=self.weight) ``` ### LMF Loss 論文地址: https://arxiv.org/pdf/2212.12741.pdf LMF 簡單的把 LDAM Loss 和 Focal Loss 做線性加權 ![](https://i.imgur.com/UJEWy3g.png) 作者宣稱結果要比其他 loss function 在 F1 指標上漲點 2% ~ 10% ![](https://i.imgur.com/dgPxcqh.png) ## Technique for multi-label classification 接下來我們 review 最近在 CXR 上多標籤分類任務的一些 paper ### Multi-Label Chest X-Ray Classification via Deep Learning 論文地址: https://arxiv.org/ftp/arxiv/papers/2211/2211.14929.pdf 代碼地址: https://github.com/aravindsp/CS598_DL_Chest_X_RayClassification 參考價值: ⭐ 觀後結語: 作者沒講什麼,就是個跑完 code 後發的論文 #### Dataset 作者使用 MIMIC-Chest X-Ray Database (MIMIC-CXR) 資料集,共有 350000 張胸腔影像 有些難判讀的胸腔影像,資料中的標籤標示為 -1,代表醫生也很難分辨影像是屬於哪一種病灶,作者的處理方法是根據不同類別將 -1 的影像判別為正或負 #### Method 作者將原始影像 downsample 到 224\*224 #### Experiment ![](https://i.imgur.com/PtD01U8.png) ![](https://i.imgur.com/TQ0fJIy.png) ### CheXNet: Radiologist-Level Pneumonia Detection on Chest X-Rays with Deep Learning 論文地址: https://arxiv.org/ftp/arxiv/papers/2211/2211.14929.pdf 代碼地址: None 參考價值: ⭐ 觀後結語: 聽說這篇論文很有名,我沒看出來這篇論文提出什麼重要的成果 #### Dataset 作者使用 ChestX-ray 14 資料集,總共 112120 張胸腔影像,作者在做資料集分割時沒有考慮類別,反而是讓訓練、驗證、測試中沒有重複的**病人** #### Method 作者使用的以下方法 1. 權重的 BCE 2. 將影像 downsample 到 224\*224 3. 使用 DenseNet121 訓練 4. 使用 ImageNet 的標準差和平均數做 normalized 5. 使用隨機水平翻轉作為 augmentation ![](https://i.imgur.com/l4f7424.png) #### Experimental ![](https://i.imgur.com/zYrmmz6.png) ### SwinCheX: Multi-label classification on chest X-ray images with transformers 論文地址: https://arxiv.org/pdf/2206.04246.pdf 代碼地址: https://github.com/rohban-lab/SwinCheX 參考價值: ⭐⭐⭐ 觀後結語: 使用多頭的方式解決 multi-label 問題,可以嘗試看看 #### Dataset 作者使用 ChestX-ray 14 資料集,總共 112120 張胸腔影像 #### Method 作者汲取了 multi-task 任務的思想,用 SwinT 當作 shared component,最後在不同類別用不同的分類器去判別每個類別的機率 ![](https://i.imgur.com/nDRQk7Y.png) #### Experimental 作者在實驗過程中的結論有以下幾點 1. 在任何一個模型中,一開始 data 的 AUC 表現都很好,但從某個 epoch 後越來越差,作者認為是 over-fitting 的問題 (但是作者沒有提出怎麼解決) 2. 在 ViT 模型中,作者標住了幾點 - batchsize: 32 - learning rate: 3e-2 ![](https://i.imgur.com/VoZ1eZN.png) 3. 在 headless 的實驗,作者使用了 14 個神經元作為輸出,loss 為 BCE ### Data-Efficient Vision Transformers for Multi-Label Disease Classification on Chest Radiographs 論文地址: https://arxiv.org/pdf/2208.08166.pdf 代碼地址: None 參考價值: ⭐⭐ 觀後結語: 主要研究知識蒸馏和 Transformer 對於 CXR 的影響 #### Dataset 作者使用 CheXpert 資料集,總共 224316 張胸腔影像,類別數量也是 14,資料中有被標記為 -1 的影像,在這篇論文中,作者把這些影像視為正類別 #### Method 作者採用了以下方法: 1. 將影像 downsample 到 224\*224 2. 數據增強採用 random augment 和 random erasing 3. 作者先訓練了一個 DenseNet201 網路作為教師網路,再以 DeiT 作為學生網路 4. 50 epochs,loss 為 Weighted BCE 5. 根據訓練目標不同使用不同學習率 (在 DeiT 中提到使用 ConvNet 作為教師網路會比用 transformer 好) - DenseNet201: 1e-4 - DeiT: 5e-5 #### Experimental ![](https://i.imgur.com/7Iq4cU3.png) 可以看到加上知識蒸餾後真的有效提升網路表現 #### Coding ```python! class DistilledVisionTransformer(VisionTransformer): """包含蒸馏令牌的视觉Transformer。 论文: Training data-efficient image transformers & distillation through attention https://arxiv.org/abs/2012.12877 这个蒸馏ViT的实现来自于https://github.com/facebookresearch/deit """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # 蒸馏令牌 self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.pos_embed, std=.02) self.head_dist.apply(self._init_weights) def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) # 在实现中,两个token是连续的,不像上图中一个在前一个在后 x = torch.cat((cls_tokens, dist_token, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) # transformer encoder部分 for blk in self.blocks: x = blk(x) x = self.norm(x) # 返回两个token embedding后的结果 return x[:, 0], x[:, 1] def forward(self, x): x, x_dist = self.forward_features(x) x = self.head(x) x_dist = self.head_dist(x_dist) if self.training: return x, x_dist else: # 在推理时,返回两个分类器预测的平均值 return (x + x_dist) / 2 ``` ```python! # 利用pretrained weight进行finetune import timm import torch import torch.nn as nn # model_name设置 # vit: 'vit_base_patch16_224' class MyViT(nn.Module): def __init__(self, target_size, model_name, pretrained=False): super(MyViT, self).__init__() self.model = timm.create_model(model_name, pretrained=pretrained) n_features = self.model.head.in_features # 改成自己任务的图像类别数 self.model.head = nn.Linear(n_features, target_size) def forward(self, x): x = self.model(x) return x class MyDeiT(nn.Module): def __init__(self, target_size, pretrained=False): super(MyDeiT, self).__init__() self.model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_distilled_patch16_224', pretrained=True) n_features = self.model.head.in_features # 改成自己任务的图像类别数 self.model.head = nn.Linear(n_features, target_size) self.model.head_dist = nn.Linear(n_features, target_size) def forward(self, x): x, x_dist = self.model(x) return x, x_dist ```