#python #machine-learning #keras #roc #auc
#python #машинное обучение #keras #ОКР #auc
Вопрос:
Как я могу вычислить ROC и AUC в Keras, используя набор данных test_generator? Посмотрите коды, которые я написал, но не вижу ни одной кривой, нанесенной на график, и я хотел бы вычислить кривую на основе всех наборов данных в наборе данных Test_Dataset для двоичной классификации. Я попытался применить некоторые коды из StackOverflow, но это не сработало и не вычисляло то, что я хотел. Я новичок в машинном обучении. Спасибо.
from tensorflow.keras.models import load_model
from sklearn import metrics
import matplotlib.pyplot as plt
from skimage.transform import rescale,resize
import numpy as np `from tensorflow.keras.preprocessing.image import ImageDataGenerator
img_datagen = ImageDataGenerator(rescale=1/255)
test_generator = img_datagen.flow_from_directory(
'Test_Dataset', # This is the source directory for testing images
target_size=(150, 150), # All images will be resized to 150 x 150
batch_size=51,
# Specify the classes explicitly
classes = ['BAPL_0_1', 'BAPL_2_3'],
# Since we use categorical_crossentropy loss, we need categorical labels
class_mode='categorical')
X,y = test_generator.next()
X.shape
pred_prob1 = model.predict(X)
predict_label1 = np.argmax(pred_prob1, axis=-1)
true_label1 = np.argmax(y, axis=-1)
y = np.array(true_label1)
scores = np.array(predict_label1)
fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=9)
roc_auc = metrics.auc(fpr, tpr)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic (ROC)')
plt.legend(loc="lower right")
plt.show()
Вот модель, разработанная с использованием Keras, и модель была обучена подкатегориям в каждом классе, имеющим 51 изображение для каждой папки в классе, вы можете увидеть в Test_Dataset .
Test_Dataset здесь