Переменный размер пакета и несколько графических процессоров в tensorflow

#python #tensorflow #multi-gpu

Вопрос:

Я хотел бы разработать пользовательский цикл обучения для нескольких графических процессоров с переменным размером пакета с помощью Tensorflow. В каждую эпоху обучения размер пакета увеличивается на единицу до тех пор, пока не будет достигнут максимальный размер буфера.

Вот мой текущий код, который я перенастроил из руководства Tensorflow и в котором я использую фиксированный размер пакета. Есть какие-нибудь предложения?

 def __init__(self, param):
  ...       
  if self.parameters.get('multi_gpu'):
    self.strategy = tf.distribute.MirroredStrategy()
    print('Number of GPU devices: {}'.format(self.strategy.num_replicas_in_sync))
    self.GLOBAL_BATCH_SIZE = self.parameters.get('batch_size') * self.strategy.num_replicas_in_sync
    with self.strategy.scope():
      self.build_model()
  else:
    self.build_model()


def build_model(self):
  if self.parameters.get('multi_gpu'):
    self.distributed_train_step_fn = self.get_model_train_step_function(self.detection_model, optimizer, to_fine_tune)
    self.distributed_test_step_fn = self.get_model_test_step_function(self.detection_model, optimizer, to_fine_tune)
  else:
    self.train_step_fn = self.get_model_train_step_function(self.detection_model, optimizer, to_fine_tune)
    self.test_step_fn = self.get_model_test_step_function(self.detection_model,optimizer, to_fine_tune)

...

def get_model_train_step_function(self, model, optimizer, vars_to_fine_tune):
  ...
  @tf.function(experimental_relax_shapes=True)
  def train_step_fn(image_tensors, groundtruth_boxes_list,groundtruth_classes_list):
    ...

  @tf.function(experimental_relax_shapes=True)
  def distributed_train_step_fn(image_tensors, groundtruth_boxes_list,groundtruth_classes_list):
    per_replica_losses = self.strategy.run(train_step_fn, args=(image_tensors,groundtruth_boxes_list,groundtruth_classes_list))
    if self.parameters.get('multi_gpu'): 
      return distributed_train_step_fn
    else:
      return train_step_fn

...

def get_model_test_step_function(self, model, optimizer, vars_to_fine_tune):
  ...
  @tf.function(experimental_relax_shapes=True)
  def test_step_fn(image_tensors, groundtruth_boxes_list,groundtruth_classes_list):
    ...

  @tf.function(experimental_relax_shapes=True)
  def distributed_test_step_fn(image_tensors, groundtruth_boxes_list,groundtruth_classes_list):
    per_replica_losses = self.strategy.run(test_step_fn, args=(image_tensors,groundtruth_boxes_list,groundtruth_classes_list))
    if self.parameters.get('multi_gpu'): 
      return distributed_test_step_fn
    else:
      return test_step_fn

...

def fine_tune(self):
  ...
  if self.parameters.get('multi_gpu'):
    num_batches = int(train_size / self.GLOBAL_BATCH_SIZE)
    num_test_batches = int(len(self.test_images) / self.GLOBAL_BATCH_SIZE)
  else:
    num_batches = int(train_size / self.parameters.get('batch_size'))
    num_test_batches = int(len(self.test_images) / self.parameters.get('batch_size'))

  ...
  for epoch in range(self.parameters.get('epochs')):
    ...
    for idx in tqdm(range(num_batches - 1)):
      if self.parameters.get('multi_gpu'):
        example_keys = all_keys[idx * self.GLOBAL_BATCH_SIZE: (idx   1) * self.GLOBAL_BATCH_SIZE]
        gt_boxes_list = [self.train_boxes_tensors[key] for key in example_keys]
        gt_classes_list = [self.train_classes_one_hot_tensors[key] for key in example_keys]
        image_tensors = [self.train_image_tensors[key] for key in example_keys]
        batch_loss = self.distributed_train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
      else:
        example_keys = all_keys[idx * self.parameters.get('batch_size'): (idx   1) * self.parameters.get('batch_size')]
        gt_boxes_list = [self.train_boxes_tensors[key] for key in example_keys]
        gt_classes_list = [self.train_classes_one_hot_tensors[key] for key in example_keys]
        image_tensors = [self.train_image_tensors[key] for key in example_keys]
        batch_loss = self.train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
        epoch_loss_avg.update_state(batch_loss)

    print('Testing...')
    for idx in tqdm(range(num_test_batches - 1)):
      if self.parameters.get('multi_gpu'):
        example_keys = all_test_keys[idx * self.GLOBAL_BATCH_SIZE: (idx   1) * self.GLOBAL_BATCH_SIZE]
        gt_boxes_list = [self.test_boxes_tensors[key] for key in example_keys]
        gt_classes_list = [self.test_classes_one_hot_tensors[key] for key in example_keys]
        image_tensors = [self.test_image_tensors[key] for key in example_keys]
        batch_val_loss = self.distributed_test_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
      else:
        example_keys = all_test_keys[idx * self.parameters.get('batch_size'): (idx   1) * self.parameters.get('batch_size')]
        gt_boxes_list = [self.test_boxes_tensors[key] for key in example_keys]
        gt_classes_list = [self.test_classes_one_hot_tensors[key] for key in example_keys]
        image_tensors = [self.test_image_tensors[key] for key in example_keys]
        batch_val_loss = self.test_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
        epoch_val_loss_avg.update_state(batch_val_loss)

    print('Epoch loss ', str(epoch_loss_avg.result()), ' val_loss ',str(epoch_val_loss_avg.result()))