#python #tensorflow #keras #tensorflow2.0
Вопрос:
Я пытаюсь создать пользовательскую модель ( SharedModel
) с пользовательскими блоками ( ClassA
, ClassB
, ClassC
). Визуальный обзор можно найти здесь. Причина создания этой модели заключается в том, чтобы иметь возможность обучать ее как есть, а затем удалять class_c
во время вывода. Это игрушечная модель для краткого представления фактической модели.
Код для ClassA
:
class ClassA(keras.Model):
def __init__(self, **kwargs):
super(ClassA, self).__init__(**kwargs)
weight_decay = 0.0001
self.L1 = Conv2D(64, kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))
self.L2 = BatchNormalization()
self.L3 = Activation('relu')
self.L4 = MaxPooling2D(pool_size=(2,2), name="pool1")
def call(self, inputs):
x=self.L1(inputs)
x=self.L2(x)
x=self.L3(x)
x=self.L4(x)
return x
def get_config(self):
return {
# "layers": {
"L1": self.L1,
"L2": self.L2,
"L3": self.L3,
"L4": self.L4,
# }
}
Реализация ClassB
и ClassC
идентична ClassA
.
Код для SharedModel
:
# @tf.keras.utils.register_keras_serializable()
class SharedModel(keras.Model):
def __init__(self, **kwargs):
super(SharedModel, self).__init__(**kwargs)
self.a_class=ClassA()
self.b_class=ClassB()
self.c_class=ClassC()
def call(self, inputs, **kwargs):
out1=self.a_class(inputs)
out2=self.b_class(out1)
out3=self.c_class(out1)
return out2, out3
def get_config(self):
return {
# "layers": {
"a_class": self.a_class,
"b_class": self.b_class,
"c_class": self.c_class,
# "build_graph": self.build_graph
# }
}
# "base_config": super(SharedModel, self).get_config()}
@classmethod
def from_config(cls, config):
return cls(**config)
# @tf.keras.utils.register_keras_serializable()
def build_graph(self, dim):
x = Input(shape=(dim))
return Model(inputs=x, outputs = self.call(inputs=x), name="Shared Model")
Я могу создавать и сохранять модель без проблем. Однако, когда я пытаюсь загрузить его обратно с помощью
recreate_model=tf.keras.models.load_model("GraphModel", custom_objects={"SharedModel": SharedModel})
and try to call using either recreate_model(input)
or even the explicit recreate_model.call(input)
, I get the following error:
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/load.py in _unable_to_call_layer_due_to_serialization_issue(layer, *unused_args, **unused_kwargs)
902
903 raise ValueError(
--> 904 f'Cannot call custom layer {layer.name} of type {type(layer)}, because '
905 'the call function was not serialized to the SavedModel.'
906 'Please try one of the following methods to fix this issue:'
ValueError: Exception encountered when calling layer "shared_model" (type SharedModel).
Cannot call custom layer shared_model of type <class 'keras.saving.saved_model.load.SharedModel'>, because the call function was not serialized to the SavedModel.Please try one of the following methods to fix this issue:
(1) Implement `get_config` and `from_config` in the layer/model class, and pass the object to the `custom_objects` argument when loading the model. For more details, see: https://www.tensorflow.org/guide/keras/save_and_serialize
(2) Ensure that the subclassed model or layer overwrites `call` and not `__call__`. The input shape and dtype will be automatically recorded when the object is called, and used when saving. To manually specify the input shape/dtype, decorate the call function with `@tf.function(input_signature=...)`.
Call arguments received:
• unused_args=('tf.Tensor(shape=(None, 216, 64, 1), dtype=float32)',)
• unused_kwargs={'training': 'None'}
As shown, I’ve implemented the get_config
method for all the subclasses and the from_config
for SharedModel
, and am passing SharedModel
in custom_objects
. I also tried passing the blocks into custom_ojects
as well, but then I get this error while loading:
1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
1323
1324 # First, we create all layers and enqueue nodes to be processed
-> 1325 for layer_data in config['layers']:
1326 process_layer(layer_data)
1327 # Then we process nodes in order of layer depth.
KeyError: 'layers'
I did think to implement from_config
for the blocks as well, but it seems to be optional.
According to the docs, recreate_model
should have a similar object id class as model
( <__main__.SharedModel>
) when passing the custom objects but it has <keras.saving.saved_model.load.SharedModel>
instead.
From these errors, I’m guessing I’m missing something in get_config
and from_config
. Could someone tell me what I’m doing wrong and point me to a solution?
Further, how can I make build_graph
in SharedModel
accessible after loading?