#pytorch
#pytorch
Вопрос:
Мне было интересно, есть ли более эффективная альтернатива для приведенного ниже кода, без использования цикла «for» в 4-й строке?
import torch
n, d = 37700, 7842
k = 4
sample = torch.cat([torch.randperm(d)[:k] for _ in range(n)]).view(n, k)
mask = torch.zeros(n, d, dtype=torch.bool)
mask.scatter_(dim=1, index=sample, value=True)
По сути, то, что я пытаюсь сделать, это создать n
d
тензор по маске, такой, чтобы в каждой строке точно k
случайные элементы были истинными.
Ответ №1:
Вот способ сделать это без цикла. Давайте начнем со случайной матрицы, в которой все элементы отображаются iid, в данном случае равномерно на [0,1] . Затем мы берем k-й квантиль для каждой строки и устанавливаем для всех меньших или равных элементов значение True, а для остальных значение False в каждой строке:
rand_mat = torch.rand(n, d)
k_th_quant = torch.topk(rand_mat, k, largest = False)[0][:,-1:]
mask = rand_mat <= k_th_quant
Цикл не требуется 🙂 x2.1598 быстрее, чем код, который вы прикрепили к моему процессору.
Комментарии:
1. Хороший ответ, однако, я думаю, что я должен был предоставить вам реальные значения для
n
,d
, иk
которые я на самом деле использую в своем коде (я редактирую вопрос). Сn=37700
,d=7842
, иk=4
, мой собственный код работает около 5 секунд на моем процессоре, в то время как ваш занимает около 18 секунд.2. Спасибо, поэтому я обновил его, и теперь он стал еще лучше и быстрее для ваших новых значений n, d и k. Мой занимает 2,44 с, а ваш — 5,27 с.