#python #pytorch #comparison
Вопрос:
У меня есть три тензора логических масок , которые я хочу создать логическую маску, которая, если значение совпадает с тремя тензорами, то это 1
, иначе 0
.
Я пытался torch.where(A == B == C, 1, 0)
, но, похоже, это не поддерживает такое.
Ответ №1:
torch.eq
Оператор поддерживает только сравнение двоичных тензоров, поэтому вам необходимо выполнить два сравнения:
(A==B) amp; (B==C)
Ответ №2:
Вы можете использовать:
((A == B) amp; (B == C))
При необходимости вы всегда можете преобразовать логический тензор в соответствующий тип:
((A == B) amp; (B == C)).to(float)
Ответ №3:
AFAIK, тензор-это в основном массив NumPy, привязанный к устройству. Если это не слишком дорого для вашего приложения и вы можете позволить себе сделать это на процессоре, вы можете просто преобразовать его в NumPy и сделать то, что вам нужно, с помощью сравнения.