#tensorflow #keras #semantic-segmentation
#tensorflow #keras #семантическая сегментация
Вопрос:
Я работаю над семантической сегментацией с использованием segmentation-models
библиотеки. Я изменил это руководство, чтобы рассмотреть 8 классов для сегментации вместо 2, рассмотренных в примере. При обучении модели я получил ConcatOp : Dimensions of inputs should match:
ошибку. Знаете ли вы, что вызвало эту ошибку?
# define optomizer
optim = keras.optimizers.Adam(LR)
# Segmentation models losses can be combined together by ' ' and scaled by integer or float factor
# set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)
dice_loss = sm.losses.DiceLoss(class_weights=np.array([1, 2, 2, 2, 2, 2, 1, 1, 0.5]))
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
# total_loss = dice_loss (1 * focal_loss)
# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
total_loss = sm.losses.categorical_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
# compile keras model with defined optimozer, loss and metrics
model.compile(optim, total_loss, metrics)
# Dataset for train images
train_dataset = Dataset(
x_train_dir,
y_train_dir,
classes=CLASSES,
augmentation=get_training_augmentation(),
preprocessing=get_preprocessing(preprocess_input),
)
# Dataset for validation images
valid_dataset = Dataset(
x_valid_dir,
y_valid_dir,
classes=CLASSES,
augmentation=get_validation_augmentation(),
preprocessing=get_preprocessing(preprocess_input),
)
train_dataloader = Dataloder(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = Dataloder(valid_dataset, batch_size=1, shuffle=False)
# check shapes for errors
assert train_dataloader[0][0].shape == (BATCH_SIZE, 320, 320, 3)
assert train_dataloader[0][1].shape == (BATCH_SIZE, 320, 320, n_classes)
# define callbacks for learning rate scheduling and best checkpoints saving
callbacks = [
keras.callbacks.ModelCheckpoint('./best_model.h5', save_weights_only=True, save_best_only=True, mode='min'),
keras.callbacks.ReduceLROnPlateau(),
]
# train model
history = model.fit_generator(
train_dataloader,
steps_per_epoch=len(train_dataloader),
epochs=EPOCHS,
callbacks=callbacks,
validation_data=valid_dataloader,
validation_steps=len(valid_dataloader),
)
Epoch 1/40
61/62 [============================>.] - ETA: 0s - loss: 0.9784 - iou_score: 0.0108 - f1-score: 0.0124
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-59-859d1e145522> in <module>()
6 callbacks=callbacks,
7 validation_data=valid_dataloader,
----> 8 validation_steps=len(valid_dataloader),
9 )
13 frames
/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' object_name '` call to the '
90 'Keras 2 API: ' signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1730 use_multiprocessing=use_multiprocessing,
1731 shuffle=shuffle,
-> 1732 initial_epoch=initial_epoch)
1733
1734 @interfaces.legacy_generator_methods_support
/usr/local/lib/python3.6/dist-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
240 validation_steps,
241 callbacks=callbacks,
--> 242 workers=0)
243 else:
244 # No need for try/except because
/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' object_name '` call to the '
90 'Keras 2 API: ' signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in evaluate_generator(self, generator, steps, callbacks, max_queue_size, workers, use_multiprocessing, verbose)
1789 workers=workers,
1790 use_multiprocessing=use_multiprocessing,
-> 1791 verbose=verbose)
1792
1793 @interfaces.legacy_generator_methods_support
/usr/local/lib/python3.6/dist-packages/keras/engine/training_generator.py in evaluate_generator(model, generator, steps, callbacks, max_queue_size, workers, use_multiprocessing, verbose)
399 outs = model.test_on_batch(x, y,
400 sample_weight=sample_weight,
--> 401 reset_metrics=False)
402 outs = to_list(outs)
403 outs_per_batch.append(outs)
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
1557 ins = x y sample_weights
1558 self._make_test_function()
-> 1559 outputs = self.test_function(ins)
1560
1561 if reset_metrics:
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py in __call__(self, inputs)
3725 value = math_ops.cast(value, tensor.dtype)
3726 converted_inputs.append(value)
-> 3727 outputs = self._graph_fn(*converted_inputs)
3728
3729 # EagerTensor.numpy() will often make a copy to ensure memory safety.
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
1549 TypeError: For invalid positional/keyword argument combinations.
1550 """
-> 1551 return self._call_impl(args, kwargs)
1552
1553 def _call_impl(self, args, kwargs, cancellation_manager=None):
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _call_impl(self, args, kwargs, cancellation_manager)
1589 raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
1590 list(kwargs.keys()), list(self._arg_keywords)))
-> 1591 return self._call_flat(args, self.captured_inputs, cancellation_manager)
1592
1593 def _filtered_call(self, args, kwargs):
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1690 # No tape is watching; skip to running the function.
1691 return self._build_call_outputs(self._inference_function.call(
-> 1692 ctx, args, cancellation_manager=cancellation_manager))
1693 forward_backward = self._select_forward_and_backward_functions(
1694 args,
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in call(self, ctx, args, cancellation_manager)
543 inputs=args,
544 attrs=("executor_type", executor_type, "config_proto", config),
--> 545 ctx=ctx)
546 else:
547 outputs = execute.execute_with_cancellation(
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code, message), None)
68 except TypeError as e:
69 keras_symbolic_tensors = [
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [1,1536,44,64] vs. shape[1] = [1,816,43,64]
[[node decoder_stage0_concat_4/concat (defined at /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3009) ]] [Op:__inference_keras_scratch_graph_508722]
Function call stack:
keras_scratch_graph
Комментарии:
1. Проверьте свой набор данных проверки аналогично тому, как
assert train_dataloader[0][0].shape == (BATCH_SIZE, 320, 320, 3) assert train_dataloader[0][1].shape == (BATCH_SIZE, 320, 320, n_classes)
2. Это решило проблему, набор проверки не имел правильного размера. Большое спасибо за ваше внимание.
3. @PPR, если ваша проблема решена, не могли бы вы, пожалуйста, опубликовать решение в разделе ответов на благо сообщества. Спасибо!
4. @TFer2 Я столкнулся с этой проблемой. Не могли бы вы рассказать мне, как вы это решили?
5. Комментарий Алекса помог мне решить эту проблему. Распечатайте свой набор данных проверки и проверьте на наличие несоответствий.