#c #simd #avx512 #simd-library #synet
Вопрос:
Существуют инструкции AVX-512 VNNI, начиная с процессора Intel Cascade Lake, которые могут ускорить вывод квантованных нейронных сетей на процессоре. В частности, существует функция instuction _mm512_dpbusd_epi32
( vpdpbusd
), которая позволяет выполнять умножение 8-разрядных целых чисел со знаком и без знака и накапливать их в 32-разрядные целочисленные аккумуляторы. Ниже приведен псевдокод этой инструкции:
void _mm512_dpbusd_epi32(int32_t sum[16], uint8_t a[16][4], int8_t b[16][4])
{
for(int i = 0; i < 16; i)
sum[i] =
(int)a[i][0]*b[i][0] (int)a[i][1]*b[i][1]
(int)a[i][2]*b[i][2] (int)a[i][3]*b[i][3];
}
К сожалению, процессоры intel до Cascade Lake не имеют этой инструкции, поэтому возникает вопрос об эмуляции этой с использованием предыдущего расширения (например, AVX-512BW).
Поэтому мой вопрос таков: как сделать эту эмуляцию максимально эффективной?
Ответ №1:
Я думаю, что на этот вопрос нет одного правильного ответа.
С одной стороны, быстрая эмуляция _mm512_dpbusd_epi32
с использованием расширения AVX-512BW может выглядеть так:
inline __m512i _mm512_dpbusd_epi32_bw_fast(__m512i i32, __m512i u8, __m512i i8)
{
__m512i i16 = _mm512_maddubs_epi16(u8, i8); //possible overflow of INT16.
__m512i _1 = _mm512_set1_epi16(1);
return _mm512_add_epi32(i32, _mm512_madd_epi16(i16, _1));
}
В этой реализации используются только 3 инструкции (и все они быстрые).
Но это может дать неверный результат из-за возможного переполнения INT16 в _mm512_maddubs_epi16
инструкции.
С другой стороны, правильная эмуляция выглядит ужасно и требует 14 инструкций (и некоторые из них особенно медленные).:
inline __m512i _mm512_hadd_epi32(__m512i a, __m512i b)
{
static const __m512i IDX0 = _mm512_setr_epi32(
0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E,
0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E);
static const __m512i IDX1 = _mm512_setr_epi32(
0x01, 0x03, 0x05, 0x07, 0x09, 0x0B, 0x0D, 0x0F,
0x11, 0x13, 0x15, 0x17, 0x19, 0x1B, 0x1D, 0x1F);
__m512i ab0 = _mm512_permutex2var_epi32(a, IDX0, b);
__m512i ab1 = _mm512_permutex2var_epi32(a, IDX1, b);
return _mm512_add_epi32(ab0, ab1);
}
inline __m512i _mm512_dpbusd_epi32_bw_exact(__m512i i32, __m512i u8, __m512i i8)
{
__m512i u8_i16lo = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(u8, 0));
__m512i i8_i16lo = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(i8, 0));
__m512i i32lo = _mm512_madd_epi16(u8_i16lo, i8_i16lo);
__m512i u8_i16hi = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(u8, 1));
__m512i i8_i16hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(i8, 1));
__m512i i32hi = _mm512_madd_epi16(u8_i16hi, i8_i16hi);
return _mm512_add_epi32(i32, _mm512_hadd_epi32(i32lo, i32hi));
}