#python #pytorch #quantization
#python #pytorch #квантование
Вопрос:
Я использую приведенный ниже код, чтобы получить квантованный формат int 8 без знака в pytorch. Однако я не могу преобразовать quant
переменную в to np.uint8
. Возможно ли это сделать?
import torch
quant = torch.quantize_per_tensor(torch.tensor([-1.0, 0.352, 1.321, 2.0]), 0.1, 10, torch.quint8)
Ответ №1:
Это можно сделать с помощью torch.int_repr()
import torch
import numpy as np
# generate a test float32 tensor
float32_tensor = torch.tensor([-1.0, 0.352, 1.321, 2.0])
print(f'{float32_tensor.dtype}n{float32_tensor}n')
# convert to a quantized uint8 tensor. This format keeps the values in the range of
# the float32 format, with the resolution of a uint8 format (256 possible values)
quint8_tensor = torch.quantize_per_tensor(float32_tensor, 0.1, 10, torch.quint8)
print(f'{quint8_tensor.dtype}n{quint8_tensor}n')
# map the quantized data to the actual uint8 values (and then to an np array)
uint8_np_ndarray = torch.int_repr(quint8_tensor).numpy()
print(f'{uint8_np_ndarray.dtype}n{uint8_np_ndarray}')
Вывод
torch.float32
tensor([-1.0000, 0.3520, 1.3210, 2.0000])
torch.quint8
tensor([-1.0000, 0.4000, 1.3000, 2.0000], size=(4,), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10)
uint8
[ 0 14 23 30]