Не могу использовать `call()` для модели, загруженной с помощью `tf.keras.models.load_model()`, несмотря на передачу пользовательских объектов

#python #tensorflow #keras #tensorflow2.0

Вопрос:

Я пытаюсь создать пользовательскую модель ( SharedModel ) с пользовательскими блоками ( ClassA , ClassB , ClassC ). Визуальный обзор можно найти здесь. Причина создания этой модели заключается в том, чтобы иметь возможность обучать ее как есть, а затем удалять class_c во время вывода. Это игрушечная модель для краткого представления фактической модели.

Код для ClassA :

 class ClassA(keras.Model):
    def __init__(self, **kwargs):
        super(ClassA, self).__init__(**kwargs)
        weight_decay = 0.0001
        self.L1 = Conv2D(64, kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))
        self.L2 = BatchNormalization()
        self.L3 = Activation('relu')
        self.L4 = MaxPooling2D(pool_size=(2,2), name="pool1")
    
    def call(self, inputs):
        x=self.L1(inputs)
        x=self.L2(x)
        x=self.L3(x)
        x=self.L4(x)
        return x

    def get_config(self):
        return {
            # "layers": {
            "L1": self.L1,
            "L2": self.L2,
            "L3": self.L3,
            "L4": self.L4,
            # }
        }
 

Реализация ClassB и ClassC идентична ClassA .

Код для SharedModel :

 # @tf.keras.utils.register_keras_serializable()
class SharedModel(keras.Model):
    def __init__(self, **kwargs):
        super(SharedModel, self).__init__(**kwargs)
        self.a_class=ClassA()
        self.b_class=ClassB()
        self.c_class=ClassC()
    
    def call(self, inputs, **kwargs):
        out1=self.a_class(inputs)
        out2=self.b_class(out1)
        out3=self.c_class(out1)

        return out2, out3
    
    def get_config(self):
        return {
            # "layers": {
            "a_class": self.a_class,
            "b_class": self.b_class,
            "c_class": self.c_class,
            # "build_graph": self.build_graph
            # }
        }
        # "base_config": super(SharedModel, self).get_config()}

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    # @tf.keras.utils.register_keras_serializable()
    def build_graph(self, dim):
        x = Input(shape=(dim))
        return Model(inputs=x, outputs = self.call(inputs=x), name="Shared Model")
 

Я могу создавать и сохранять модель без проблем. Однако, когда я пытаюсь загрузить его обратно с помощью

 recreate_model=tf.keras.models.load_model("GraphModel", custom_objects={"SharedModel": SharedModel})
 

and try to call using either recreate_model(input) or even the explicit recreate_model.call(input) , I get the following error:

 /usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/load.py in _unable_to_call_layer_due_to_serialization_issue(layer, *unused_args, **unused_kwargs)
    902 
    903   raise ValueError(
--> 904       f'Cannot call custom layer {layer.name} of type {type(layer)}, because '
    905       'the call function was not serialized to the SavedModel.'
    906       'Please try one of the following methods to fix this issue:'

ValueError: Exception encountered when calling layer "shared_model" (type SharedModel).

Cannot call custom layer shared_model of type <class 'keras.saving.saved_model.load.SharedModel'>, because the call function was not serialized to the SavedModel.Please try one of the following methods to fix this issue:

(1) Implement `get_config` and `from_config` in the layer/model class, and pass the object to the `custom_objects` argument when loading the model. For more details, see: https://www.tensorflow.org/guide/keras/save_and_serialize

(2) Ensure that the subclassed model or layer overwrites `call` and not `__call__`. The input shape and dtype will be automatically recorded when the object is called, and used when saving. To manually specify the input shape/dtype, decorate the call function with `@tf.function(input_signature=...)`.

Call arguments received:
  • unused_args=('tf.Tensor(shape=(None, 216, 64, 1), dtype=float32)',)
  • unused_kwargs={'training': 'None'}
 

As shown, I’ve implemented the get_config method for all the subclasses and the from_config for SharedModel , and am passing SharedModel in custom_objects . I also tried passing the blocks into custom_ojects as well, but then I get this error while loading:

 1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1323 
   1324   # First, we create all layers and enqueue nodes to be processed
-> 1325   for layer_data in config['layers']:
   1326     process_layer(layer_data)
   1327   # Then we process nodes in order of layer depth.

KeyError: 'layers'
 

I did think to implement from_config for the blocks as well, but it seems to be optional.

According to the docs, recreate_model should have a similar object id class as model ( <__main__.SharedModel> ) when passing the custom objects but it has <keras.saving.saved_model.load.SharedModel> instead.

From these errors, I’m guessing I’m missing something in get_config and from_config . Could someone tell me what I’m doing wrong and point me to a solution?
Further, how can I make build_graph in SharedModel accessible after loading?