Вывод тензорного потока с использованием сохраненной модели

#python #tensorflow

#python #тензорный поток

Вопрос:

Я новичок в TensorFlow 2.3.1 и пытаюсь понять, как выполняется вывод. После загрузки сохраненной модели я хочу передать тензор только с единицами, чтобы убедиться, что модель выдает то, что мы ожидаем. Например…

 import tensorflow as tf

resnet18_tf = tf.saved_model.load("resnet18.tf")
x_tf = tf.ones((1,3,224,224), tf.float32)

resnet18_tf(x_tf)
 

Однако приведенный выше код приводит к следующей ошибке…

 ---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-34-33fa05a7412b> in <module>
      4 x_tf = tf.ones((1,3,224,224), tf.float32)
      5 
----> 6 resnet18_tf(x_tf)

ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (1 total):
    * Tensor("None_0:0", shape=(1, 3, 224, 224), dtype=float32)
  Keyword arguments: {}

Expected these arguments to match one of the following 1 option(s):

Option 1:
  Positional arguments (0 total):
    * 
  Keyword arguments: {'input': TensorSpec(shape=(1, 3, 224, 224), dtype=tf.float32, name='input')}
 

Я почти уверен, что форма правильная, но я изо всех сил пытаюсь интерпретировать это сообщение об ошибке. Как вы вводите TensorSpec для устранения этой ошибки?

Комментарии:

1. может быть, попробуйте передать его в качестве аргумента ключевого слова? resnet18_tf(input=x_tf)

2. Вау, это сработало… Спасибо!

Ответ №1:

Сообщение об ошибке

 Expected these arguments to match one of the following 1 option(s):

Option 1:
  Positional arguments (0 total):
    * 
  Keyword arguments: {'input': TensorSpec(shape=(1, 3, 224, 224), dtype=tf.float32, name='input')}
 

предполагает, что функция ожидает аргументы ключевого слова и никаких позиционных аргументов. Словарь указывает, что ключевое слово является input .

 import tensorflow as tf

resnet18_tf = tf.saved_model.load("resnet18.tf")
x_tf = tf.ones((1,3,224,224), tf.float32)

resnet18_tf(input=x_tf)