Привязанные числовые сравнения в PyTorch

#python #pytorch

#python #pytorch

Вопрос:

В PyTorch мы можем проводить сравнения между элементами в тензорах следующим образом:

 import torch

a = torch.tensor([[1,2], [2,3], [3,4]])
b = torch.tensor([[3,4], [1,2], [2,3]])

print(a.size())
# torch.Size([3, 2])
print(b.size())
# torch.Size([3, 2])

c = a[:, 0] < b[:, 0]

print(c)
# tensor([ True, False, False])
  

Однако, когда мы пытаемся добавить условие, фрагмент завершается ошибкой:

 c = a[:, 0] < b[:, 1] < b[:, 0]
  

Ожидаемый результат

  tensor([ False, False,  False])
  

Итак, для каждого элемента в a сравните его первый элемент со вторым элементом соответствующего элемента в b и сравните этот элемент с первым элементом того же элемента в b.

Обратная трассировка (последний последний вызов): файл «scratch_12.py «, строка 9, в c = a[:, 0] < b[:, 1] < b[:, 0] Ошибка времени выполнения: значение bool тензора с более чем одним значением неоднозначно

Почему это так, и как мы можем это решить?

Комментарии:

1. И каков желаемый результат?

Ответ №1:

Вы должны разделить свое условие на два условия, используя непосредственно amp; оператор. Что касается того, почему именно, это связано с синтаксисом torch .

 c = (a[:, 0] < b[:, 1]) amp; (b[:, 1] < b[:, 0])