Try   HackMD

繪製分類模型結果 - ROC Curve

使用套件
scikit-plot
https://pypi.org/project/scikit-plot/
matplotlib
https://pypi.org/project/matplotlib/

安裝方式

pip install scikit-plot
pip install matplotlib

用法: ref from stackoverflow

import scikitplot as skplt
import matplotlib.pyplot as plt

y_true = # ground truth labels
y_probas = # predicted probabilities generated by sklearn classifier
skplt.metrics.plot_roc_curve(y_true, y_probas)
plt.show()

範例資料:
假設我們模型跑完的結果長這樣

  • 欄位說明:
    • filename: 圖片檔名
    • gt_class: ground truth 類別 label
    • pred_class: 模型 predict 類別結果 label
    • pred_prob: 模型 predict 各類別結果的詳細機率值

把 gt_class 資料整理成 list 存到 y_true 裡面
長這樣 (label 範圍 0 ~ 6,共 7 個類別)

y_true = [2, 5, 5, 3, 2, 1, ...... ]

把 pred_prob 資料整理成二維 list 存到 prob_list 裡面
長這樣 (我的模型輸出的每一筆資料含有 7 個機率值)

prob_list = [[8.5276943e-06, 5.0007920e-06, 9.9994439e-01, 4.3...], 
             [5.6465615e-06 4.8560273e-06 6.0416551e-06 5.7..],
             [2.5744524e-04 1.5103277e-03 3.2300744e-04 3.1...],
             [2.5181988e-02 8.9589570e-04 6.6928881e-01 3.0...],
             [7.5951771e-06 9.9995345e-01 8.5064212e-06 5.4...],
             ...
             ]

繪製 ROC Curve

import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_roc_curve(y_true,prob_list)
plt.savefig('{}_roc_curve.png'.format(model_name)) # 圖片存檔
plt.show()  # 顯示出來