Как сохранить и загрузить selected и все переменные в tensorflow 2.0 с помощью tf.train.Контрольная точка?

#python #tensorflow #tensorflow2.0

#python #tensorflow #tensorflow2.0

Вопрос:

Как мне сохранить выбранные переменные в tensorflow 2.0, показанные ниже, в файл и загрузить их в некоторые определенные переменные в другом коде, используя tf.train.Контрольная точка?

 class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def saveVariables(self):
        # how to save self.initList's 3,55 and 60th elements and self.moreList's 4th element
  

Также, пожалуйста, покажите, как сохранить все переменные и перезагрузить с помощью tf.train.Контрольная точка. Заранее спасибо.

Комментарии:

1. Я не уверен, что понимаю проблему. Я предполагаю, что вы прочитали информацию о контрольных точках в версии 2.0 . Если вы создадите tf.train.Checkpoint для определенных переменных, которые вы хотите, это должно сработать, верно? Или, что мешает вам это сделать?

2. Я не понимаю приведенную выше официальную ссылку, она слишком сложная. Кроме того, я не понимаю, почему каждый tf-учебник должен быть написан для keras crap. Я был бы рад, если бы вы могли просто сохранить вышеупомянутые 3 переменные и восстановить. А также сохраните все 110 из них и восстановите с помощью tf.train. Контрольная точка ПРОСТЫМ СПОСОБОМ, В ОТЛИЧИЕ ОТ ДОКУМЕНТОВ.

Ответ №1:

Я не уверен, что это то, что вы имеете в виду, но вы можете создать tf.train.Checkpoint объект специально для переменных, которые вы хотите сохранить и восстановить. Смотрите следующий пример:

 import tensorflow as tf

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()
        self.ckpt = self.makeCheckpoint()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def makeCheckpoint(self):
        return tf.train.Checkpoint(
            init3=self.initList[3], init55=self.initList[55],
            init60=self.initList[60], more4=self.moreList[4])

    def saveVariables(self):
        self.ckpt.save('./ckpt')

    def restoreVariables(self):
        status = self.ckpt.restore(tf.train.latest_checkpoint('.'))
        status.assert_consumed()  # Optional check

# Create variables
v1 = manyVariables()
# Assigned fixed values
for i, v in enumerate(v1.initList):
    v.assign(i * tf.ones_like(v))
for i, v in enumerate(v1.moreList):
    v.assign(100   i * tf.ones_like(v))
# Save them
v1.saveVariables()

# Create new variables
v2 = manyVariables()
# Check initial values
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[-3.3441594  -0.18425298 -0.4898144  -1.2330629   0.08798431]
#  [ 1.5002227   0.99475247  0.7817361   0.3849587  -0.59548247]
#  [-0.57121766 -1.277224    0.6957546  -0.67618763  0.0510064 ]
#  [ 0.85491985  0.13310803 -0.93152267  0.10205163  0.57520276]
#  [-1.0606447  -0.16966362 -1.0448577   0.56799036 -0.90726566]]

# Restore them
v2.restoreVariables()
# Check values after restoring
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]]
  

Если вы хотите сохранить все переменные в списках, вы могли бы заменить makeCheckpoint что-то вроде этого:

 def makeCheckpoint(self):
    return tf.train.Checkpoint(
        **{f'init{i}': v for i, v in enumerate(self.initList)},
        **{f'more{i}': v for i, v in enumerate(self.moreList)})
  

Обратите внимание, что у вас могут быть «вложенные» контрольные точки, поэтому, в более общем плане, у вас могла бы быть функция, которая создает контрольную точку для списка переменных, например, вот так:

 def listCheckpoint(varList):
    # Use 'item{}'.format(i) if using Python <3.6
    return tf.train.Checkpoint(**{f'item{i}': v for i, v in enumerate(varList)})
  

Тогда вы могли бы просто использовать это:

 def makeCheckpoint(self):
    return tf.train.Checkpoint(init=listCheckpoint(self.initList),
                               more=listCheckpoint(self.moreList))
  

Комментарии:

1. Большое спасибо, чувак @jdehesa . Это именно то, чего я хотел. Пожалуйста, также покажите, как сохранить их все и восстановить, поскольку этот метод не будет работать для многих переменных. Тогда я приму ваш ответ и соответствующим образом отредактирую свой вопрос, чтобы другие также могли извлечь из него пользу.

2. @caissalover Я отредактировал ответ, посмотрите, охватывает ли он то, что вы искали.

3. Не совсем. Это грубая сила последнего. «Самый простой способ управлять переменными — присоединить их к объектам Python, а затем ссылаться на эти объекты. Подклассы tf.train. Контрольная точка, tf.keras. слои. Слой и tf.keras. Модель автоматически отслеживает переменные, присвоенные их атрибутам » В официальном документе, создавая контрольную точку и предоставляя v1, наш объект в качестве аргумента для нее должен сохранить его. Все еще не понимаю, что сделали эти ребята, но что-то вроде этого должно сохранять все в любом объекте manyVariables. Раньше было проще передавать объект сеанса в tf.train. Заставка

4. @caissalover Я согласен с вами, что многие изменения кажутся немного «способом Keras или шоссе». Раньше контрольные точки просто сохраняли переменные по умолчанию в коллекции глобальных переменных, но в 2.x коллекций нет, так что этого больше нет. Как говорится, если вы создаете свой материал на основе моделей / слоев Keras, все должно просто работать ™, но если вы хотите сделать что-то по-другому, вы в значительной степени предоставлены сами себе. Конечно, вы можете создать свой собственный класс, управляющий этим для вас, с переопределяемыми методами для создания переменных с контрольными точками… что будет повторять то, что делает TF / Keras.

5. Итак, я думаю, что ваш путь — это путь. Я думал сделать manyVariables подклассом tf.train. Контрольная точка должна что-то делать, но наследование — это не мое. Пока я буду придерживаться вашего пути. Без коллекций глобальных переменных определенно сложнее сохранять и восстанавливать.

Ответ №2:

В следующем коде я сохраняю массив, называемый variables, в текстовый файл с именем по вашему выбору. Этот файл будет находиться в той же папке, что и ваш файл python. ‘wb’ в функции open означает запись с усечением (то есть удаление всего, что ранее было в файле) и использует формат байт. Я использую pickle для обработки сохранения / синтаксического анализа списка.

 import pickle

    def saveVariables(self, variables): #where 'variables' is a list of variables
        with open("nameOfYourFile.txt", 'wb ') as file:
           pickle.dump(variables, file)

    def retrieveVariables(self, filename):
        variables = []
        with open(str(filename), 'rb') as file:
            variables = pickle.load(file)
        return variables
  

Чтобы сохранить определенный материал в свой файл, просто добавьте его в качестве аргумента variables в saveVariables вот так:

 myVariables = [initList[2], initList[54], initList[59], moreList[3]]
saveVariables(myVariables)
  

Для извлечения переменных из текстового файла с определенным именем:

 myVariables = retrieveVariables("theNameOfYourFile.txt")
thirdEl = myVariables[0]
fiftyFifthEl = myVariables[1]
SixtiethEl = myVariables[2]
fourthEl = myVariables[3]
  

Вы могли бы добавить эти функции в любое место класса.

Однако, чтобы иметь возможность получить доступ к initList / moreList в вашем примере, вы должны либо вернуть их из их функций (как я делаю со variables списком), либо сделать их глобальными.

Комментарии:

1. Сразу же появилось сообщение об ошибке, что переменные должны быть str в файле. write(). Удалось сохранить с помощью str (переменные) в виде файла. написал аргумент, но не смог загрузить и преобразовать обратно в tensor. Должен быть лучший метод с tf.train. Checkpoint().

2. упс, ты прав, я забыл это перепроверить. Вы могли бы заставить это работать, используя pickle, который позволяет сохранять многие форматы в виде байтов. Я обновлю ответ. У Tensorflow действительно может быть автоматическая функция для этого, я никогда с ней не работал, поэтому я не уверен.

3. Это сделало свое дело. Большое вам спасибо. Хотя должен быть и какой-то специфичный для tensorflow способ, поэтому я подожду еще немного другого ответа с tf.train. Контрольная точка, если это возможно..