#python #tensorflow #multi-gpu
Вопрос:
Я хотел бы разработать пользовательский цикл обучения для нескольких графических процессоров с переменным размером пакета с помощью Tensorflow. В каждую эпоху обучения размер пакета увеличивается на единицу до тех пор, пока не будет достигнут максимальный размер буфера.
Вот мой текущий код, который я перенастроил из руководства Tensorflow и в котором я использую фиксированный размер пакета. Есть какие-нибудь предложения?
def __init__(self, param):
...
if self.parameters.get('multi_gpu'):
self.strategy = tf.distribute.MirroredStrategy()
print('Number of GPU devices: {}'.format(self.strategy.num_replicas_in_sync))
self.GLOBAL_BATCH_SIZE = self.parameters.get('batch_size') * self.strategy.num_replicas_in_sync
with self.strategy.scope():
self.build_model()
else:
self.build_model()
def build_model(self):
if self.parameters.get('multi_gpu'):
self.distributed_train_step_fn = self.get_model_train_step_function(self.detection_model, optimizer, to_fine_tune)
self.distributed_test_step_fn = self.get_model_test_step_function(self.detection_model, optimizer, to_fine_tune)
else:
self.train_step_fn = self.get_model_train_step_function(self.detection_model, optimizer, to_fine_tune)
self.test_step_fn = self.get_model_test_step_function(self.detection_model,optimizer, to_fine_tune)
...
def get_model_train_step_function(self, model, optimizer, vars_to_fine_tune):
...
@tf.function(experimental_relax_shapes=True)
def train_step_fn(image_tensors, groundtruth_boxes_list,groundtruth_classes_list):
...
@tf.function(experimental_relax_shapes=True)
def distributed_train_step_fn(image_tensors, groundtruth_boxes_list,groundtruth_classes_list):
per_replica_losses = self.strategy.run(train_step_fn, args=(image_tensors,groundtruth_boxes_list,groundtruth_classes_list))
if self.parameters.get('multi_gpu'):
return distributed_train_step_fn
else:
return train_step_fn
...
def get_model_test_step_function(self, model, optimizer, vars_to_fine_tune):
...
@tf.function(experimental_relax_shapes=True)
def test_step_fn(image_tensors, groundtruth_boxes_list,groundtruth_classes_list):
...
@tf.function(experimental_relax_shapes=True)
def distributed_test_step_fn(image_tensors, groundtruth_boxes_list,groundtruth_classes_list):
per_replica_losses = self.strategy.run(test_step_fn, args=(image_tensors,groundtruth_boxes_list,groundtruth_classes_list))
if self.parameters.get('multi_gpu'):
return distributed_test_step_fn
else:
return test_step_fn
...
def fine_tune(self):
...
if self.parameters.get('multi_gpu'):
num_batches = int(train_size / self.GLOBAL_BATCH_SIZE)
num_test_batches = int(len(self.test_images) / self.GLOBAL_BATCH_SIZE)
else:
num_batches = int(train_size / self.parameters.get('batch_size'))
num_test_batches = int(len(self.test_images) / self.parameters.get('batch_size'))
...
for epoch in range(self.parameters.get('epochs')):
...
for idx in tqdm(range(num_batches - 1)):
if self.parameters.get('multi_gpu'):
example_keys = all_keys[idx * self.GLOBAL_BATCH_SIZE: (idx 1) * self.GLOBAL_BATCH_SIZE]
gt_boxes_list = [self.train_boxes_tensors[key] for key in example_keys]
gt_classes_list = [self.train_classes_one_hot_tensors[key] for key in example_keys]
image_tensors = [self.train_image_tensors[key] for key in example_keys]
batch_loss = self.distributed_train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
else:
example_keys = all_keys[idx * self.parameters.get('batch_size'): (idx 1) * self.parameters.get('batch_size')]
gt_boxes_list = [self.train_boxes_tensors[key] for key in example_keys]
gt_classes_list = [self.train_classes_one_hot_tensors[key] for key in example_keys]
image_tensors = [self.train_image_tensors[key] for key in example_keys]
batch_loss = self.train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
epoch_loss_avg.update_state(batch_loss)
print('Testing...')
for idx in tqdm(range(num_test_batches - 1)):
if self.parameters.get('multi_gpu'):
example_keys = all_test_keys[idx * self.GLOBAL_BATCH_SIZE: (idx 1) * self.GLOBAL_BATCH_SIZE]
gt_boxes_list = [self.test_boxes_tensors[key] for key in example_keys]
gt_classes_list = [self.test_classes_one_hot_tensors[key] for key in example_keys]
image_tensors = [self.test_image_tensors[key] for key in example_keys]
batch_val_loss = self.distributed_test_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
else:
example_keys = all_test_keys[idx * self.parameters.get('batch_size'): (idx 1) * self.parameters.get('batch_size')]
gt_boxes_list = [self.test_boxes_tensors[key] for key in example_keys]
gt_classes_list = [self.test_classes_one_hot_tensors[key] for key in example_keys]
image_tensors = [self.test_image_tensors[key] for key in example_keys]
batch_val_loss = self.test_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
epoch_val_loss_avg.update_state(batch_val_loss)
print('Epoch loss ', str(epoch_loss_avg.result()), ' val_loss ',str(epoch_val_loss_avg.result()))