#python #tensorflow #keras #tensorflow-datasets
Вопрос:
Я пытаюсь построить CNN в TensorFlow с помощью Python. Я загрузил свои изображения в набор данных следующим образом:
dataset = tf.keras.preprocessing.image_dataset_from_directory(
"train_data", shuffle=True, image_size=(578, 260),
batch_size=BATCH_SIZE)
Однако, если я хочу использовать train_test_split или fit_resample для этого набора данных, мне нужно разделить его на данные и метки. Я новичок в TensorFlow и не знаю, как это сделать. Был бы очень признателен за любую помощь.
Комментарии:
1. Являются ли ваши ярлыки частью «train_data»?
2. @AloneTogether да.
3. Как структурированы ваши данные в вашей папке?
4. @AloneTogether У меня есть 5 вложенных папок, полных изображений, организованных так же, как я хочу, чтобы данные были классифицированы
5. Спасибо за решение, я все еще пытаюсь разобраться
Ответ №1:
Вы можете использовать этот subset
параметр для разделения ваших данных на training
и validation
.
import tensorflow as tf
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
image_size=(256, 256),
seed=1,
batch_size=32)
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=1,
image_size=(256, 256),
batch_size=32)
for x, y in train_ds.take(1):
print('Image --> ', x.shape, 'Label --> ', y.shape)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
Image --> (32, 256, 256, 3) Label --> (32,)
Что касается ваших этикеток, то, согласно документам:
Либо «выводимые» (метки генерируются из структуры каталогов), Нет (без меток), либо список/кортеж целых меток того же размера, что и количество файлов изображений, найденных в каталоге. Метки должны быть отсортированы в соответствии с буквенно-цифровым порядком путей к файлам изображений (полученным с помощью os.walk(каталог) в Python).
Так что просто попробуйте повторить train_ds
и посмотреть, есть ли они там. Вы также можете использовать параметры label_mode
для обозначения типа имеющихся у вас меток и class_names
для явного перечисления ваших классов.
Если ваши классы сбалансированы, вы можете использовать class_weights
параметр model.fit(*)
. Для получения дополнительной информации ознакомьтесь с этим сообщением.