Как утвердить torch.float32 / 64 с помощью isinstance?

#python #pytorch

#python #pytorch

Вопрос:

Я пишу конвейер для своего проекта глубокого обучения, в качестве хорошей практики я пытаюсь выполнить некоторые утверждения, чтобы убедиться, что тип данных совпадает, поэтому его легче отлаживать позже! Однако я не знаю, как использовать isinstance для утверждения torch.Tensor's dtype .

Например:

     assert isinstance(image, torch.Tensor) and isinstance(target['boxes'], torch.Tensor)
    assert isinstance(image.dtype, (torch.float32, torch.float64)) and isinstance(target['boxes'].dtype, (torch.float32, torch.float64)) 
  

Используется assert image.dtype in [torch.float32, torch.float64] ли здесь единственный метод? Есть ли элегантный способ сделать это?

Ответ №1:

Вы могли бы использовать torch.is_floating_point

 assert torch.is_floating_point(image) and torch.is_floating_point(target['boxes'])
  

Функция вызывает исключение, если входные данные не являются тензором. Поэтому нет необходимости выполнять независимую проверку torch.Tensor .