#python #tensorflow
#python #тензорный поток
Вопрос:
Я хочу загрузить одну и ту же переменную в предварительно подготовленной модели в несколько переменных в новой модели
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1 1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(v1)
with tf.Session() as sess:
sess.run(init_op)
sess.run(v1 1)
save_path = saver.save(sess, "/tmp/model.ckpt")
и послесловия
# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[3])
# Add ops to save and restore all the variables.
saver = tf.train.Saver({"v1" : v1,"v1":v2})
with tf.Session() as sess:
saver.restore(sess, "/tmp/model.ckpt")
Т. е. я хочу, чтобы обе переменные были инициализированы из переменной v1 из предыдущей модели.
Следующий пример завершается сбоем, поскольку в нем говорится, что графики разные.
Ответ №1:
Вычислите присвоенное значение переменной из исходного графика, а затем инициализируйте новые переменные из нового графика с этим значением:
import tensorflow as tf
with tf.Graph().as_default():
# the variable from the original graph
v0 = tf.Variable(tf.random_normal([2, 2]))
with tf.Session(graph=v0.graph) as sess:
sess.run(v0.initializer)
init_val = v0.eval() # <-- evaluate the assigned value
print('original graph:')
print(init_val)
# original graph:
# [[-1.7466899 1.1560178 ]
# [-0.46535382 1.7059366 ]]
# variables from new graph
with tf.Graph().as_default():
v1 = tf.Variable(init_val) # <-- variable from new graph
v2 = tf.Variable(init_val) # <-- variable from new graph
with tf.Session(graph=v1.graph) as sess:
sess.run([v.initializer for v in [v1, v2]])
print('new graph:')
print(v1.eval())
print(v2.eval())
# new graph:
# [[-1.7466899 1.1560178 ]
# [-0.46535382 1.7059366 ]]
# [[-1.7466899 1.1560178 ]
# [-0.46535382 1.7059366 ]]
Ответ №2:
Вот другой метод, повторяющий переменные из предыдущего графика:
def load_pretrained(sess):
checkpoint_path = 'pretrainedmodel.ckpt'
vars_to_load = [var for var in tf.get_collection(tf.GraphKeys.VARIABLES) if
("some_scope" in var.op.name)]
assign_ops = []
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
for var in vars_to_load:
for name,shape in tf.contrib.framework.list_variables(checkpoint_path):
if(var.op.name ~some regex comperison~ name):
assign_ops.append(tf.assign(var,reader.get_tensor(name)))
break
sess.run(assign_ops)