对于一个研究问题,我需要使用 AVX2/AVX 指令实现非常高效的 4 位乘法(只需要低 4 位)。
我目前的做法是:
__m256i _mm256_mullo_epi4(const __m256i a, const __m256i b) {
__m256i mask_f_0 = _mm256_set1_epi16(0x000f);
__m256i tmp_mul_0 = _mm256_and_si256(_mm256_mullo_epi16(a, b), mask_f_0);
__m256i tmp_mul_1 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 4), _mm256_srli_epi16(b, 4)), mask_f_0);
__m256i tmp_mul_2 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 8), _mm256_srli_epi16(b, 8)), mask_f_0);
__m256i tmp_mul_3 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 12), _mm256_srli_epi16(b, 12)), mask_f_0);
__m256i tmp1 = _mm256_xor_si256(tmp_mul_0, _mm256_slli_epi16(tmp_mul_1, 4));
__m256i tmp2 = _mm256_xor_si256(tmp1, _mm256_slli_epi16(tmp_mul_2, 8));
__m256i tmp = _mm256_xor_si256(tmp2, _mm256_slli_epi16(tmp_mul_3, 12));
return tmp;
}
此实现利用相对昂贵的
_mm256_mullo_epi16
指令 4 次来分别计算每个 4 位 limb
。
这能以某种方式更快地完成吗?更准确地说:是否可以减少所需指令的数量?
你的函数对我来说已经很完美了,因为没有 _mm256_mullo_epi8 内在函数。除了切换 AVX512 之外,我认为没有什么可以加快速度。话虽如此,最后的 3 个异或运算不需要全部依赖于之前的结果,因此您可以按如下方式重新排列它们,并希望从某些指令级并行性中受益。
__m256i _mm256_mullo_epi4(const __m256i a, const __m256i b) {
__m256i mask_f_0 = _mm256_set1_epi16(0x000f);
__m256i tmp_mul_0 = _mm256_and_si256(_mm256_mullo_epi16(a, b), mask_f_0);
__m256i tmp_mul_1 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 4), _mm256_srli_epi16(b, 4)), mask_f_0);
__m256i tmp_mul_2 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 8), _mm256_srli_epi16(b, 8)), mask_f_0);
__m256i tmp_mul_3 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 12), _mm256_srli_epi16(b, 12)), mask_f_0);
__m256i tmp1 = _mm256_xor_si256(tmp_mul_0, _mm256_slli_epi16(tmp_mul_1, 4));
__m256i tmp2 = _mm256_xor_si256(_mm256_slli_epi16(tmp_mul_3, 12), _mm256_slli_epi16(tmp_mul_2, 8));
return _mm256_xor_si256(tmp1, tmp2);
}
我真的不知道为什么,但有时我发现如果使用较少的中间结果,内在函数的工作速度会更快一些。如果您不介意牺牲可读性,您可以将性能与上述函数的以下版本进行比较,而无需显式声明中间结果。
__m256i _mm256_mullo_epi4(const __m256i a, const __m256i b) {
__m256i mask_f_0 = _mm256_set1_epi16(0x000f);
return _mm256_xor_si256(_mm256_xor_si256(_mm256_and_si256(_mm256_mullo_epi16(a, b), mask_f_0),
_mm256_slli_epi16(_mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 4), _mm256_srli_epi16(b, 4)), mask_f_0), 4)),
_mm256_xor_si256(_mm256_slli_epi16(_mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 12), _mm256_srli_epi16(b, 12)), mask_f_0), 12),
_mm256_slli_epi16(_mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a, 8), _mm256_srli_epi16(b, 8)), mask_f_0), 8)));
}