我有一个自定义的 golang (1.23.0) 程序集,它执行 AVX512 操作以加速非常常见的代码路径。该函数通过将手牌表示为 int64 位集来检查一组玩家是否持有一手扑克牌。代码如下所示(CardSet 只是一个 int64):
// func SubsetAVX512(cs []CardSet, hs []CardSet) int
// Returns 1 if any card set in cards contains any hand in hands, 0 otherwise
#include "textflag.h"
#define cs_data 0(FP)
#define cs_len 8(FP)
#define cs_cap 16(FP)
#define hs_data 24(FP)
#define hs_len 32(FP)
#define hs_cap 40(FP)
#define ret_off 48(FP)
// Define the function
TEXT ·SubsetAVX512(SB), NOSPLIT, $0-56
// Start of the function
// Load parameters into registers
MOVQ cs+cs_data, R8 // R8 = cards_ptr
MOVQ cs+cs_len, R9 // R9 = cards_len
MOVQ hs+hs_data, R10 // R10 = hands_ptr
MOVQ hs+hs_len, R11 // R11 = hands_len
// Check if hands_len == 0
TESTQ R11, R11
JE return_false
// Check if cards_len == 0
TESTQ R9, R9
JE return_false
// Initialize loop counters
XORQ R12, R12 // R12 = i = 0 (hands index)
// Main loop over hands
outer_loop:
CMPQ R12, R11 // Compare i (R12) with hands_len (R11)
JGE return_false // If i >= hands_len, no match found
// Load 8 hands into Z0 (512-bit register)
LEAQ (R10)(R12*8), R13 // R13 = &hands[i]
VMOVDQU64 0(R13), Z0 // Load 8 int64s from [R13] into Z0
// Inner loop over cards
XORQ R14, R14 // R14 = j = 0 (cards index)
inner_loop:
CMPQ R14, R9 // Compare j (R14) with cards_len (R9)
JGE next_hands_block // If j >= cards_len, move to next hands block
// Load cs from cards[j]
LEAQ (R8)(R14*8), R15 // R15 = &cards[j]
MOVQ 0(R15), AX // AX = cards[j]
// Broadcast cs into Z1
VPBROADCASTQ AX, Z1 // Broadcast RAX into all lanes of Z1
// Compute cs_vec & h_vec
VPANDQ Z0, Z1, Z2 // Z2 = Z0 & Z1
// Compare (cs_vec & h_vec) == h_vec
VPCMPEQQ Z0, Z2, K1 // Compare Z0 == Z2, store result in mask K1
// Check if any comparison is true
KORTESTW K1, K1 // Test if any bits in K1 are set
JNZ found_match // If so, a match is found
// Increment card index
INCQ R14 // j++
JMP inner_loop // Repeat inner loop
next_hands_block:
// Increment hands index by 8
ADDQ $8, R12 // i += 8
JMP outer_loop // Repeat outer loop
found_match:
// Match found, return 1
MOVQ $1, AX // Set return value to 1 (true)
RET
return_false:
// No match found, return 0
XORQ AX, AX // Set return value to 0 (false)
RET
只要不同时调用此代码就可以很好地工作,这是有效的:
type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
cs := []CardSet{3, 1}
hs := []CardSet{3, 0}
var count int64
for i := 0; i < 5; i++ {
if SubsetAVX512(cs, hs) {
atomic.AddInt64(&count, 1)
}
}
require.Equal(t, int64(5), count)
}
但是,这失败了:
type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
cs := []CardSet{3, 1}
hs := []CardSet{3, 0}
var count int64
wg := sync.WaitGroup{}
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if SubsetAVX512(cs, hs) {
atomic.AddInt64(&count, 1)
}
}()
}
wg.Wait()
require.Equal(t, int64(5), count)
}
我相信这个问题与我正在使用的一些寄存器被并发 goroutine 覆盖有关。我的猜测是它是掩码寄存器
K1
但这只是一个稍微有根据的猜测。
您的问题是,当 Go 调用约定要求您在堆栈上返回结果时,您尝试在
AX
中返回结果。 更改返回使用
MOVQ $1, ret+ret_off
正确返回结果,您会发现问题消失。