Нарушена ли загрузка в eager TensorFlow прямо сейчас?

#python-3.x #tensorflow

#python-3.x #tensorflow

Вопрос:

Веса в классах, наследуемых от tf.keras.Модель, похоже, не может загрузиться в данный момент. Я не могу загрузить веса из Example() за пределами класса, используя контрольную точку, поэтому я попытался сделать это внутри, что, по общему мнению, должно работать. Он может сохранять веса, как и при простом сохранении Example() , однако он по-прежнему не может их загрузить. Это мой код модели:

 class Example(tf.keras.Model):
    def __init__(self, cfg):
        super(Example, self).__init__()

        self.model = tf.keras.Sequential([
             ........layers.......
        ])

        # Create saver
        self.save_path = cfg.save_dir   cfg.extension
        self.ckpt_prefix = self.save_path   '/ckpt'
        self.saver = tf.train.Checkpoint(model=self.model)

    def call(self, x_in):
        x_out = self.model(x_in)
        return x_out

    def save(self):
        self.saver.save(file_prefix=self.ckpt_prefix)

    def load(self):
        self.saver.restore(tf.train.latest_checkpoint(self.save_path))
 

И это то, что я использую, чтобы проверить, загружается ли он:

 example = Example()
if Path(self.example.save_path).is_dir():
            print(self.example.weights)
            print(self.example.model.weights)
            self.example.load()
            print(self.example.weights)
            print(self.example.model.weights)
 

Вывод:

 []
[]
[]
[]
 

Это было протестировано как на tensorflow 1.3, так и на 2.0, и я могу подтвердить, что веса не являются пустыми после первой партии, а также что это контрольная точка / сохранение.

Комментарии:

1. Вы проверили статус восстановления? status = self.saver.restore(...); status.assert_existing_objects_matched(); status.assert_consumed();

2. Спасибо за ответ! Когда я print(status) , я получаю: <tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f45fc082630>

3. И когда я запускаю status.assert_existing_objects_matched() , я получаю: File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed raise AssertionError("Unresolved object in checkpoint: %s" % (node,)) AssertionError: Unresolved object in checkpoint: attributes { name: "VARIABLE_VALUE" full_name: "Variable" checkpoint_key: "best_acc/.ATTRIBUTES/VARIABLE_VALUE" }

4. Я полагаю, это означает, что что-то идет не так с контрольной точкой, но как мне это отладить?

5. Я узнал, что это происходит, когда модель повторно инициализируется, то есть всякий раз, когда вам нужно будет фактически использовать контрольную точку. Я привел минимальный жизнеспособный пример и проблему с tensorflow. Я думаю, что это ошибка.

Ответ №1:

Как выясняется, существует три разных способа, которыми TensorFlow выполняет контрольную точку, в зависимости от того, что проверяется.

  1. Объект с контрольной точкой — это просто переменная. Это восстанавливается сразу после вызова checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path)) .
  2. Объект с контрольной точкой представляет собой модель с определенной формой ввода. Это также немедленно восстанавливается.
  3. Объект с контрольной точкой представляет собой модель без определенной формы ввода. Именно здесь поведение меняется, поскольку TensorFlow выполняет «отложенное» восстановление и НЕ восстанавливает веса модели до тех пор, пока входные данные не будут переданы в модель.

Вот пример:

 import os
import tensorflow as tf
import numpy as np

# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()

# Create model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty before training?", model.weights == [])

# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)

# Make fake data
img = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
# Train
with tf.GradientTape() as tape:
    logits = model(img)
    loss = tf.losses.mean_squared_error(truth, logits)

# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)
optimizer.apply_gradients(grads_and_vars)

# Save model
checkpoint_path = './ckpt/'
checkpoint.save('./ckpt/')

# Check if weights update
print("Are weights empty after training?", model.weights == [])

# Reset model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty when resetting model?", model.weights == [])

# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))

# This next line is REQUIRED to restore
#model(img)

print("Are weights empty after restoring from checkpoint?", model.weights == [])
print(status)
status.assert_existing_objects_matched()
status.assert_consumed()
 

С выводом:

 Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f6256b4ddd8>
Traceback (most recent call last):
  File "test.py", line 58, in <module>
    status.assert_consumed()
  File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed
    raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
  name: "VARIABLE_VALUE"
  full_name: "sequential/conv2d/kernel"
  checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
}
 

Однако раскомментирование строки model(img) приведет к следующему результату:

 Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? False
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7ff62320fe48>
 

Поэтому необходимо передать входные данные, чтобы правильно восстановить модель, инвариантную к форме.

Ссылки:

https://www.tensorflow.org/alpha/guide/checkpoints#delayed_restorations
https://github.com/tensorflow/tensorflow/issues/27937

Комментарии:

1. Поведение для меня немного отличается от этого, однако для меня, по сути, суть та же. Если я загружу контрольную точку и запущу модель без загрузки фиктивного изображения, она не будет вести себя так, как ожидалось. ЕСЛИ я восстановлю контрольную точку, а затем загружу изображение в модель перед началом переподготовки, она будет работать хорошо. Спасибо!!