sse4.2 _mm_cmpistrm/_mm_cmpesrm 指令得到错误结果

问题描述 投票:0回答:1

我想使用以下代码来计算数组 a 和数组 b 的交集:

#include <nmmintrin.h>
#include <cstdint>
#include <cstdio>
void test(uint16_t *a, uint16_t *b) {
    __m128i v_a = _mm_loadu_si128((__m128i*)a);
    __m128i v_b = _mm_loadu_si128((__m128i*)b);
    __m128i res_v1 = _mm_cmpistrm(v_a, v_b, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK);
    uint32_t mask_a = _mm_extract_epi32(res_v1, 0);
    __m128i res_v2 = _mm_cmpistrm(v_b, v_a, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK);
    uint32_t mask_b = _mm_extract_epi32(res_v2, 0);
    printf("match a:");
    for (int i = 0; i < 8; i++) {
        if (mask_a & (1 << (7 - i))) {
            printf(" %d", a[i]);
        }
    }
    putchar('\n');
    printf("match b:");
    for (int i = 0; i < 8; i++) {
        if (mask_b & (1 << (7 - i))) {
            printf(" %d", b[i]);
        }
    }
    putchar('\n');
}
int main() {
    uint16_t a[] = {13, 18, 19, 24, 97, 104, 456, 1024};
    uint16_t b[] = {11, 17, 18, 19, 24, 58, 104, 456};
    test(a, b);
}

我期望的输出是:

match a: 18 19 24 104 456
match b: 18 19 24 104 456

然而实际结果是:

match a: 13 18 24 97 104
match b: 17 18 24 58 104

似乎 _mm_cmpistrm 指令给了我错误的结果!

我的代码有什么错误,或者Intel CPU有bug吗?

如果我犯了错误,我该如何纠正?

c++ x86 intel simd sse
1个回答
0
投票

该代码以两种不同的方式“反向”使用掩码:

  • 它隐式地采用掩码的位反转。
  • 它将
    _mm_cmpistrm(v_a, v_b, ...
    产生的掩码解释为指示来自
    a
    的元素也位于
    b
    中,但该掩码指示来自
    b
    的元素也位于
    a
    中。

有了这个代码,

void test(uint16_t *a, uint16_t *b) {
    __m128i v_a = _mm_loadu_si128((__m128i*)a);
    __m128i v_b = _mm_loadu_si128((__m128i*)b);
    __m128i res_v1 = _mm_cmpistrm(v_a, v_b, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK);
    uint32_t mask_b = _mm_cvtsi128_si32(res_v1);
    __m128i res_v2 = _mm_cmpistrm(v_b, v_a, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK);
    uint32_t mask_a = _mm_cvtsi128_si32(res_v2);
    printf("match b:");
    for (int i = 0; i < 8; i++) {
        if (mask_b & (1 << i)) {
            printf(" %d", b[i]);
        }
    }
    putchar('\n');
    printf("match a:");
    for (int i = 0; i < 8; i++) {
        if (mask_a & (1 << i)) {
            printf(" %d", a[i]);
        }
    }
    putchar('\n');
}

我得到输出:

match b: 18 19 24 104 456
match a: 18 19 24 104 456
© www.soinside.com 2019 - 2024. All rights reserved.