Какой метод тензорного потока решает, какую конкретную партию примеров следует изучить модели?

#tensorflow

Вопрос:

Я пытаюсь понять реализацию SGD в tensorflow.

Я начал с gradient_descent.py из-за имени файла.

Согласно документу keras, оптимизатору необходимо реализовать _resource_apply_dense метод, который соответствует коду (частично), показанному ниже:

 def _resource_apply_dense(self, grad, var, apply_state=None):
    var_device, var_dtype = var.device, var.dtype.base_dtype
    coefficients = ((apply_state or {}).get((var_device, var_dtype))
                    or self._fallback_apply_state(var_device, var_dtype))

    if self._momentum:
    momentum_var = self.get_slot(var, "momentum")
    return gen_training_ops.ResourceApplyKerasMomentum(
        ...
 

Я хотел бы знать, кто передает var переменную в _resource_apply_dense метод? Другими словами, какой метод решает, что эта конкретная серия примеров предназначена для изучения моделью?

Ответ №1:

Проверяя keras optimizer_v2 или tensorflow, мы находим единственное использование этой функции во всей кодовой базе tensorflow:

    #...
   def apply_grad_to_update_var(var, grad):
      #...
      if "apply_state" in self._dense_apply_args:
        apply_kwargs["apply_state"] = apply_state
      update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))
 

Позже мы увидим в том же файле, что var переменная исходит из аргумента _distributed_apply функции:

 #...
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
    #...
    with name_scope_only_in_function_or_graph(name or self._name):
      for grad, var in grads_and_vars:
      #...
 

Наконец, grads_and_vars аргумент определяется как List of (gradient, variable) pairs в функции apply_gradients :

   #...
  def apply_gradients(self,
                      grads_and_vars,
    #...
    """...
    Args:
      grads_and_vars: List of (gradient, variable) pairs.
    """
 

Если вы проверите вхождения apply_gradients (этого поиска), вы увидите, что это обычный способ обновления весов сети и, таким образом, контролируется шагом «обновление» оптимизатора.

Комментарии:

1.Огромное спасибо. Как раз перед этой строкой self._create_all_weights(var_list) я вставил print(var_list[0].numpy(), var_list[1].numpy()) , а затем получил веса и смещения, а не обучающие примеры, какие-нибудь идеи?

2. Обратите внимание, что var_list это список обучаемых переменных, а не список обучающих примеров! вы касаетесь кода, связанного с обновлением обучаемых переменных!