#tensorflow
Вопрос:
Я пытаюсь понять реализацию SGD в tensorflow.
Я начал с gradient_descent.py из-за имени файла.
Согласно документу keras, оптимизатору необходимо реализовать _resource_apply_dense
метод, который соответствует коду (частично), показанному ниже:
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
if self._momentum:
momentum_var = self.get_slot(var, "momentum")
return gen_training_ops.ResourceApplyKerasMomentum(
...
Я хотел бы знать, кто передает var
переменную в _resource_apply_dense
метод? Другими словами, какой метод решает, что эта конкретная серия примеров предназначена для изучения моделью?
Ответ №1:
Проверяя keras optimizer_v2 или tensorflow, мы находим единственное использование этой функции во всей кодовой базе tensorflow:
#...
def apply_grad_to_update_var(var, grad):
#...
if "apply_state" in self._dense_apply_args:
apply_kwargs["apply_state"] = apply_state
update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
if var.constraint is not None:
with ops.control_dependencies([update_op]):
return var.assign(var.constraint(var))
Позже мы увидим в том же файле, что var
переменная исходит из аргумента _distributed_apply
функции:
#...
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
#...
with name_scope_only_in_function_or_graph(name or self._name):
for grad, var in grads_and_vars:
#...
Наконец, grads_and_vars
аргумент определяется как List of (gradient, variable) pairs
в функции apply_gradients
:
#...
def apply_gradients(self,
grads_and_vars,
#...
"""...
Args:
grads_and_vars: List of (gradient, variable) pairs.
"""
Если вы проверите вхождения apply_gradients
(этого поиска), вы увидите, что это обычный способ обновления весов сети и, таким образом, контролируется шагом «обновление» оптимизатора.
Комментарии:
1.Огромное спасибо. Как раз перед этой строкой
self._create_all_weights(var_list)
я вставилprint(var_list[0].numpy(), var_list[1].numpy())
, а затем получил веса и смещения, а не обучающие примеры, какие-нибудь идеи?2. Обратите внимание, что
var_list
это список обучаемых переменных, а не список обучающих примеров! вы касаетесь кода, связанного с обновлением обучаемых переменных!