从不同的 goroutine 同时调用时,AVX512 程序集会中断

问题描述 投票:0回答:1

我有一个自定义的 golang (1.23.0) 程序集,它执行 AVX512 操作以加速非常常见的代码路径。该函数通过将手牌表示为 int64 位集来检查一组玩家是否持有一手扑克牌。代码如下所示(CardSet 只是一个 int64):

// func SubsetAVX512(cs []CardSet, hs []CardSet) int
// Returns 1 if any card set in cards contains any hand in hands, 0 otherwise

#include "textflag.h"

#define cs_data 0(FP)
#define cs_len  8(FP)
#define cs_cap  16(FP)
#define hs_data 24(FP)
#define hs_len  32(FP)
#define hs_cap  40(FP)
#define ret_off 48(FP)

// Define the function
TEXT ·SubsetAVX512(SB), NOSPLIT, $0-56

// Start of the function
    // Load parameters into registers
    MOVQ cs+cs_data, R8         // R8 = cards_ptr
    MOVQ cs+cs_len, R9          // R9 = cards_len

    MOVQ hs+hs_data, R10        // R10 = hands_ptr
    MOVQ hs+hs_len, R11         // R11 = hands_len

    // Check if hands_len == 0
    TESTQ R11, R11
    JE return_false

    // Check if cards_len == 0
    TESTQ R9, R9
    JE return_false

    // Initialize loop counters
    XORQ R12, R12                 // R12 = i = 0 (hands index)

    // Main loop over hands
outer_loop:
    CMPQ R12, R11                 // Compare i (R12) with hands_len (R11)
    JGE return_false              // If i >= hands_len, no match found

    // Load 8 hands into Z0 (512-bit register)
    LEAQ (R10)(R12*8), R13        // R13 = &hands[i]
    VMOVDQU64 0(R13), Z0          // Load 8 int64s from [R13] into Z0

    // Inner loop over cards
    XORQ R14, R14                 // R14 = j = 0 (cards index)
inner_loop:
    CMPQ R14, R9                  // Compare j (R14) with cards_len (R9)
    JGE next_hands_block          // If j >= cards_len, move to next hands block

    // Load cs from cards[j]
    LEAQ (R8)(R14*8), R15         // R15 = &cards[j]
    MOVQ 0(R15), AX               // AX = cards[j]

    // Broadcast cs into Z1
    VPBROADCASTQ AX, Z1           // Broadcast RAX into all lanes of Z1

    // Compute cs_vec & h_vec
    VPANDQ Z0, Z1, Z2             // Z2 = Z0 & Z1

    // Compare (cs_vec & h_vec) == h_vec
    VPCMPEQQ Z0, Z2, K1           // Compare Z0 == Z2, store result in mask K1

    // Check if any comparison is true
    KORTESTW K1, K1               // Test if any bits in K1 are set
    JNZ found_match               // If so, a match is found

    // Increment card index
    INCQ R14                      // j++
    JMP inner_loop                // Repeat inner loop

next_hands_block:
    // Increment hands index by 8
    ADDQ $8, R12                  // i += 8
    JMP outer_loop                // Repeat outer loop

found_match:
    // Match found, return 1
    MOVQ $1, AX                   // Set return value to 1 (true)
    RET

return_false:
    // No match found, return 0
    XORQ AX, AX                   // Set return value to 0 (false)
    RET

只要不同时调用此代码就可以很好地工作,这是有效的:

type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
    cs := []CardSet{3, 1}
    hs := []CardSet{3, 0}
    var count int64
    for i := 0; i < 5; i++ {
        if SubsetAVX512(cs, hs) {
            atomic.AddInt64(&count, 1)
        }
    }
    require.Equal(t, int64(5), count)
}

但是,这失败了:

type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
    cs := []CardSet{3, 1}
    hs := []CardSet{3, 0}
    var count int64
    wg := sync.WaitGroup{}
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            if SubsetAVX512(cs, hs) {
                atomic.AddInt64(&count, 1)
            }
        }()
    }
    wg.Wait()
    require.Equal(t, int64(5), count)
}

我相信这个问题与我正在使用的一些寄存器被并发 goroutine 覆盖有关。我的猜测是它是掩码寄存器

K1
但这只是一个稍微有根据的猜测。

go assembly avx avx512
1个回答
0
投票

您的问题是,当 Go 调用约定要求您在堆栈上返回结果时,您尝试在

AX
中返回结果。 更改返回使用

MOVQ $1, ret+ret_off

正确返回结果,您会发现问题消失。

© www.soinside.com 2019 - 2024. All rights reserved.