Переменные, не найденные на графике для «convert_variables_to_константы»

#python #tensorflow

Вопрос:

ВСЕ

Я приведу несколько простых примеров о глубокой модели с квантованием.

Есть два файла. DnnModel.py и TrainModel.py

DnnModel.py

 import numpy as np
import tensorflow as tf

class DNN:

  def __init__(self, archList, X, reuse, scopeName):
      
      self.scopeName = scopeName
      
      with tf.variable_scope(self.scopeName, reuse = reuse): 
    
         self.logits = 0
         
         self.weights = []
         self.bias = []
         self.midLayers = []
         
         for midLayerIdx in range(1, len(archList)):
             
             self.weights.append(tf.Variable(tf.random_normal([archList[midLayerIdx-1], archList[midLayerIdx]])))
             self.bias.append(tf.Variable(tf.random_normal([archList[midLayerIdx]])))
             
             if midLayerIdx == 1:
                
                self.midLayers.append(tf.nn.relu(tf.add(tf.matmul(X, self.weights[-1]), self.bias[-1])))                    
                
             elif midLayerIdx == len(archList) - 1:
                
                self.logits = tf.add(tf.matmul(self.midLayers[-1], self.weights[-1]), self.bias[-1])
                
             else:
                
                self.midLayers.append(tf.nn.relu(tf.add(tf.matmul(self.midLayers[-1], self.weights[-1]), self.bias[-1]))) 
      
         self.prediction = tf.nn.softmax(self.logits)
      
         self.netVariables = tf.trainable_variables()
      
      self.variables = [var for var in self.netVariables if var.name.startswith(self.scopeName)]
 

TrainModel.py

 import numpy as np
import tensorflow as tf
from DnnModel import DNN

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

(train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data()

# Neural network architecture
archList = [784, 512, 256, 10]

# Training procedure settings
train_batch_size = 50
train_batch_number = train_data.shape[0]
quant_delay_epoch = 1

learning_rate = 0.001
num_steps = 500
display_step = 100


# [Graph][Session] Tensorflow graph and session
train_graph = tf.Graph()                      # [Graph] Create graph
train_sess = tf.Session(graph = train_graph)  # [Session] Create corresponding session

with train_graph.as_default():

     # [Model] Build model
     # Building of neural network must be in the scope of graph !!!!

     X = tf.placeholder("float", [None, archList[0]])
     Y = tf.placeholder("float", [None, archList[-1]])

     train_dnn = DNN(archList, X, True, 'DNN')

     correct_pred = tf.equal(tf.argmax(train_dnn.prediction, 1), tf.argmax(Y, 1))
     accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

     # [Graph][Quantization Aware Training] 
     # Create fake quantiation flow with created graph
     tf.contrib.quantize.create_training_graph(input_graph = train_graph, quant_delay = int(train_batch_number / train_batch_size * quant_delay_epoch))

     # [Optimizer][Loss] 
     # Definition of loss function must be ''after QAT'' !!!!!
     loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = train_dnn.logits, labels = Y))
     optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate = learning_rate)
     train_op = optimizer.minimize(loss_op)

     train_sess.run(tf.global_variables_initializer())    

     for step in range(1, num_steps 1):
         batch_x, batch_y = mnist.train.next_batch(train_batch_size)
    
         # Run optimization op (backprop)
         train_sess.run(train_op, feed_dict = {X: batch_x, Y: batch_y})
         if step % display_step == 0 or step == 1:
            # Calculate batch loss and accuracy
            loss, acc = train_sess.run([loss_op, accuracy], feed_dict={X: batch_x,
                                                                   Y: batch_y})
         print("Step "   str(step)   ", Minibatch Loss= "   
              "{:.4f}".format(loss)   ", Training Accuracy= "   
              "{:.3f}".format(acc))

     print("Optimization Finished!")

     # Calculate accuracy for MNIST test images
     print("Testing Accuracy:", 
           train_sess.run(accuracy, feed_dict={X: mnist.test.images,
                                        Y: mnist.test.labels}))

     saver = tf.compat.v1.train.Saver()
     saver.save(train_sess, './Native_Tensorflow_Model/path_to_checkpoints')

 # [Graph][Model] Save the frozen graph
 inference_graph = tf.Graph()
 inference_sess = tf.Session(graph = inference_graph)

with inference_graph.as_default():

X = tf.placeholder("float", [None, archList[0]])
inference_dnn = DNN(archList, X, True, 'DNN')

tf.contrib.quantize.create_eval_graph(input_graph = inference_graph)
inference_graph_def = inference_graph.as_graph_def()

print("All variables : ", [n.name for n in tf.get_default_graph().as_graph_def().node])

inference_graph_def = inference_graph.as_graph_def()
inference_saver = tf.compat.v1.train.Saver()
inference_saver.restore(inference_sess, './Native_Tensorflow_Model/path_to_checkpoints')

graphVariables = [var.name for var in inference_graph.as_graph_def().node]
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
    inference_sess,
    inference_graph_def,
    graphVariables)

with open('./Native_Tensorflow_Model/path_to_frozen_graph.pb', 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())
 

На шаге tf.compat.v1.graph_util.convert_variables_to_constants,

журнал «Ошибка утверждения: сохранить/имя файла/ввод не указан в графике«. Однако я не видел ни одного элемента с этим именем в списке имен.

Как я мог бы решить эту проблему? Большое спасибо !!

Ответ №1:

Я нашел решение. Ссылка : введите описание ссылки здесь и введите описание ссылки здесь

Оригинальное содержимое

 graphVariables = [var.name for var in inference_graph.as_graph_def().node]

frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
inference_sess,
inference_graph_def,
graphVariables)
 

Измененное содержимое

 graphVariables = [var.op.name for var in inference_dnn.variables]

frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
    inference_sess,
    inference_graph_def,
    graphVariables)