데이터과학 삼학년

ROC, Precision-Recall Curve for Multi classification 본문

Machine Learning

ROC, Precision-Recall Curve for Multi classification

Dan-k 2024. 2. 20. 15:00
반응형

ROC Curve
- multiclassification 문제에서는 각 label별 ROC커브를 그림

import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_recall_curve, auc
import matplotlib.pyplot as plt


from sklearn.preprocessing import OneHotEncoder

# Assuming y_test is a 1D array of class labels
y_test = y_test.values.reshape(-1, 1)
# Create the one-hot encoder
encoder = OneHotEncoder(sparse=False, categories='auto')
# Fit and transform the data
y_test_onehot = encoder.fit_transform(y_test)

num_classes = 3

# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
thresholds= dict()
roc_auc = dict()

for i in range(num_classes):  # num_classes is the number of classes in your data
    fpr[i], tpr[i], thresholds[i] = roc_curve(y_test_onehot[:, i], y_pred_proba[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], thresholds["micro"] = roc_curve(y_test_onehot.ravel(), y_pred_proba.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Plot ROC curve for each class
plt.figure(figsize=(10, 8))
colors =  ['red', 'green', 'gray']
for i, color in zip(range(num_classes), colors):
    optimal_idx = np.argmax(tpr[i] - fpr[i])
    plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'Class {i} (AUC = {roc_auc[i]:.2f}), threshold={thresholds[i][optimal_idx]:.2f}')
    plt.scatter(fpr[i][optimal_idx], tpr[i][optimal_idx], marker='+', s=200, color='blue')

# Plot micro-average ROC curve
plt.plot(fpr["micro"], tpr["micro"], color='gold', label=f'Micro-average (AUC = {roc_auc["micro"]:.2f})', linestyle='--', linewidth=2)

plt.plot([0, 1], [0, 1], 'k--', linewidth=2)  # Plot the random curve
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()

 

source : https://stats.stackexchange.com/questions/2151/how-to-plot-roc-curves-in-multiclass-classification

 
Precision-Recall curve
- precision과 recall의 trade off 고려


precision = dict()
recall = dict()
average_precision = dict()

# Calculate precision-recall curve for each class
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test_onehot[:, i], y_pred_proba[:, i])
    average_precision[i] = auc(recall[i], precision[i])
    
# Compute micro-average precision-recall curve
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test_onehot.ravel(), y_pred_proba.ravel())
average_precision["micro"] = auc(recall["micro"], precision["micro"])

# Plot Precision-Recall curve for each class
plt.figure(figsize=(10, 8))

colors = ['red', 'green', 'gray']
for i, color in zip(range(n_classes), colors):
    plt.plot(recall[i], precision[i], color=color, lw=2,
             label=f"Class {i} ")

plt.plot(recall["micro"], precision["micro"], color='gold', lw=2,
         linestyle='--',
         label=f"Micro-average Precision-Recall curve ")

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve for Multi-Class Classification')
plt.legend(loc="lower right")
plt.show()​
source : https://stats.stackexchange.com/questions/2151/how-to-plot-roc-curves-in-multiclass-classification

 
최적의 Threshold

import numpy as np
from sklearn.metrics import plot_roc_curve

# calculate roc curves
fpr, tpr, thresholds = roc_curve(y_test, y_prob2)

#Youden’s J statistic. / J = Sensitivity + Specificity – 1
J = tpr - fpr
optimal_index = np.argmax(J)
best_threshold = thresholds[optimal_index]
728x90
반응형
LIST
Comments