tf.compat.v1.disable_eager_execution() с tf.data.dateset

#tensorflow2.0 #eager-execution #tf.data.dataset

#tensorflow2.0 #нетерпеливое выполнение #tf.data.dataset

Вопрос:

Я использую tensorflow 2.2. У меня есть два числовых массива (объекты и метки), которые я передаю в tf.data.dataset.from_tensor_slices():

 train_dataset = tf.data.Dataset.from_tensors(feature_train_slice, label_train_slice).shuffle(buffer_size).reapeat()

test_dataset = tf.data.Dataset.from_tensors(feature_test_slice, label_test_slice).shuffle(buffer_size).repeat()
  

Я пытаюсь передать эти данные в мою model.fit():

 history = self.model.fit(ds_train,
                            steps_per_epoch=int(train_steps / (batch_size)),
                            verbose=1,
                            epochs=epochs,
                            callbacks=self.call_back(),
                            use_multiprocessing=True,
                            validation_data = test_dataset,
                            validation_steps = int(validation_steps / (batch_size))
                            )
  

Я использовал

 tf.compat.v1.disable_eager_execution()
  

в начале моего кода. Если я прокомментирую это, обучение начнется без проблем, но, как я понимаю, обучение идет медленнее (каждый шаг занимает 2 секунды на 2080TI). Если я оставлю это, каждый шаг займет около 1,2 секунды. Однако программа никогда не передает строку

 train_dataset = tf.data.Dataset.from_tensors(feature_train_slice, label_train_slice).shuffle().reapeat()
  

Я вышел из программы более чем на 30 минут, и хотя потребляется около 60 ГБ (моя оперативная память составляет 64 ГБ), программа, похоже, ничего не делает. Кто-нибудь видел это раньше? любая помощь приветствуется.

Ответ №1:

вместо reapeat() должна прийти функция repeat().