Keras generator и fit_generator, как создать генератор, чтобы избежать ошибки ‘function shape’

#python #tensorflow #keras

#python #tensorflow #keras

Вопрос:

Я создаю генератор для Keras, чтобы иметь возможность загружать изображения моего набора данных, поскольку он немного велик для моей оперативной памяти.

Я построил генератор следующим образом:

 # import the necessary packages
import tensorflow
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import pandas as pd
from tqdm import tqdm

#loading
path_to_txt = "/content/test/leafsnap-dataset/leafsnap-dataset- 
images_improved.txt"
df = pd.read_csv(path_to_txt ,sep='t')
arr = np.array(df)
#epochs and steps:
NUM_TRAIN_IMAGES = 0
NUM_EPOCHS = 30

def image_generator(arr, bs, mode="train", aug=None):
  while True:
    images = []
    labels = []
    for row in arr:
      if len(images) < bs:
        img = (cv2.resize(cv2.imread("/content/test/leafsnap-dataset/"   
        row[0]),(224,224)))
        images.append(img)
        labels.append([row[2]])
        NUM_TRAIN_IMAGES  = 1
      else:
        break


  if aug is not None:
    (images, labels) = next(aug.flow(np.array(images),labels, 
     batch_size=bs))

  obj = OneHotEncoder()
  values = obj.fit_transform(labels).toarray()

  yield (np.array(images), labels)
  

Затем я вызываю fit_generator из последовательной модели (cnn работал, пока я не получил ошибку OOM)

 #create the augmentation function:
 aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
    width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
    horizontal_flip=True, fill_mode="nearest")

#create the generator:
gen = image_generator(arr, bs = 32, mode = "train", aug = aug)

history = model.fit_generator(image_generator,
    steps_per_epoch = NUM_TRAIN_IMAGES,
    epochs = NUM_EPOCHS)
  

И отсюда я получаю эту ошибку:

 # Create generator from NumPy or EagerTensor Input.
--> 377   num_samples = int(nest.flatten(data)[0].shape[0])
378   if batch_size is None:
379     raise ValueError('You must specify `batch_size`')
AttributeError: 'function' object has no attribute 'shape'
  

Комментарии:

1. Во-первых, ваша функция генератора не экономит память. Потому что сначала вы загружаете все изображения. Вы должны выполнить итерацию по файлам изображений и внутри цикла получить np.array.

Ответ №1:

Я вижу здесь две основные ошибки.

Во-первых, ваша функция генератора не экономит память. Потому что вы загружаете все изображения сначала (цикл while). Вы должны выполнить итерацию по файлам изображений и внутри цикла получить np.массив изображений с меткой.

Во-вторых, вы передаете имя функции генератора в fit_generator, когда вы должны использовать его возвращаемый object — gen.

Комментарии:

1. О боже, не увидел ошибки между функцией и именем объекта после сотни раз прочтения кода. Что касается оптимизации, да, я новичок в концепции, я сейчас ее оптимизирую. Спасибо.

2. Я предполагаю, что оператор yield находится внутри некоторого цикла? Если это не так, этот генератор будет таким же эффективным с точки зрения памяти, как и без генератора

3. На самом деле я думаю, что это не внутри цикла, потому что вы не перебираете все изображения, а только до тех пор, пока не достигнете размера bs. Но моя реализация в любом случае неверна.