ПаЙторч: сравните три тензора?

#python #pytorch #comparison

Вопрос:

У меня есть три тензора логических масок , которые я хочу создать логическую маску, которая, если значение совпадает с тремя тензорами, то это 1 , иначе 0 .

Я пытался torch.where(A == B == C, 1, 0) , но, похоже, это не поддерживает такое.

Ответ №1:

torch.eq Оператор поддерживает только сравнение двоичных тензоров, поэтому вам необходимо выполнить два сравнения:

 (A==B) amp; (B==C)
 

Ответ №2:

Вы можете использовать:

 ((A == B) amp; (B == C))
 

При необходимости вы всегда можете преобразовать логический тензор в соответствующий тип:

 ((A == B) amp; (B == C)).to(float)
 

Ответ №3:

AFAIK, тензор-это в основном массив NumPy, привязанный к устройству. Если это не слишком дорого для вашего приложения и вы можете позволить себе сделать это на процессоре, вы можете просто преобразовать его в NumPy и сделать то, что вам нужно, с помощью сравнения.