Отладка Tensorflow 2.0: сбой печати в tf.функции

#python #tensorflow #keras #neural-network

#питон #тензорный поток #keras #нейронная сеть

Вопрос:

Я пытаюсь отладить относительно сложный пользовательский метод обучения с использованием пользовательских функций потерь и т.д. В частности, я пытаюсь отладить проблему на пользовательском этапе обучения, который компилируется в Tensorflow @function и устанавливается как скомпилированная модель Keras. Я хочу иметь возможность распечатать промежуточное значение тензора при сбое вызова функции. Трудность заключается в том, что, поскольку тензоры внутри @function являются значениями графика и не вычисляются немедленно, а поскольку функция выходит из строя во время вычисления, кажется, что значения на самом деле не вычисляются. Вот простой пример:

 class debug_model(tf.keras.Model):
    def __init__(self, width,depth,insize,outsize,batch_size):
        super(debug_model, self).__init__()
        self.width = width
        self.depth = depth
        self.insize = insize
        self.outsize = outsize
        
        self.net = tf.keras.models.Sequential()
        self.net.add(tf.keras.Input(shape = (insize,)))
        for i in range(depth):
            self.net.add(tf.keras.layers.Dense(width,activation = 'swish'))
        self.net.add(tf.keras.layers.Dense(outsize))
    
    def call(self,ipts):
        
        return self.net(ipts)
    
    @tf.function
    def train_step(self,data):
        ipt, target = data
        with tf.GradientTape(persistent=True) as tape_1:
            tape_1.watch(ipt)
            y = self(ipt)

            tf.print('y:',y)
            assert False
            loss = tf.keras.losses.MAE(target,y)
        trainable_vars = self.trainable_variables
        loss_grad = tape_1.gradient(loss,trainable_vars)
        self.optimizer.apply_gradients(zip(loss_grad, trainable_vars))
        self.compiled_metrics.update_state(target, y)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
 

Если вы скомпилируете эту модель с некоторыми данными по вашему выбору и запустите ее:

 train_set = tf.data.Dataset.from_tensor_slices(data_tuple).batch(opt.batchSize)
train_set.shuffle(buffer_size = trainpoints)

model = debug_model(opt.width,opt.depth,in_size,out_size,batchSize)
optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr)    
lr_sched = lambda epoch, lr: lr * 0.95**(1 / (8))    
cb_scheduler = tf.keras.callbacks.LearningRateScheduler(schedule = lr_sched, verbose = 1)
model.build((None,1))    
model.summary()    
model.compile(optimizer=optimizer,
              loss = tf.keras.losses.MeanAbsoluteError(),
    )
    
callbacks = [
tf.keras.callbacks.ModelCheckpoint(path,
    verbose=2
),
cb_scheduler,
tf.keras.callbacks.CSVLogger(path 'log.csv')
]
    
hist = model.fit(train_set,epochs = opt.nEpochs,callbacks = callbacks)
 

Если вы загрузите это и запустите его, вы увидите, что он завершается из-за ошибки утверждения без печати. Есть ли способ заставить этот тензор оценить, чтобы я мог его распечатать?

Комментарии:

1. assert False приведет к сбою функции при построении графика. tf.print предназначен для печати данных при фактической оценке функции (когда она вызывается после сборки). Просто удалите утверждение.

2. Я понимаю, что утверждение приводит к сбою функции, это единственная цель «утверждать ложь». Это пример сбоя «train_step», который воспроизводит проблему, с которой я сталкиваюсь при более сложной реализации, которая выходит из строя по причине, которую я пока не понимаю. Утверждение здесь является заменой неизвестной причины сбоя, которую я пытаюсь диагностировать.

3. Или дело в том, что сбой происходит до того, как там вообще есть какие-либо значения, поэтому печатать нечего?

4. Да, assert произойдет сбой функции при компиляции, а не когда она фактически выполняется с «реальными» значениями, тогда tf.print как будет печататься только в последнем случае, так что это, вероятно, не то, что вы хотите. Возможно, это более уместно: tensorflow.org/api_docs/python/tf/debugging/Assert