как выводить изображения из Batchdataset изображений в keras

#python #tensorflow #keras

#python #тензорный поток #keras

Вопрос:

После создания набора данных изображений с использованием image_dataset_from_directory из keras, как получить первое изображение из набора данных в формате numpy, которое можно отобразить с помощью pyplot.imshow?

 import tensorflow as tf
import matplotlib.pyplot as plt

test_data = tf.keras.preprocessing.image_dataset_from_directory(
    "C:\Users\Admin\Downloads\kagglecatsanddogs_3367a",
    validation_split=.1,
    subset='validation',
    seed=123)
for e in test_data.as_numpy_iterator():
    print(e[1:])
  

Ответ №1:

В приведенном выше коде e — это не изображение, а скорее кортеж, содержащий изображение и метки.
Код:

 plt.figure(figsize=(10, 10))
class_names = test_data.class_names
for images, labels in test_data.take(1):
    for i in range(32):
        ax = plt.subplot(6, 6, i   1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
  

Вы можете использовать test_data.take(1) , чтобы взять один пакет из вашего test_data и визуализировать его.

Ваш вывод будет выглядеть примерно так:
введите описание изображения здесь

Ответ №2:

Код: если вы используете функцию предварительной обработки набора данных изображений tensorflow из каталога.

 batch_size = 32
img_height = 180
img_width = 180
seed = 123
    
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir_train, 
                                                                   labels='inferred', 
                                                                   seed=seed, 
                                                                   batch_size=batch_size, 
                                                                   image_size=(img_width, img_height),
                                                                   label_mode='categorical',
                                                                   subset="training",
                                                                   validation_split=0.2)

class_names = train_ds.class_names

import matplotlib.pyplot as plt

### To visualize the images
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(batch_size):
    ax = plt.subplot(6, 6, i   1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[np.argmax(labels[i])])
    plt.axis("off")

# Plotting the images
plt.show()