#tensorflow2.0 #eager-execution #tf.data.dataset
#tensorflow2.0 #нетерпеливое выполнение #tf.data.dataset
Вопрос:
Я использую tensorflow 2.2. У меня есть два числовых массива (объекты и метки), которые я передаю в tf.data.dataset.from_tensor_slices():
train_dataset = tf.data.Dataset.from_tensors(feature_train_slice, label_train_slice).shuffle(buffer_size).reapeat()
test_dataset = tf.data.Dataset.from_tensors(feature_test_slice, label_test_slice).shuffle(buffer_size).repeat()
Я пытаюсь передать эти данные в мою model.fit():
history = self.model.fit(ds_train,
steps_per_epoch=int(train_steps / (batch_size)),
verbose=1,
epochs=epochs,
callbacks=self.call_back(),
use_multiprocessing=True,
validation_data = test_dataset,
validation_steps = int(validation_steps / (batch_size))
)
Я использовал
tf.compat.v1.disable_eager_execution()
в начале моего кода. Если я прокомментирую это, обучение начнется без проблем, но, как я понимаю, обучение идет медленнее (каждый шаг занимает 2 секунды на 2080TI). Если я оставлю это, каждый шаг займет около 1,2 секунды. Однако программа никогда не передает строку
train_dataset = tf.data.Dataset.from_tensors(feature_train_slice, label_train_slice).shuffle().reapeat()
Я вышел из программы более чем на 30 минут, и хотя потребляется около 60 ГБ (моя оперативная память составляет 64 ГБ), программа, похоже, ничего не делает. Кто-нибудь видел это раньше? любая помощь приветствуется.
Ответ №1:
вместо reapeat() должна прийти функция repeat().