Machine Learning
ROC, Precision-Recall Curve for Multi classification
Dan-k
2024. 2. 20. 15:00
반응형
ROC Curve
- multiclassification 문제에서는 각 label별 ROC커브를 그림
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_recall_curve, auc
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
# Assuming y_test is a 1D array of class labels
y_test = y_test.values.reshape(-1, 1)
# Create the one-hot encoder
encoder = OneHotEncoder(sparse=False, categories='auto')
# Fit and transform the data
y_test_onehot = encoder.fit_transform(y_test)
num_classes = 3
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
thresholds= dict()
roc_auc = dict()
for i in range(num_classes): # num_classes is the number of classes in your data
fpr[i], tpr[i], thresholds[i] = roc_curve(y_test_onehot[:, i], y_pred_proba[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], thresholds["micro"] = roc_curve(y_test_onehot.ravel(), y_pred_proba.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# Plot ROC curve for each class
plt.figure(figsize=(10, 8))
colors = ['red', 'green', 'gray']
for i, color in zip(range(num_classes), colors):
optimal_idx = np.argmax(tpr[i] - fpr[i])
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'Class {i} (AUC = {roc_auc[i]:.2f}), threshold={thresholds[i][optimal_idx]:.2f}')
plt.scatter(fpr[i][optimal_idx], tpr[i][optimal_idx], marker='+', s=200, color='blue')
# Plot micro-average ROC curve
plt.plot(fpr["micro"], tpr["micro"], color='gold', label=f'Micro-average (AUC = {roc_auc["micro"]:.2f})', linestyle='--', linewidth=2)
plt.plot([0, 1], [0, 1], 'k--', linewidth=2) # Plot the random curve
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()

Precision-Recall curve
- precision과 recall의 trade off 고려
precision = dict()
recall = dict()
average_precision = dict()
# Calculate precision-recall curve for each class
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test_onehot[:, i], y_pred_proba[:, i])
average_precision[i] = auc(recall[i], precision[i])
# Compute micro-average precision-recall curve
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test_onehot.ravel(), y_pred_proba.ravel())
average_precision["micro"] = auc(recall["micro"], precision["micro"])
# Plot Precision-Recall curve for each class
plt.figure(figsize=(10, 8))
colors = ['red', 'green', 'gray']
for i, color in zip(range(n_classes), colors):
plt.plot(recall[i], precision[i], color=color, lw=2,
label=f"Class {i} ")
plt.plot(recall["micro"], precision["micro"], color='gold', lw=2,
linestyle='--',
label=f"Micro-average Precision-Recall curve ")
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve for Multi-Class Classification')
plt.legend(loc="lower right")
plt.show()

최적의 Threshold
import numpy as np
from sklearn.metrics import plot_roc_curve
# calculate roc curves
fpr, tpr, thresholds = roc_curve(y_test, y_prob2)
#Youden’s J statistic. / J = Sensitivity + Specificity – 1
J = tpr - fpr
optimal_index = np.argmax(J)
best_threshold = thresholds[optimal_index]
728x90
반응형
LIST