Набор проверки содержит изображения только из одного класса, если они получены через каталог image_dataset_from_directory

#python #image #tensorflow #keras #tensorflow-datasets

Вопрос:

У меня есть следующая функция для возврата набора данных обучения и проверки:

 def load_from_directory(path, shuffle=False):  train_ds = tfk.preprocessing.image_dataset_from_directory(  directory=path,  image_size=IMAGE_SIZE,  validation_split=VALIDATION_SPLIT,  batch_size=BATCH_SIZE,  seed=SEED,  subset='training',  label_mode='binary',  shuffle=shuffle  )   val_ds = tfk.preprocessing.image_dataset_from_directory(  directory=path,  image_size=IMAGE_SIZE,  validation_split=VALIDATION_SPLIT,  batch_size=BATCH_SIZE,  seed=SEED,  subset='validation',  label_mode='binary',  shuffle=False  )   return train_ds, val_ds  train_ds, val_ds = load_from_directory(path=TRAINING_PATH, shuffle=True)  

Проблема в том, что после некоторых странных результатов (точность проверки 100% после 2-й эпохи) Я проанализировал состав проверочного набора и пришел к выводу, что он содержит изображения только из одного класса.

Это очень странно, но я не знаю, как с этим справиться. Я использую набор данных «Кошки и собаки» от Microsoft, который содержит массу примеров каждого класса.

Чтобы отобразить на диаграмме распределение классов, я делаю следующее:

 import plotly.graph_objects as go  labels = np.concatenate([y for _, y in train_ds], axis=0) _, counts = np.unique(labels, return_counts=True)  fig = go.Figure(  data=[  go.Pie(  labels=CLASS_NAMES,   values=counts,   hole=.5,   marker_colors=['rgb(205, 152, 36)', 'rgb(129, 180, 179)', 'rgb(177, 180, 34)']  )],   layout_title_text='Train Class Frequency' )  fig.update_layout(width=400, height=400) fig.show()  labels = np.concatenate([y for _, y in val_ds], axis=0) _, counts = np.unique(labels, return_counts=True)  fig = go.Figure(  data=[  go.Pie(  labels=CLASS_NAMES,   values=counts,   hole=.5,   marker_colors=['rgb(205, 152, 36)', 'rgb(129, 180, 179)', 'rgb(177, 180, 34)']  )],   layout_title_text='Validation Class Frequency' )  fig.update_layout(width=400, height=400) fig.show()   

Еще более странная вещь заключается в том, что при suffle=True создании набора данных в наборе данных есть два класса, но не имеет смысла устанавливать этот флаг в значение True.

Результаты

Ответ №1:

Я запустил ваш код и не вижу проблемы. Я использовал набор данных с 2 классами. и запустил его с помощью shuffle=True, а также с помощью shuffle=False. Чтобы проверить, имеет ли val_ds нужное количество классов, используйте

 print(val_ds.class_names)  

Комментарии:

1. Замена категориального на двоичный в label_mode устранит проблему. Я думаю, что в моем случае это было проблемой. Спасибо за вашу помощь