#python #image #tensorflow #keras #tensorflow-datasets
Вопрос:
У меня есть следующая функция для возврата набора данных обучения и проверки:
def load_from_directory(path, shuffle=False): train_ds = tfk.preprocessing.image_dataset_from_directory( directory=path, image_size=IMAGE_SIZE, validation_split=VALIDATION_SPLIT, batch_size=BATCH_SIZE, seed=SEED, subset='training', label_mode='binary', shuffle=shuffle ) val_ds = tfk.preprocessing.image_dataset_from_directory( directory=path, image_size=IMAGE_SIZE, validation_split=VALIDATION_SPLIT, batch_size=BATCH_SIZE, seed=SEED, subset='validation', label_mode='binary', shuffle=False ) return train_ds, val_ds train_ds, val_ds = load_from_directory(path=TRAINING_PATH, shuffle=True)
Проблема в том, что после некоторых странных результатов (точность проверки 100% после 2-й эпохи) Я проанализировал состав проверочного набора и пришел к выводу, что он содержит изображения только из одного класса.
Это очень странно, но я не знаю, как с этим справиться. Я использую набор данных «Кошки и собаки» от Microsoft, который содержит массу примеров каждого класса.
Чтобы отобразить на диаграмме распределение классов, я делаю следующее:
import plotly.graph_objects as go labels = np.concatenate([y for _, y in train_ds], axis=0) _, counts = np.unique(labels, return_counts=True) fig = go.Figure( data=[ go.Pie( labels=CLASS_NAMES, values=counts, hole=.5, marker_colors=['rgb(205, 152, 36)', 'rgb(129, 180, 179)', 'rgb(177, 180, 34)'] )], layout_title_text='Train Class Frequency' ) fig.update_layout(width=400, height=400) fig.show() labels = np.concatenate([y for _, y in val_ds], axis=0) _, counts = np.unique(labels, return_counts=True) fig = go.Figure( data=[ go.Pie( labels=CLASS_NAMES, values=counts, hole=.5, marker_colors=['rgb(205, 152, 36)', 'rgb(129, 180, 179)', 'rgb(177, 180, 34)'] )], layout_title_text='Validation Class Frequency' ) fig.update_layout(width=400, height=400) fig.show()
Еще более странная вещь заключается в том, что при suffle=True
создании набора данных в наборе данных есть два класса, но не имеет смысла устанавливать этот флаг в значение True.
Ответ №1:
Я запустил ваш код и не вижу проблемы. Я использовал набор данных с 2 классами. и запустил его с помощью shuffle=True, а также с помощью shuffle=False. Чтобы проверить, имеет ли val_ds нужное количество классов, используйте
print(val_ds.class_names)
Комментарии:
1. Замена категориального на двоичный в label_mode устранит проблему. Я думаю, что в моем случае это было проблемой. Спасибо за вашу помощь