Почему встраивание dimemsion должно быть кратным числу заголовков в MultiheadAttention?

#python-3.x #pytorch #transformer #attention-model

#python-3.x #pytorch #трансформатор #внимание-модель

Вопрос:

Я изучаю трансформатор. Вот документ pytorch для MultiheadAttention. В их реализации я видел, что существует ограничение:

  assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
 

Зачем требовать ограничения: embed_dim must be divisible by num_heads? если мы вернемся к уравнению

MultiHead (Q, K, V) = Объединение (head1, ..., headh) WOwhereheadi = Внимание (QWiQ, KWiK, VWiV)

Предположим, что: Q , K , V являются n x emded_dim матрицами; все весовые матрицы W emded_dim x head_dim ,

Тогда конкат [head_i, ..., head_h] будет n x (num_heads*head_dim) матрицей;

W^O с размером (num_heads*head_dim) x embed_dim

[head_i, ..., head_h] * W^O станет n x embed_dim выходом

Я не знаю, зачем нам это нужно embed_dim must be divisible by num_heads .

Допустим, у нас есть num_heads=10000 , результаты одинаковы, поскольку произведение матрицы на матрицу уничтожит эту информацию.

Ответ №1:

Когда у вас есть последовательность seq_len x emb_dim (т.е. 20 x 8 ) и вы хотите использовать num_heads=2 , последовательность будет разделена по emb_dim измерению. Поэтому вы получаете две 20 x 4 последовательности. Вы хотите, чтобы каждая головка имела одинаковую форму, и если emb_dim она не делится на num_heads это, это не сработает. Возьмем , к примеру , последовательность 20 x 9 и еще раз num_heads=2 . Тогда вы получите 20 x 4 и 20 x 5 , которые не являются одним и тем же измерением.

Комментарии:

1. Хороший пример, если seq_len x emb_dim есть 20 x 9 , и num_heads=2 , пусть выбирают head_dim=77 , тогда мы можем получить head_i 20 x 144 матрицу is. как таковой [head_1, head_2] 20 x 288 , мы все равно можем выбрать W^O is 288 x 9 . мы все еще можем получить финал 20 x 9 . Я хочу сказать, что мы также можем отобразить emb_dim на любую длину и использовать W^O для ее возврата emb_dim . Зачем нужно погружаться emb_dim в четную длину? Спасибо.