#python-3.x #tensorflow
#python-3.x #tensorflow
Вопрос:
Веса в классах, наследуемых от tf.keras.Модель, похоже, не может загрузиться в данный момент. Я не могу загрузить веса из Example() за пределами класса, используя контрольную точку, поэтому я попытался сделать это внутри, что, по общему мнению, должно работать. Он может сохранять веса, как и при простом сохранении Example() , однако он по-прежнему не может их загрузить. Это мой код модели:
class Example(tf.keras.Model):
def __init__(self, cfg):
super(Example, self).__init__()
self.model = tf.keras.Sequential([
........layers.......
])
# Create saver
self.save_path = cfg.save_dir cfg.extension
self.ckpt_prefix = self.save_path '/ckpt'
self.saver = tf.train.Checkpoint(model=self.model)
def call(self, x_in):
x_out = self.model(x_in)
return x_out
def save(self):
self.saver.save(file_prefix=self.ckpt_prefix)
def load(self):
self.saver.restore(tf.train.latest_checkpoint(self.save_path))
И это то, что я использую, чтобы проверить, загружается ли он:
example = Example()
if Path(self.example.save_path).is_dir():
print(self.example.weights)
print(self.example.model.weights)
self.example.load()
print(self.example.weights)
print(self.example.model.weights)
Вывод:
[]
[]
[]
[]
Это было протестировано как на tensorflow 1.3, так и на 2.0, и я могу подтвердить, что веса не являются пустыми после первой партии, а также что это контрольная точка / сохранение.
Комментарии:
1. Вы проверили статус восстановления?
status = self.saver.restore(...); status.assert_existing_objects_matched(); status.assert_consumed();
2. Спасибо за ответ! Когда я
print(status)
, я получаю:<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f45fc082630>
3. И когда я запускаю
status.assert_existing_objects_matched()
, я получаю:File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed raise AssertionError("Unresolved object in checkpoint: %s" % (node,)) AssertionError: Unresolved object in checkpoint: attributes { name: "VARIABLE_VALUE" full_name: "Variable" checkpoint_key: "best_acc/.ATTRIBUTES/VARIABLE_VALUE" }
4. Я полагаю, это означает, что что-то идет не так с контрольной точкой, но как мне это отладить?
5. Я узнал, что это происходит, когда модель повторно инициализируется, то есть всякий раз, когда вам нужно будет фактически использовать контрольную точку. Я привел минимальный жизнеспособный пример и проблему с tensorflow. Я думаю, что это ошибка.
Ответ №1:
Как выясняется, существует три разных способа, которыми TensorFlow выполняет контрольную точку, в зависимости от того, что проверяется.
- Объект с контрольной точкой — это просто переменная. Это восстанавливается сразу после вызова
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))
. - Объект с контрольной точкой представляет собой модель с определенной формой ввода. Это также немедленно восстанавливается.
- Объект с контрольной точкой представляет собой модель без определенной формы ввода. Именно здесь поведение меняется, поскольку TensorFlow выполняет «отложенное» восстановление и НЕ восстанавливает веса модели до тех пор, пока входные данные не будут переданы в модель.
Вот пример:
import os
import tensorflow as tf
import numpy as np
# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()
# Create model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(256, 3, padding="same"),
tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty before training?", model.weights == [])
# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)
# Make fake data
img = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
# Train
with tf.GradientTape() as tape:
logits = model(img)
loss = tf.losses.mean_squared_error(truth, logits)
# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)
optimizer.apply_gradients(grads_and_vars)
# Save model
checkpoint_path = './ckpt/'
checkpoint.save('./ckpt/')
# Check if weights update
print("Are weights empty after training?", model.weights == [])
# Reset model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(256, 3, padding="same"),
tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty when resetting model?", model.weights == [])
# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))
# This next line is REQUIRED to restore
#model(img)
print("Are weights empty after restoring from checkpoint?", model.weights == [])
print(status)
status.assert_existing_objects_matched()
status.assert_consumed()
С выводом:
Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f6256b4ddd8>
Traceback (most recent call last):
File "test.py", line 58, in <module>
status.assert_consumed()
File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed
raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
name: "VARIABLE_VALUE"
full_name: "sequential/conv2d/kernel"
checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
}
Однако раскомментирование строки model(img)
приведет к следующему результату:
Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? False
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7ff62320fe48>
Поэтому необходимо передать входные данные, чтобы правильно восстановить модель, инвариантную к форме.
Ссылки:
https://www.tensorflow.org/alpha/guide/checkpoints#delayed_restorations
https://github.com/tensorflow/tensorflow/issues/27937
Комментарии:
1. Поведение для меня немного отличается от этого, однако для меня, по сути, суть та же. Если я загружу контрольную точку и запущу модель без загрузки фиктивного изображения, она не будет вести себя так, как ожидалось. ЕСЛИ я восстановлю контрольную точку, а затем загружу изображение в модель перед началом переподготовки, она будет работать хорошо. Спасибо!!