#machine-learning #deep-learning #pytorch
#машинное обучение #глубокое обучение #пыторч #pytorch
Вопрос:
Я использую библиотеку Pytorch и ищу способ заморозить веса и смещения в моей модели.
Я видел эти 2 варианта:
-
model.train(False)
-
for param in model.parameters(): param.requires_grad = False
В чем разница (если таковая есть) и какой из них я должен использовать, чтобы заморозить текущее состояние моей модели?
Ответ №1:
Они очень разные.
Независимо от процесса backprop, некоторые уровни имеют разное поведение при обучении или оценке модели. В pytorch их всего 2: BatchNorm (который, я думаю, прекращает обновлять текущее среднее значение и отклонение при оценке) и Dropout (который удаляет значения только в режиме обучения). Итак, model.train()
и model.eval()
(эквивалентно model.train(false)
) просто установите логический флаг, чтобы сообщить этим двум слоям «заморозить себя». Обратите внимание, что эти два уровня не имеют никаких параметров, на которые влияет обратная операция (я думаю, что тензоры буфера batchnorm изменены во время прямого прохождения)
С другой стороны, установка для всех ваших параметров значения «require_grad=false» просто указывает pytorch прекратить запись градиентов для backprop. Это не повлияет на пакетный уровень и уровни отсева
Как заморозить вашу модель, зависит от вашего варианта использования, но я бы сказал, что самый простой способ — использовать torch.jit.trace. Это создаст замороженную копию вашей модели, в точности в том состоянии, в котором она была при вашем вызове trace
. Ваша модель осталась неизменной.
Обычно вы вызываете
model.eval()
traced_model = torch.jit.trace(model, input)
Комментарии:
1. Это torch.jit.trace обратимо? На случай, если я решу, что хочу тренироваться еще несколько эпох
2. Я отредактировал свой ответ, чтобы добавить ясности. Трассировка вашей модели не повлияет на нее, она просто создаст ее замороженную копию
3. У BatchNorm действительно есть обучаемые параметры — новое среднее и стандартное отклонение. Пожалуйста, уточните свой ответ.
Ответ №2:
Существует два способа зависания в PyTorch при обучении:
- установка
requires_grad
вFalse
- установка скорости обучения lr равной нулю
Пока model.train(False)
это способ не обучаться. 😉