Как установить ‘training = False’ при использовании customize test_step() для оценки модели в Keras, использующем customize train_step()?

#python #tensorflow #keras

#python #тензорный поток #keras

Вопрос:

Я использую tensorflow.keras и хочу оценить свою модель с помощью customize test_step (), который использует customize train_step () . Я хотел бы обучить и оценить следующую модель:

 class Whole_model(tf.keras.Model):
     def __init__(self, EEG_gen_model, emg_feature_extractor, 
                  eeg_feature_extractor, seq2seq_model):
      super(Whole_model, self).__init__()
      self.EEG_gen_model= EEG_gen_model
      self.emg_feature_extractor= emg_feature_extractor
      self.eeg_feature_extractor= eeg_feature_extractor
      self.seq2seq_model=seq2seq_model
    
    
  
   def compile(self, EEG_gen_optimizer, emg_feature_optim, eeg_feature_optim, 
       seq2seq_optim, EEG_gen_loss, seq2seq_loss_fn, gen_accuracy, accuracy):
       super(Whole_model, self).compile()
       self.EEG_gen_optimizer = EEG_gen_optimizer
       self.emg_feature_optim=emg_feature_optim
       self.eeg_feature_optim=eeg_feature_optim
       self.seq2seq_optim=seq2seq_optim
       self.EEG_gen_loss = EEG_gen_loss
       self.seq2seq_loss_fn=seq2seq_loss_fn
       self.gen_accuracy=gen_accuracy
       self.accuracy=accuracy
   

    #we can use different optimizer for each model

   def train_step(self, data):
       x_train, [y_train_eeg, y]= data
       y = tf.reshape(y, [-1, no_Epochs , 5])
       n_samples_per_epoch = x_train.shape[1]
       print(n_samples_per_epoch)
       eeg_gen_input=tf.reshape(x_train, [-1, n_samples_per_epoch, 1])
       y_eeg_gen= tf.reshape(y_train_eeg, [-1, n_samples_per_epoch, 1])


       #tf.argmax(pred_classes,1)
       # Train the EEG generator
       with tf.GradientTape() as tape:
          EEG_Gen= self.EEG_gen_model(eeg_gen_input)
          gen_model_loss= self.EEG_gen_loss(y_train_eeg, EEG_Gen)
          gen_accuracy=self.accuracy(y_train_eeg, EEG_Gen)
      grads = tape.gradient(gen_model_loss, self.EEG_gen_model.trainable_weights)
      self.EEG_gen_optimizer.apply_gradients(zip(grads, self.EEG_gen_model.trainable_weights))

    # #SEQ2SEQ 
    emg_inp = x_train
    eeg_inp = self.EEG_gen_model(emg_inp)
    emg_enc_seq=self.emg_feature_extractor(emg_inp)
    eeg_enc_seq=self.eeg_feature_extractor(eeg_inp)
    emg_eeg_attention_seq = tf.keras.layers.Attention()([emg_enc_seq, eeg_enc_seq])
    input_layer=tf.keras.layers.Concatenate()([emg_enc_seq, emg_eeg_attention_seq])
    len_epoch=input_layer.shape[1] 
    inputs=tf.reshape(input_layer, [-1, no_Epochs ,len_epoch]) 
   

    # Train the discriminator
    with tf.GradientTape() as tape:
        outputs=self.seq2seq_model(inputs)
        seq2seq_loss= self.seq2seq_loss_fn(y, outputs)
        accuracy=self.accuracy(y_train, tf.argmax(outputs,1))
    grads = tape.gradient(seq2seq_loss, self.seq2seq_model.trainable_weights)
    self.seq2seq_optim.apply_gradients(zip(grads, self.seq2seq_model.trainable_weights))

    #fEATURE EXTRACTOR
    emg_inp = x_train
    eeg_inp = self.EEG_gen_model(emg_inp)
    eeg_enc_seq=self.emg_feature_extractor(emg_inp)
   
    with tf.GradientTape() as tape:
        eeg_enc_seq=self.eeg_feature_extractor(eeg_inp)
        emg_eeg_attention_seq = tf.keras.layers.Attention()([emg_enc_seq, eeg_enc_seq])
        input_layer=tf.keras.layers.Concatenate()([emg_enc_seq, emg_eeg_attention_seq])
        len_epoch=input_layer.shape[1] 
        inputs=tf.reshape(input_layer, [-1, no_Epochs ,len_epoch])
        outputs=self.seq2seq_model(inputs)
        seq2seq_loss= self.seq2seq_loss_fn(y, outputs)
     grads = tape.gradient(seq2seq_loss, self.eeg_feature_extractor.trainable_weights)
     self.seq2seq_optim.apply_gradients(zip(grads, self.eeg_feature_extractor.trainable_weights))    



    emg_inp = x_train
    eeg_inp = self.EEG_gen_model(emg_inp)
    with tf.GradientTape() as tape:
        eeg_enc_seq=self.emg_feature_extractor(emg_inp)
        eeg_enc_seq=self.eeg_feature_extractor(eeg_inp)
        emg_eeg_attention_seq = tf.keras.layers.Attention()([emg_enc_seq, eeg_enc_seq])
        input_layer=tf.keras.layers.Concatenate()([emg_enc_seq, emg_eeg_attention_seq])
        len_epoch=input_layer.shape[1] 
        inputs=tf.reshape(input_layer, [-1, no_Epochs ,len_epoch])
        outputs=self.seq2seq_model(inputs)
        seq2seq_loss= self.seq2seq_loss_fn(y, outputs)
        accuracy=self.accuracy(y_train, tf.argmax(outputs,1))
    grads = tape.gradient(seq2seq_loss, self.emg_feature_extractor.trainable_weights)
   
   
  self.emg_feature_extractor.apply_gradients(zip(grads,self.emg_feature_extractor.trainable_weights))
    return {"seq2seq_loss": seq2seq_loss, 'gen_model_loss':gen_model_loss, "gen_accuracy": gen_accuracy, "accuracy": accuracy}
  

Где EEG_gen_model, emg_feature_extractor, eeg_feature_extractor и seq2seq_model являются вспомогательной моделью основной модели.

Теперь я хочу использовать customize test_step () для оценки модели. аналогично следующему:

     def test_step(self, data):
       x_emg, y = data
       no_Epochs=3
       y = tf.reshape(y, [-1, no_Epochs , 5])

       with tf.GradientTape() as tape:
         emg_inp = x_emg
         eeg_inp = self.EEG_gen_model(emg_inp)
         emg_enc_seq=self.emg_feature_extractor(emg_inp)
         eeg_enc_seq=self.eeg_feature_extractor(eeg_inp)
         emg_eeg_attention_seq = tf.keras.layers.Attention()([emg_enc_seq, eeg_enc_seq])
         input_layer=tf.keras.layers.Concatenate()([emg_enc_seq, emg_eeg_attention_seq])
         len_epoch=input_layer.shape[1] 
         inputs=tf.reshape(input_layer, [-1, no_Epochs ,len_epoch]) 
         outputs=self.seq2seq_model(inputs) # Forward pass
         # Compute our own loss
         loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

    # Compute gradients
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)

    # Update weights
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(y, y_pred)
    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}
  

Я сомневаюсь, «Как установить ‘training = False'» в разделе test_step () в этом сценарии?

Любые предложения приветствуются.

Ответ №1:

Попробуйте это:

 def test_step(self, data):
   x_emg, y = data
   no_Epochs=3
   y = tf.reshape(y, [-1, no_Epochs , 5])

   emg_inp = x_emg
   eeg_inp = self.EEG_gen_model(emg_inp)
   emg_enc_seq=self.emg_feature_extractor(emg_inp)
   eeg_enc_seq=self.eeg_feature_extractor(eeg_inp)
   emg_eeg_attention_seq = tf.keras.layers.Attention()([emg_enc_seq, eeg_enc_seq])
   input_layer=tf.keras.layers.Concatenate()([emg_enc_seq, emg_eeg_attention_seq])
   len_epoch=input_layer.shape[1] 
   inputs=tf.reshape(input_layer, [-1, no_Epochs ,len_epoch]) 
   outputs=self.seq2seq_model(inputs) # Forward pass
   # Compute our own loss
   loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

   self.compiled_metrics.update_state(y, y_pred)
   # Return a dict mapping metric names to current value
   return {m.name: m.result() for m in self.metrics}
  

Комментарии:

1. Большое вам спасибо, @Andrey за ваш ответ. Пожалуйста, объясните причину, по которой не используется «with tf.GradientTape () as tape:».

2. Большое спасибо, Андрей. Пожалуйста, объясните, как он автоматически установит training = False и примет функцию потерь, используемую во время обучения в train_step ()?

3. Он не будет автоматически устанавливать training = False . Если вашей модели требуется, чтобы параметр ‘training’ был установлен явно — вы должны добавить его в свой вызов в test_step (): outputs=self.seq2seq_model (входы, обучение = False). Test_step не будет принимать функцию потерь из train_step — вы должны явно установить функцию потерь: loss = self.compiled_loss (y, y_pred, regularization_losses =self.losses)

4. Спасибо @Andrey. Теперь я хочу использовать model.predict () . Но я получаю сообщение об ошибке типа «NotImplementedError: при создании подкласса Model класса вы должны реализовать call метод.». Как это решить?

5. @Chandra пожалуйста, создайте новый вопрос. Покажите свои исследовательские усилия там