#python #tensorflow #keras #protocol-buffers
#питон #тензорный поток #keras #протокол-буферы
Вопрос:
для моего проекта я пытаюсь сделать вывод на основе моей обученной модели, хранящейся в файле saved_model.pb. Я сомневаюсь, что ошибка связана с моим кодом, который вы можете увидеть здесь, но, скорее всего, из-за проблемы с установкой:
from PIL import Image
import numpy as np
import scipy
from scipy import misc
import matplotlib.pyplot as plt
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
with tf.Graph().as_default() as graph: # Set default graph as graph
with tf.Session() as sess:
# Load the graph in graph_def
print("load graph")
# We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
with gfile.FastGFile("saved_model.pb",'rb') as f:
from scipy.io import wavfile
samplerate, data = wavfile.read('sound.wav')
# Set FCN graph to the default graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
# Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None
)
# Print the name of operations in the session
for op in graph.get_operations():
print("Operation Name :",op.name) # Operation name
print("Tensor Stats :",str(op.values())) # Tensor name
# INFERENCE Here
l_input = graph.get_tensor_by_name('Inputs/fifo_queue_Dequeue:0') # Input Tensor
l_output = graph.get_tensor_by_name('upscore32/conv2d_transpose:0') # Output Tensor
print("Shape of input : ", tf.shape(l_input))
f.global_variables_initializer()
# Run model on single image
Session_out = sess.run( m_output, feed_dict = {m_input : data} )
print("Predicted class:", class_names[Session_out[0].argmax()] )
Обратная трассировка выглядит следующим образом
Traceback (most recent call last):
File "/home/pi/model_inference/test.py", line 11, in <module>
graph_def.ParseFromString(f.read())
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/message.py", line 199, in ParseFromString
return self.MergeFromString(serialized)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1145, in MergeFromString
if self._InternalParse(serialized, 0, length) != length:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 754, in DecodeField
if value._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 733, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 888, in DecodeMap
if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1199, in InternalParse
buffer, new_pos, wire_type) # pylint: disable=protected-access
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 989, in _DecodeUnknownField
(data, pos) = _DecodeUnknownFieldSet(buffer, pos)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 968, in _DecodeUnknownFieldSet
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 993, in _DecodeUnknownField
raise _DecodeError('Wrong wire type in tag.')
google.protobuf.message.DecodeError: Wrong wire type in tag.
Важно отметить, что я пробую это на Raspberry pi v4 (таким образом, на нем работает Linux). Я был бы рад любому намеку, что делать.
Заранее спасибо!
Комментарии:
1. Здравствуйте, вам удалось решить эту проблему?
Ответ №1:
Попробуйте использовать tf.saved_model.load(path)
where path
— это путь к папке вашей сохраненной модели, содержащей папку assets и variables.перейдите по этой ссылке на руководство по выводу tensorflow object-detection api.
Ответ №2:
Похоже, что ваш файл "saved_model.pb"
не является сохраненным (wireformat) протобуфером типа message GraphDef
. Может быть, вы можете посмотреть, как это было сохранено, и найти несколько инструкций о том, как загрузить его обратно? Просто догадываюсь по названию, может ли это быть модель keras, и вы должны использовать tf.keras.models.load_model
?