#python #python-3.x #tensorflow #keras #deep-learning
#питон #python-3.x #tensorflow #keras #глубокое обучение
Вопрос:
Я попытался использовать пользовательские метрики для своей модели и сохранить контрольную точку модели на основе этой пользовательской метрики. Среда, которую я использую, — это ядро kaggle с графическим процессором в качестве ускорителя.
Код, который я использовал:
ckp = tf.keras.callbacks.ModelCheckpoint(filepath, monitor="val_f1score",
mode='max', save_weights_only=True,
save_best_only=True, verbose=1)
model.compile("adam",
loss=tf.keras.losses.categorical_crossentropy,
metrics=["acc",
tfa.metrics.F1Score(num_classes=18, name="f1score"),
]
)
model.fit(X_train, y_train, epochs=300, batch_size=64, validation_data=(X_val, y_val),
callbacks=[ckp]
)
вызывает эту ошибку:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-39-dca145b7ccb9> in <module>
1 model.fit(X_train, y_train, epochs=300, batch_size=64, validation_data=(X_val, y_val),
----> 2 callbacks=[ckp]
3 )
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
--> 108 return method(self, *args, **kwargs)
109
110 # Running inside `run_distribute_coordinator` already.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1135 epoch_logs.update(val_logs)
1136
-> 1137 callbacks.on_epoch_end(epoch, epoch_logs)
1138 training_logs = epoch_logs
1139 if self.stop_training:
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
410 for callback in self.callbacks:
411 if getattr(callback, '_supports_tf_logs', False):
--> 412 callback.on_epoch_end(epoch, logs)
413 else:
414 if numpy_logs is None: # Only convert once.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
1247 # pylint: disable=protected-access
1248 if self.save_freq == 'epoch':
-> 1249 self._save_model(epoch=epoch, logs=logs)
1250
1251 def _should_save_on_batch(self, batch):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _save_model(self, epoch, logs)
1289 'skipping.', self.monitor)
1290 else:
-> 1291 if self.monitor_op(current, self.best):
1292 if self.verbose > 0:
1293 print('nEpoch d: %s improved from %0.5f to %0.5f,'
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Однако, когда я удалил ModelCheckpoint или установил для монитора значение val_loss или val_acc, model.fit() работает нормально.
Комментарии:
1. Какие-либо обновления по этому поводу?