#python #tensorflow #keras
#python #tensorflow #keras
Вопрос:
Мне было интересно, знает ли кто-нибудь, как работает build()
функция из tf.keras.layers.Layer
класса под капотом. Согласно документации:
сборка вызывается, когда вы знаете формы входных тензоров и можете выполнить остальную часть инициализации
итак, мне кажется, что класс ведет себя примерно так:
class MyDenseLayer:
def __init__(self, num_outputs):
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_weight("kernel",
shape=[int(input_shape[-1]), self.num_outputs])
def __call__(self, input):
self.build(input.shape) ## build is called here when input shape is known
return tf.matmul(input, self.kernel)
Я не могу представить, build()
что это будет вызываться вечно __call__
, но это единственное место, куда передаются входные данные. Кто-нибудь знает, как именно это работает под капотом?
Ответ №1:
Layer.build()
Метод обычно используется для создания экземпляра веса слоя. Смотрите исходный код для tf.keras.layers.Dense
в качестве примера и обратите внимание, что тензоры веса и смещения создаются в этой функции. Layer.build()
Метод принимает input_shape
аргумент, а форма весов и смещений часто зависит от формы входных данных.
Layer.call()
Метод, с другой стороны, реализует прямой проход слоя. Вы не хотите перезаписывать __call__
, потому что это реализовано в базовом классе tf.keras.layers.Layer
. В пользовательском слое вы должны реализовать call()
.
Layer.call()
не вызывается Layer.build()
. Однако она Layer().__call__()
вызывается, если слой еще не был собран (исходный код), и это установит атрибут, self.built = True
предотвращающий Layer.build()
повторный вызов. Другими словами, Layer.__call__()
вызывается Layer.build()
только при первом вызове.
Комментарии:
1. Я понимаю. Итак, он вызывается при первом прямом проходе, как я и предполагал. Спасибо за предоставление исходных текстов. Очень признателен.
2.
input_shape
Ввод в метод сборки автоматически предоставляется keras? Я не вижу,input_shape
чтобы он был явно определен3.
Layer.build()
метод принимаетinput_shape
аргумент