#python #tensorflow
Вопрос:
ВСЕ
Я приведу несколько простых примеров о глубокой модели с квантованием.
Есть два файла. DnnModel.py и TrainModel.py
DnnModel.py
import numpy as np
import tensorflow as tf
class DNN:
def __init__(self, archList, X, reuse, scopeName):
self.scopeName = scopeName
with tf.variable_scope(self.scopeName, reuse = reuse):
self.logits = 0
self.weights = []
self.bias = []
self.midLayers = []
for midLayerIdx in range(1, len(archList)):
self.weights.append(tf.Variable(tf.random_normal([archList[midLayerIdx-1], archList[midLayerIdx]])))
self.bias.append(tf.Variable(tf.random_normal([archList[midLayerIdx]])))
if midLayerIdx == 1:
self.midLayers.append(tf.nn.relu(tf.add(tf.matmul(X, self.weights[-1]), self.bias[-1])))
elif midLayerIdx == len(archList) - 1:
self.logits = tf.add(tf.matmul(self.midLayers[-1], self.weights[-1]), self.bias[-1])
else:
self.midLayers.append(tf.nn.relu(tf.add(tf.matmul(self.midLayers[-1], self.weights[-1]), self.bias[-1])))
self.prediction = tf.nn.softmax(self.logits)
self.netVariables = tf.trainable_variables()
self.variables = [var for var in self.netVariables if var.name.startswith(self.scopeName)]
TrainModel.py
import numpy as np
import tensorflow as tf
from DnnModel import DNN
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
(train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data()
# Neural network architecture
archList = [784, 512, 256, 10]
# Training procedure settings
train_batch_size = 50
train_batch_number = train_data.shape[0]
quant_delay_epoch = 1
learning_rate = 0.001
num_steps = 500
display_step = 100
# [Graph][Session] Tensorflow graph and session
train_graph = tf.Graph() # [Graph] Create graph
train_sess = tf.Session(graph = train_graph) # [Session] Create corresponding session
with train_graph.as_default():
# [Model] Build model
# Building of neural network must be in the scope of graph !!!!
X = tf.placeholder("float", [None, archList[0]])
Y = tf.placeholder("float", [None, archList[-1]])
train_dnn = DNN(archList, X, True, 'DNN')
correct_pred = tf.equal(tf.argmax(train_dnn.prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# [Graph][Quantization Aware Training]
# Create fake quantiation flow with created graph
tf.contrib.quantize.create_training_graph(input_graph = train_graph, quant_delay = int(train_batch_number / train_batch_size * quant_delay_epoch))
# [Optimizer][Loss]
# Definition of loss function must be ''after QAT'' !!!!!
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = train_dnn.logits, labels = Y))
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate = learning_rate)
train_op = optimizer.minimize(loss_op)
train_sess.run(tf.global_variables_initializer())
for step in range(1, num_steps 1):
batch_x, batch_y = mnist.train.next_batch(train_batch_size)
# Run optimization op (backprop)
train_sess.run(train_op, feed_dict = {X: batch_x, Y: batch_y})
if step % display_step == 0 or step == 1:
# Calculate batch loss and accuracy
loss, acc = train_sess.run([loss_op, accuracy], feed_dict={X: batch_x,
Y: batch_y})
print("Step " str(step) ", Minibatch Loss= "
"{:.4f}".format(loss) ", Training Accuracy= "
"{:.3f}".format(acc))
print("Optimization Finished!")
# Calculate accuracy for MNIST test images
print("Testing Accuracy:",
train_sess.run(accuracy, feed_dict={X: mnist.test.images,
Y: mnist.test.labels}))
saver = tf.compat.v1.train.Saver()
saver.save(train_sess, './Native_Tensorflow_Model/path_to_checkpoints')
# [Graph][Model] Save the frozen graph
inference_graph = tf.Graph()
inference_sess = tf.Session(graph = inference_graph)
with inference_graph.as_default():
X = tf.placeholder("float", [None, archList[0]])
inference_dnn = DNN(archList, X, True, 'DNN')
tf.contrib.quantize.create_eval_graph(input_graph = inference_graph)
inference_graph_def = inference_graph.as_graph_def()
print("All variables : ", [n.name for n in tf.get_default_graph().as_graph_def().node])
inference_graph_def = inference_graph.as_graph_def()
inference_saver = tf.compat.v1.train.Saver()
inference_saver.restore(inference_sess, './Native_Tensorflow_Model/path_to_checkpoints')
graphVariables = [var.name for var in inference_graph.as_graph_def().node]
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
inference_sess,
inference_graph_def,
graphVariables)
with open('./Native_Tensorflow_Model/path_to_frozen_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
На шаге tf.compat.v1.graph_util.convert_variables_to_constants,
журнал «Ошибка утверждения: сохранить/имя файла/ввод не указан в графике«. Однако я не видел ни одного элемента с этим именем в списке имен.
Как я мог бы решить эту проблему? Большое спасибо !!
Ответ №1:
Я нашел решение. Ссылка : введите описание ссылки здесь и введите описание ссылки здесь
Оригинальное содержимое
graphVariables = [var.name for var in inference_graph.as_graph_def().node]
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
inference_sess,
inference_graph_def,
graphVariables)
Измененное содержимое
graphVariables = [var.op.name for var in inference_dnn.variables]
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
inference_sess,
inference_graph_def,
graphVariables)