объяснение кода, скрывающегося за маской трансформатора (PyTorch)

#deep-learning #pytorch #transformer

Вопрос:

в реализации трансформатора я нашел следующую функцию ( size я полагаю, это длина последовательности)::

 def _gen_sqr_nxt_mask(self, size):
    msk = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    msk = msk.float().masked_fill(msk == 0, float('-inf'))
    msk = msk.masked_fill(msk == 1, float(0.0))
    return msk
 

он используется таким образом, внутри forward функции:

 op = self.enc_transformer(source, self.mask_source)
 

enc_transformer определяется как:

     layers_enc = TransformerEncoderLayer(num_inputs, num_heads, num_hidden, dropout)
    self.enc_transformer = TransformerEncoder(layers_enc, num_layers)
 

в рамках init функции.

может ли кто-нибудь объяснить мне, как именно должна выглядеть маска для трансформатора? с точки зрения формы и ценностей. например, похоже, что функция ‘_gen_sqr_nxt_mask’ создает матрицу с-inf и 0 — это всегда значения, которые должна использовать маска?