我为 int8 矩阵乘法实现了 8 位整数乘法。
这是我的代码,但我认为它真的很慢。
inline __m512i int8_mul(__m512i a, __m512i b) {
// Convert vectors INT8 to INT16
__m512i a_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(a));
__m512i a_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a, 1));
__m512i b_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(b));
__m512i b_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b, 1));
// Multiply vectors in INT16
__m512i mul_lo = _mm512_mullo_epi16(a_lo, b_lo);
__m512i mul_hi = _mm512_mullo_epi16(a_hi, b_hi);
// Combine results
__m512i result = _mm512_setzero_si512();
result = _mm512_inserti64x4(result, _mm512_cvtepi16_epi8(mul_lo), 0);
result = _mm512_inserti64x4(result, _mm512_cvtepi16_epi8(mul_hi), 1);
return result;
}
还有其他方法可以获得更高的性能吗?
附注为什么intel在epi8中不支持乘法?
AMX 具有 i8 点积 (https://www.felixcloutier.com/x86/tdpbssd:tdpbsud:tdpbusd:tdpbuud)。 AVX-VNNI 对于 i8 x u8 的点积有
vpdpbusd
(https://www.felixcloutier.com/x86/vpdpbusd)。 即使在没有 AVX-512 的 Alder Lake 系列上也可以使用此功能。 AVX-VNNI 似乎只有 i8 x u8(如 pmaddubsw
)和 i16 x i16。 不是两个输入都是 i8 的。 这些只是扩大点积,而不是没有求和的纯垂直c[i] = a[i] * b[i]
。
(扩大是有符号与无符号很重要的原因;在您将结果截断回 8 位的情况下,无论您是符号扩展还是零扩展到 16 位都没有关系。结果不依赖于将 MSB 解释为 2^n-1 或 -2^n-1)
vpmaddubsw
(i8 x u8,将对进行饱和求和到 i16),是的,解压到 16 位是正确的方法。
解包到奇数/偶数元素(
srli_epi16(v, 8)
和and(v, set1_epi16(0x00ff)
)比每个输入进行 3 次随机播放要便宜,并且允许您使用 Shift / Blend 重新打包,而不是 3 次随机播放。
(您用 4 来编写它,包括
_mm512_inserti64x4(zero, cvt(lo), 0)
。您可以从 __m512i result = _mm512_castsi256_si512(cvt(lo))
开始。一个好的编译器 可能 已经优化了第一个 _mm512_inserti64x4
,但有些编译器比其他编译器更字面地理解内在函数。)
并且由于高位垃圾不会影响乘法的结果(部分乘积被添加,进位仅从低位传播到高位),我们甚至不必隔离每个 i16 元素下半部分中的偶数元素.
我们还可以通过执行
(a_odd<<8) * b_odd
在每个 i16 的高半部分中生成奇数乘积来保存最后的移位,因此实际上我们通过清除低位来生成 a_odd_hi
,将奇数字节保留在原处。
// UNTESTED, let me know if there are any silly mistakes that need fixing
#include <immintrin.h>
#include <stdint.h>
//inline
__m512i int8_mul(__m512i a, __m512i b)
{
__m512i a_even = a; // _mm512_and_si512(a, _mm512_set1_epi16(0x00ff)); // not needed, high garbage is fine
__m512i a_odd_shifted = _mm512_and_si512(a, _mm512_set1_epi16(0xff00));
__m512i b_even = b;
__m512i b_odd_lo = _mm512_srli_epi16(b, 8);
// Multiply vectors in INT16
__m512i mul_even = _mm512_mullo_epi16(a_even, b_even);
__m512i mul_odd = _mm512_mullo_epi16(a_odd_shifted, b_odd_lo);
// Combine results
// blend using the same vector constant we already needed, instead of a k mask
// first source operand is a variable not needed later so it can be overwritten
__m512i result = _mm512_ternarylogic_epi32(mul_even, _mm512_set1_epi16(0xff00), mul_odd, 0xB8); // 0xB8: B ? C : A
// alternate version using a mask
// __m512i result = _mm512_mask_blend_epi8(0xAAAAAAAAAAAAAAAA, mul_even, mul_odd);
// another alternative:
// __m512i result = _mm512_mask_mov_epi8(mul_even, 0xAAAAAAAAAAAAAAAA, mul_odd);
return result;
}
(变量命名:我还考虑了
a_odd_hi
来反映我们将想要的位留在单词的上半部分这一事实。a_odd_shifted
应该意味着 a_odd
是 u8
或 i8
我们想要的值,它在 u16 中为 a_odd << 8
但是。这不是很好,因为我们实际上“没有”转移到那里。它是一小段代码,全部包含在一个函数中,所以它基本上没问题,但我仍然对我对变量名称的任何想法不满意。 a_oddx256
是另一种看起来更笨重的选择。)GCC 和 Clang 都做得相当合理 (Godbolt)。
# clang19 -O3 -march=x86-64-v4
.LCPI0_1:
.long 4278255360 # in .rodata
# in .text
int8_mul:
vpsrlw zmm2, zmm1, 8
vpmullw zmm1, zmm1, zmm0
vpandd zmm0, zmm0, dword ptr [rip + .LCPI0_1]{1to16}
vpmullw zmm0, zmm2, zmm0
vpternlogd zmm0, zmm1, dword ptr [rip + .LCPI0_1]{1to16}, 228
ret
GCC 从 mov-immediate +
vpbroadcastd zmm1, eax
实现常量(在浪费一条指令将一个输入移动到不同的向量寄存器之后),clang 选择广播加载它两次。 内联时,两者都应将常量设置提升到循环之外。
我想我可以使用掩码常数来进行输出混合和输入掩码,使用a_odd_shifted = _mm512_maskz_mov_epi8(0xAAAAAAAAAAAAAAAA, a);
。 这仍然只是一个需要设置的常量,但对于 64 位来说,它需要是
movabs rcx, imm64
/ kmovq k1, rcx
或其他什么,而不是 mov ecx, imm32
/ vpbroadcastd zmm2, ecx
。 Masked vmovdqu8
仍然采用向量执行单元,与 vpandd
相同。
GCC 按编写方式编译原始文件,对 2 个输入进行 6 次洗牌,并进行 3 次重新打包。 (它确实优化了零向量中的插入)。但是 clang 做了一些
完全不同的事情,重新矢量化它更像我的版本。
# clang19 for the original version!!
.LCPI1_1:
.short 255
int8_mul_shuffle:
vpbroadcastw zmm2, word ptr [rip + .LCPI1_1] # set1 (0x00ff)
vpandq zmm3, zmm2, zmm0 # a_even
vpmaddubsw zmm3, zmm1, zmm3 # b_even * a_even (plus 0 * b_odd = 0)
vpandnq zmm0, zmm2, zmm0 # a_odd
vpmaddubsw zmm0, zmm1, zmm0 # b_odd * a_odd (plus 0 * b_even = 0)
vpsllw zmm0, zmm0, 8 # result_odd <<= 8
vpternlogq zmm0, zmm3, zmm2, 248 # blend
ret
ret
,这是6条指令。 我的版本是 5,第一个乘法可以立即开始,任何输入都不必先经过 and
。 另外,我的关键路径延迟稍好一些:从准备就绪开始,它会为并行运行的奇数元素执行
vpandq
/
vpmaddubsw
/ vpsllw
/
vpternlogq
all in a dependency chain. Mine lets the shift and
和`。不过,
vpmaddubsw
是一种处理奇数/偶数的有趣方法,因为它在每对中都执行(x_even * y_even) + (x_odd * y_odd)
,对第一个输入进行符号扩展,对第二个输入进行零扩展。 因此,将一个输入的奇数或偶数元素归零意味着相应的元素将乘以 0。如果它倾向于使用比
vpmullw
更少的幂,则 IDK 仅是 8 位加宽乘法而不是 16 位非加宽。
Clang 的版本可以使用
vpmullw
代替 vpandq
/
vpmaddubsw
来表示偶数 x 偶数乘积。 这将是一个直接替代品,因为
vpmaddubsw
也会产生大量垃圾,所以它们像我一样混合,而不仅仅是 OR。 对一个输入进行与运算并在第二个输入之前移动另一个输入不会保存指令,但会缩短延迟(假设两者同时准备好),并且永远不会(?)更糟。