Tf-trt преобразование saved_model.pb сбой

#tensorflow #tensorrt #tensorrt-python

Вопрос:

Вот мой код преобразования.

 import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow import gfile, compat
from tensorflow.core.protobuf import saved_model_pb2
# from tensorflow.core.protobuf import meta_graph_pb2


with tf.Session() as sess:
        converter = trt.TrtGraphConverter(input_saved_model_dir=".", 
                                          input_saved_model_signature_key='predictor',
                                          input_saved_model_tags=[tf.saved_model.tag_constants.SERVING],
                                          is_dynamic_op=True,
                                          precision_mode="FP16"
                                         )
        converter.convert()
        converter.save("trt_model.pb")
 

И когда он выполняется, он создает исключения:

 20066 Traceback (most recent call last):
20067   File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_gr      aph_def_internal
20068     graph._c_graph, serialized, options)  # pylint: disable=protected-access
20069 tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node global_step/Assign was passed int64 from       global_step:0 incompatible with expected int64_ref.
20070 
20071 During handling of the above exception, another exception occurred:
20072 
20073 Traceback (most recent call last):
20074   File "converter.py", line 24, in <module>
20075     converter.save("trt_model.pb")
20076   File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/compiler/tensorrt/trt_convert.py", line 717, in       save
20077     importer.import_graph_def(self._converted_graph_def, name="")
20078   File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
20079     return func(*args, **kwargs)
20080   File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_gra      ph_def
20081     producer_op_list=producer_op_list)
20082   File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_gr      aph_def_internal
20083     raise ValueError(str(e))
20084 ValueError: Input 0 of node global_step/Assign was passed int64 from global_step:0 incompatible with expected int64_ref
 

Я проверяю узел графика.
перед обращением это похоже на

 ···
global_step/Initializer/zeros Const
global_step VariableV2
global_step/Assign Assign
global_step/read Identity
···
 

после того, как это

 ···
IteratorGetNext IteratorGetNext
global_step/Assign Assign
inference/embed_continuous/Initializer/truncated_normal/TruncatedNormal TruncatedNormal
···
 

более конкретно
, прежде чем

 name: "global_step/Assign"
op: "Assign"
input: "global_step"
input: "global_step/Initializer/zeros"
attr {
  key: "T"
  value {
    type: DT_INT64
  }
}
attr {
  key: "_class"
  value {
    list {
      s: "loc:@global_step"
    }
  }
}
attr {
  key: "_output_shapes"
  value {
    list {
      shape {
      }
    }
  }
}
attr {
  key: "use_locking"
  value {
    b: true
  }
}
attr {
  key: "validate_shape"
  value {
    b: true
  }
}
 

after

 name: "global_step/Assign"
op: "Assign"
input: "global_step"
input: "global_step/Initializer/zeros"
attr {
  key: "T"
  value {
    type: DT_INT64
  }
}
attr {
  key: "use_locking"
  value {
    b: true
  }
}
attr {
  key: "validate_shape"
  value {
    b: true
  }
}
 

должно быть, что-то не так с моим кодом или моделью, но я не могу понять. Кто-нибудь может мне помочь?

Похоже, что ваш пост в основном состоит из кода; пожалуйста, добавьте еще несколько деталей. …