Keras: Ошибка атрибута при построении модели

#python #tensorflow #keras #plot

Вопрос:

Я пытаюсь построить модель, определенную ниже:

 def build_model_1(input_shape):      initializer = tf.keras.initializers.GlorotNormal()  model= Sequential()   # print("Input Shape: {}".format(input_shape))  #1st Conv Layer  model.add(Conv2D(20, (3,3),strides=(3,1),input_shape= input_shape, activation='relu',  kernel_initializer=initializer, bias_initializer='zeros'))  BatchNormalization()  #2nd Conv Layer  model.add(Conv2D(20,kernel_size= (2,2), strides=(2,2), activation='relu',  kernel_initializer=initializer, bias_initializer='zeros'))  model.add(MaxPool2D(pool_size=(2,2), strides=(2, 2), padding='same'))  BatchNormalization()  #3rd Conv Layer  model.add(Conv2D(30,kernel_size= (1,4),strides=(1,4), activation='relu',  kernel_initializer=initializer, bias_initializer='zeros'))  BatchNormalization()  #Flatten and Dense  model.add(Flatten())  model.add(Dense(200, activation='relu',  kernel_initializer=initializer, bias_initializer='zeros'))  model.add(Dense(100, activation='relu',  kernel_initializer=initializer, bias_initializer='zeros'))    #Output Layer  model.add(Dense(4, activation='softmax'))   return model  

Однако, когда я пытаюсь построить модель с tf.keras.utils.plot_model(build_model_1, 'build_model_1.png', show_layer_names=True) помощью , я получаю следующую ошибку:

 ~/Studies/Thesis/MA_Naren_Sadhwani/06-Git/ma_naren_sadhwani/src/CNN/CNN.py in lt;modulegt; ----gt; 1 tf.keras.utils.plot_model(build_model_1, 'build_model_1.png', show_layer_names=True)  ~/opt/anaconda3/envs/Talos_env/lib/python3.9/site-packages/tensorflow/python/keras/utils/vis_utils.py in plot_model(model, to_file, show_shapes, show_dtype, show_layer_names, rankdir, expand_nested, dpi)  320 This enables in-line display of the model plots in notebooks.  321 """ --gt; 322 dot = model_to_dot(  323 model,  324 show_shapes=show_shapes,  ~/opt/anaconda3/envs/Talos_env/lib/python3.9/site-packages/tensorflow/python/keras/utils/vis_utils.py in model_to_dot(model, show_shapes, show_dtype, show_layer_names, rankdir, expand_nested, dpi, subgraph)  130 sub_w_last_node = {}  131  --gt; 132 layers = model.layers  133 if not model._is_graph_network:  134 node = pydot.Node(str(id(model)), label=model.name)  AttributeError: 'function' object has no attribute 'layers'  

Может ли кто-нибудь подсказать мне, чего мне здесь не хватает? Если да, то каков был бы правильный способ построения модели или определения модели?

Комментарии:

1. Вы передаете функцию вместо модели в plot_model

2. Спасибо. Это сработало