###### tags: `Deep Learning` `processing` # 多標籤分類的評估指標筆記 ## 多標籤分類簡介 **多標籤分類**(**multi-label classification**)是現實世界裡常見的一種分類任務形式。跟**多類別分類**(**multi-class classification**)不同,模型在預測類別時,可以預測多種類別的。 舉個例子,多類別分類是單選題,而多標籤分類則是**複選題**。 既然任務形式變複雜了,那麼與之對應的評估指標也可以分成兩種: 1. 不考慮部份正確的情況 2. 考慮部份正確的情況 當然,以分數計算的角度來看,前者遠比後者嚴苛。 ### 部份正確的評估指標 #### 準確率 Accuracy 以考慮部份正確的情況,可以使用以下公式計算 accuracy。 - **n**: 資料筆數 - **i**: 第 i 筆資料 - **true**: 標準答案 - **pred**: 預測答案 ![](https://i.imgur.com/9P6EcL3.png) 實際上就是把多標籤分類標準答案與預測,一筆筆去匹配,分母為預測結果的聯集(union),分子為預測結果的交集(intersection),接著把每筆資料的分數加總,除以資料數,便能得到『計算部份正確情況的準確率(accuracy)』。 以下是一段範例程式碼: ```=python= # coding: utf-8 from typing import List import numpy as np def partial_correct_accuracy( y_true: List[int], y_pred: List[int], ) -> float: count = 0 for i in range(len(y_true)): p = sum(np.logical_and(y_true[i], y_pred[i])) q = sum(np.logical_or(y_true[i], y_pred[i])) count += p / q return count / len(y_true) def main() -> None: """ Entry point """ y_true = [ [1, 0, 1, 1], [1, 1, 0, 0], [1, 0, 1, 0], ] y_pred = [ [1, 0, 0, 0], [0, 1, 0, 1], [1, 1, 1, 0], ] print("accuracy: ", partial_correct_accuracy(y_true, y_pred)) if __name__ == "__main__": main() ``` Output: ``` 0.4444444444444444 ``` 試驗算: ![](https://i.imgur.com/egA42Yj.png) <br/> #### 精確率 Precision ![](https://i.imgur.com/W3DV02o.png) 跟計算部份正確情況的 accuracy 公式相比,精確率我們只需要調整分母的部份即可。 分母從標準答案與預測答案的聯集,轉變為只計算預測答案的『數量』。 以實際原理來說明的話,就是計算我們所做出的預測中(分母),究竟有多少是和標準答案是一致的(分子)。 除了直接用公式計算外,也可以使用 sklearn.metrics.precision_score() 去算分。 **在多標籤分類考慮部份正確的情況下,要設定參數 average="samples"**。 ```=python= # coding: utf-8 from typing import List import numpy as np from sklearn import metrics def partial_correct_precision( y_true: List[int], y_pred: List[int], ) -> float: count = 0 for i in range(len(y_true)): if sum(y_pred[i]) == 0: continue p = sum(np.logical_and(y_true[i], y_pred[i])) q = sum(y_pred[i]) count += p / q return count / len(y_true) def main() -> None: """ Entry point """ y_true = [ [1, 0, 1, 1], [1, 1, 0, 0], [1, 0, 1, 0], ] y_pred = [ [1, 0, 0, 0], [0, 1, 0, 1], [1, 1, 1, 0], ] print("precision:", partial_correct_precision(y_true, y_pred)) print("scikit-learn precision:", metrics.precision_score(y_true, y_pred, average="samples")) if __name__ == "__main__": main() ``` Output: ``` precision: 0.7222222222222222 scikit-learn precision: 0.7222222222222222 ``` <br/> #### 召回率 Recall ![](https://i.imgur.com/XYmtAiq.png) 跟 precision 的計算方式很像、唯一不同之處在於分母是標準答案標籤的數量(**比方說 [1, 0, 1] 的數量就是 2**) 實際意義為標準答案的標籤中,我們模型預測出了多少筆。故為『召回率』。 ```=python= # coding: utf-8 from typing import List import numpy as np from sklearn import metrics def partial_correct_recall( y_true: List[int], y_pred: List[int], ) -> float: count = 0 for i in range(len(y_true)): if sum(y_true[i]) == 0: continue p = sum(np.logical_and(y_true[i], y_pred[i])) q = sum(y_true[i]) count += p / q return count / len(y_true) def main() -> None: """ Entry point """ y_true = [ [1, 0, 1, 1], [1, 1, 0, 0], [1, 0, 1, 0], ] y_pred = [ [1, 0, 0, 0], [0, 1, 0, 1], [1, 1, 1, 0], ] print("recall:", partial_correct_recall(y_true, y_pred)) print("scikit-learn recall:", metrics.recall_score(y_true, y_pred, average="samples")) if __name__ == "__main__": main() ``` Output: ``` recall: 0.611111111111111 scikit-learn recall: 0.611111111111111 ``` <br/> #### F1-score ![](https://i.imgur.com/TjEUd2L.png) F1-score 為 precision 及 recall 的平均調和數,可以看到公式同時考慮了兩種指標的情況。 ```=python= # coding: utf-8 from typing import List import numpy as np from sklearn import metrics def partial_correct_f1( y_true: List[int], y_pred: List[int], ) -> float: count = 0 for i in range(len(y_true)): if sum(y_true[i]) == 0 and sum(y_pred[i]) == 0: continue p = sum(np.logical_and(y_true[i], y_pred[i])) q = sum(y_true[i]) + sum(y_pred[i]) count += 2 * p / q return count / len(y_true) def main() -> None: """ Entry point """ y_true = [ [1, 0, 1, 1], [1, 1, 0, 0], [1, 0, 1, 0], ] y_pred = [ [1, 0, 0, 0], [0, 1, 0, 1], [1, 1, 1, 0], ] print("recall: ", partial_correct_f1(y_true, y_pred)) print("scikit-learn recall:", metrics.f1_score(y_true, y_pred, average="samples")) if __name__ == "__main__": main() ``` Output: ``` recall: 0.6 scikit-learn recall: 0.6 ```