我需要更高的 int8 向量乘法性能(Intel AVX)

问题描述 投票:0回答:1

我为 int8 矩阵乘法实现了 8 位整数乘法。

这是我的代码,但我认为它真的很慢。

inline __m512i int8_mul(__m512i a, __m512i b) {
// Convert vectors INT8 to INT16
__m512i a_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(a));
__m512i a_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a, 1));
__m512i b_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(b));
__m512i b_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b, 1));

    // Multiply vectors in INT16
    __m512i mul_lo = _mm512_mullo_epi16(a_lo, b_lo);
    __m512i mul_hi = _mm512_mullo_epi16(a_hi, b_hi);
    
    // Combine results
    __m512i result = _mm512_setzero_si512();
    result = _mm512_inserti64x4(result, _mm512_cvtepi16_epi8(mul_lo), 0);
    result = _mm512_inserti64x4(result, _mm512_cvtepi16_epi8(mul_hi), 1);
    return result;

}

还有其他方法可以获得更高的性能吗?

附注为什么intel在epi8中不支持乘法?

performance vectorization avx avx2 avx512
1个回答
0
投票

AMX 具有 i8 点积 (https://www.felixcloutier.com/x86/tdpbssd:tdpbsud:tdpbusd:tdpbuud)。 AVX-VNNI 对于 i8 x u8 的点积有

vpdpbusd
(https://www.felixcloutier.com/x86/vpdpbusd)。 即使在没有 AVX-512 的 Alder Lake 系列上也可以使用此功能。 AVX-VNNI 似乎只有 i8 x u8(如
pmaddubsw
)和 i16 x i16。 不是两个输入都是 i8 的。 这些只是扩大点积,而不是没有求和的纯垂直
c[i] = a[i] * b[i]

(扩大是有符号与无符号很重要的原因;在您将结果截断回 8 位的情况下,无论您是符号扩展还是零扩展到 16 位都没有关系。结果不依赖于将 MSB 解释为 2^n-1 或 -2^n-1)


如果您无法使用

vpmaddubsw
(i8 x u8,将对进行饱和求和到 i16),是的,解压到 16 位是正确的方法。

解包到奇数/偶数元素(

srli_epi16(v, 8)
and(v, set1_epi16(0x00ff)
)比每个输入进行 3 次随机播放要便宜,并且允许您使用 Shift / Blend 重新打包,而不是 3 次随机播放。

(您用 4 来编写它,包括

_mm512_inserti64x4(zero, cvt(lo), 0)
。您可以从
__m512i result = _mm512_castsi256_si512(cvt(lo))
开始。一个好的编译器 可能 已经优化了第一个
_mm512_inserti64x4
,但有些编译器比其他编译器更字面地理解内在函数。)

并且由于高位垃圾不会影响乘法的结果(部分乘积被添加,进位仅从低位传播到高位),我们甚至不必隔离每个 i16 元素下半部分中的偶数元素.

我们还可以通过执行

(a_odd<<8) * b_odd
在每个 i16 的高半部分中生成奇数乘积来保存最后的移位,因此实际上我们通过清除低位来生成
a_odd_hi
,将奇数字节保留在原处。

// UNTESTED, let me know if there are any silly mistakes that need fixing
#include <immintrin.h>
#include <stdint.h>

//inline
__m512i int8_mul(__m512i a, __m512i b)
{
    __m512i a_even = a;  // _mm512_and_si512(a, _mm512_set1_epi16(0x00ff));  // not needed, high garbage is fine
    __m512i a_odd_shifted = _mm512_and_si512(a, _mm512_set1_epi16(0xff00));
    __m512i b_even = b;
    __m512i b_odd_lo = _mm512_srli_epi16(b, 8);

    // Multiply vectors in INT16
    __m512i mul_even = _mm512_mullo_epi16(a_even, b_even);
    __m512i mul_odd  = _mm512_mullo_epi16(a_odd_shifted, b_odd_lo);
    
    // Combine results
   // blend using the same vector constant we already needed, instead of a k mask
   // first source operand is a variable not needed later so it can be overwritten
    __m512i result = _mm512_ternarylogic_epi32(mul_even, _mm512_set1_epi16(0xff00), mul_odd, 0xB8); // 0xB8: B ? C : A
   // alternate version using a mask
   // __m512i result = _mm512_mask_blend_epi8(0xAAAAAAAAAAAAAAAA, mul_even, mul_odd);
   // another alternative:
   // __m512i result = _mm512_mask_mov_epi8(mul_even, 0xAAAAAAAAAAAAAAAA, mul_odd);

   return result;
}

(变量命名:我还考虑了

a_odd_hi
来反映我们将想要的位留在单词的上半部分这一事实。
a_odd_shifted
应该意味着
a_odd
u8
i8
我们想要的值,它在 u16 中为
a_odd << 8
但是。这不是很好,因为我们实际上“没有”转移到那里。它是一小段代码,全部包含在一个函数中,所以它基本上没问题,但我仍然对我对变量名称的任何想法不满意。 a_oddx256是另一种看起来更笨重的选择。)
GCC 和 Clang 都做得相当合理 (

Godbolt)。 # clang19 -O3 -march=x86-64-v4 .LCPI0_1: .long 4278255360 # in .rodata # in .text int8_mul: vpsrlw zmm2, zmm1, 8 vpmullw zmm1, zmm1, zmm0 vpandd zmm0, zmm0, dword ptr [rip + .LCPI0_1]{1to16} vpmullw zmm0, zmm2, zmm0 vpternlogd zmm0, zmm1, dword ptr [rip + .LCPI0_1]{1to16}, 228 ret

GCC 从 mov-immediate + 
vpbroadcastd zmm1, eax

实现常量(在浪费一条指令将一个输入移动到不同的向量寄存器之后),clang 选择广播加载它两次。 内联时,两者都应将常量设置提升到循环之外。

我想我可以使用掩码常数来进行输出混合和输入掩码,使用

a_odd_shifted = _mm512_maskz_mov_epi8(0xAAAAAAAAAAAAAAAA, a);

。 这仍然只是一个需要设置的常量,但对于 64 位来说,它需要是

movabs rcx, imm64
/
kmovq k1, rcx
或其他什么,而不是
mov ecx, imm32
/
vpbroadcastd zmm2, ecx
。 Masked
vmovdqu8
仍然采用向量执行单元,与
vpandd
相同。

GCC 按编写方式编译原始文件,对 2 个输入进行 6 次洗牌,并进行 3 次重新打包。 (它确实优化了零向量中的插入)。

但是 clang 做了一些

完全不同的事情,重新矢量化它更像我的版本。

# clang19 for the original version!! .LCPI1_1: .short 255 int8_mul_shuffle: vpbroadcastw zmm2, word ptr [rip + .LCPI1_1] # set1 (0x00ff) vpandq zmm3, zmm2, zmm0 # a_even vpmaddubsw zmm3, zmm1, zmm3 # b_even * a_even (plus 0 * b_odd = 0) vpandnq zmm0, zmm2, zmm0 # a_odd vpmaddubsw zmm0, zmm1, zmm0 # b_odd * a_odd (plus 0 * b_even = 0) vpsllw zmm0, zmm0, 8 # result_odd <<= 8 vpternlogq zmm0, zmm3, zmm2, 248 # blend ret

不包括不断的设置和
ret
,这是6条指令。 我的版本是 5,第一个乘法可以立即开始,任何输入都不必先经过

and

。  另外,我的关键路径延迟稍好一些:从准备就绪开始,它会为并行运行的奇数元素执行
vpandq
 / 
vpmaddubsw
 / vpsllw
/
vpternlogq
all in a dependency chain. Mine lets the shift and
和`。
不过,
vpmaddubsw

是一种处理奇数/偶数的有趣方法,因为它在每对中都执行

(x_even * y_even) + (x_odd * y_odd)

,对第一个输入进行符号扩展,对第二个输入进行零扩展。  因此,将一个输入的奇数或偶数元素归零意味着相应的元素将乘以 0。如果它倾向于使用比 
vpmullw
 更少的幂,则 IDK 仅是 8 位加宽乘法而不是 16 位非加宽。 
Clang 的版本可以使用 
vpmullw

代替

vpandq

 / 
vpmaddubsw
 来表示偶数 x 偶数乘积。  这将是一个直接替代品,因为 
vpmaddubsw
 也会产生大量垃圾,所以它们像我一样混合,而不仅仅是 OR。  对一个输入进行与运算并在第二个输入之前移动另一个输入不会保存指令,但会缩短延迟(假设两者同时准备好),并且永远不会(?)更糟。
	
© www.soinside.com 2019 - 2024. All rights reserved.