如何在 Rust 中自动向量化 (SIMD) 模乘法

问题描述 投票:0回答:1

我正在尝试优化具有模乘法的代码,以使用 SIMD 自动矢量化。也就是说,我不想使用任何库,编译器应该完成这项工作。这是我可以获得的可验证的小例子:

#[inline(always)]
fn mod_mul64(
    a: u64,
    b: u64,
    modulus: u64,
) -> u64 {
    ((a as u128 * b as u128) % modulus as u128) as u64
}

pub fn mul(a: &mut [u64], b: &[u64], modulo: u64){
    for _ in (0..1000).step_by(4) {
        a[0] = mod_mul64(b[0], a[7], modulo);
        a[1] = mod_mul64(b[1], a[6], modulo);
        a[2] = mod_mul64(b[2], a[5], modulo);
        a[3] = mod_mul64(b[3], a[4], modulo);
        a[4] = mod_mul64(b[4], a[3], modulo);
        a[5] = mod_mul64(b[5], a[2], modulo);
        a[6] = mod_mul64(b[6], a[1], modulo);
        a[7] = mod_mul64(b[7], a[0], modulo);
    }
}

#[allow(unused)]
pub fn main() {
    let a: &mut[u64] = todo!();
    let b: &[u64] = todo!();
    let modulo = todo!();
    mul(a, b, modulo);
    println!("a: {:?}", a);
}

https://godbolt.org/z/h8zfadz3d所示,即使打开优化并且目标 CPU 是本机的,也没有 SIMD 指令,对于矢量来说应该以

v
开头。

我知道这个

mod_mul64
实现可能不适合 SIMD。修改它以便自动获得 SIMD 的简单方法应该是什么?

rust simd micro-optimization auto-vectorization
1个回答
0
投票

您当前的代码似乎是错误的,只是修改了前 8 个数字。我假设您正在尝试矢量化一般

(a * b) % c
数学。有很多 mod/rem 实现,但能够自动矢量化工作,但仅限于 32 位值,并具有恒定模数溢出。

#![allow(dead_code)]

// `cargo asm --lib --native mul32_const_asm_test 0`
// `cargo asm --lib --native mul32_const_asm_test_aligned 0`
// `cargo asm --lib --native mul64_const_asm_test_aligned 0`
// `cargo asm --lib --native mul_many 0`
// `cargo asm --lib --native mul_many_aligned_u64 0`

// MARK: Const modulus asm tests

#[no_mangle]
pub const fn mul32_const_asm_test(a: &[u32; 8], b: &[u32; 8]) -> [u32; 8] {
    let some_random_number = 2232673653;
    mul_many(a, b, some_random_number)
}

#[no_mangle]
pub const fn mul32_const_asm_test_aligned(
    a: &Aligned64<[u32; 8]>,
    b: &Aligned64<[u32; 8]>,
) -> Aligned64<[u32; 8]> {
    let some_random_number = 2232673653;
    mul_many_aligned(a, b, some_random_number)
}

#[no_mangle]
pub const fn mul64_const_asm_test_aligned(
    a: &Aligned64<[u64; 4]>,
    b: &Aligned64<[u64; 4]>,
) -> Aligned64<[u64; 4]> {
    let some_random_number = 2232673653;
    mul_many_aligned_u64(a, b, some_random_number)
}

// MARK: Non const asm Tests

// NOTE: scalar asm on its own, can be vectorized if `modulo` is a constant.
#[no_mangle]
pub const fn mul_many(a: &[u32; 8], b: &[u32; 8], modulo: u32) -> [u32; 8] {
    // let func = mod_mul32_expanding; // not vectorized
    // let func = mod_mul32_triple_custom; // vectorized, big
    // let func = mod_mul32_simple_custom; // vectorized
    // let func = mod_mul32_triple; // vectorized
    let func = mod_mul32_simple; // vectorized

    let mut out = [0; 8];
    out[0] = func(b[0], a[7], modulo);
    out[1] = func(b[1], a[6], modulo);
    out[2] = func(b[2], a[5], modulo);
    out[3] = func(b[3], a[4], modulo);
    out[4] = func(b[4], a[3], modulo);
    out[5] = func(b[5], a[2], modulo);
    out[6] = func(b[6], a[1], modulo);
    out[7] = func(b[7], a[0], modulo);
    out
}

// NOTE: scalar asm on its own, can be vectorized if `modulo` is a constant.
#[no_mangle]
pub const fn mul_many_aligned(
    a: &Aligned64<[u32; 8]>,
    b: &Aligned64<[u32; 8]>,
    modulo: u32,
) -> Aligned64<[u32; 8]> {
    // let func = mod_mul32_expanding; // not vectorized
    // let func = mod_mul32_triple_custom; // vectorized, big
    // let func = mod_mul32_simple_custom; // vectorized
    // let func = mod_mul32_triple; // vectorized
    let func = mod_mul32_simple; // vectorized

    let mut out = Aligned64([0; 8]);
    out.0[0] = func(b.0[0], a.0[7], modulo);
    out.0[1] = func(b.0[1], a.0[6], modulo);
    out.0[2] = func(b.0[2], a.0[5], modulo);
    out.0[3] = func(b.0[3], a.0[4], modulo);
    out.0[4] = func(b.0[4], a.0[3], modulo);
    out.0[5] = func(b.0[5], a.0[2], modulo);
    out.0[6] = func(b.0[6], a.0[1], modulo);
    out.0[7] = func(b.0[7], a.0[0], modulo);
    out
}

// I couldn't get this vectorized
#[no_mangle]
pub const fn mul_many_aligned_u64(
    a: &Aligned64<[u64; 4]>,
    b: &Aligned64<[u64; 4]>,
    modulo: u64,
) -> Aligned64<[u64; 4]> {
    // let func = mod_mul64_expanding; // not vectorized
    // let func = mod_mul64_simple; // not vectorized
    let func = mod_mul64_simple_custom; // surprising not vectorized

    let mut out = Aligned64([0; 4]);
    out.0[0] = func(b.0[0], a.0[3], modulo);
    out.0[1] = func(b.0[1], a.0[2], modulo);
    out.0[2] = func(b.0[2], a.0[1], modulo);
    out.0[3] = func(b.0[3], a.0[0], modulo);
    out
}

// MARK: 32 bit

/// Never overflows
#[inline(always)]
const fn mod_mul32_expanding(a: u32, b: u32, modulus: u32) -> u32 {
    ((a as u64 * b as u64) % modulus as u64) as u32
}

/// Overflows if `modulus`, and `a` and `b` are huge
#[inline(always)]
const fn mod_mul32_triple(a: u32, b: u32, modulus: u32) -> u32 {
    (a % modulus * b % modulus) % modulus
}

/// Overflows if `a * b` overflows
#[inline(always)]
const fn mod_mul32_simple(a: u32, b: u32, modulus: u32) -> u32 {
    (a * b) % modulus
}

#[inline(always)]
const fn mod_mul32_triple_custom(a: u32, b: u32, modulus: u32) -> u32 {
    rem_u32(rem_u32(a, modulus) * rem_u32(b, modulus), modulus)
}

#[inline(always)]
const fn mod_mul32_simple_custom(a: u32, b: u32, modulus: u32) -> u32 {
    rem_u32(a * b, modulus)
}

// MARK: 64 bit

/// Never overflows
#[inline(always)]
const fn mod_mul64_expanding(a: u64, b: u64, modulus: u64) -> u64 {
    ((a as u128 * b as u128) % modulus as u128) as u64
}

/// Overflows if `a * b` overflows
#[inline(always)]
const fn mod_mul64_simple(a: u64, b: u64, modulus: u64) -> u64 {
    (a * b) % modulus
}

#[inline(always)]
const fn mod_mul64_simple_custom(a: u64, b: u64, modulus: u64) -> u64 {
    rem_u64(a * b, modulus)
}

// MARK: Helpers

/// I dont think it overflows and I think gives exact same resutls as % for unsigned.
#[inline(always)]
const fn rem_u32(lhs: u32, rhs: u32) -> u32 {
    // TODO: does any of this overflow?
    lhs - (rhs * (lhs / rhs))
}

/// I dont think it overflows and I think gives exact same resutls as % for unsigned.
#[inline(always)]
const fn rem_u64(lhs: u64, rhs: u64) -> u64 {
    // TODO: does any of this overflow?
    lhs - (rhs * (lhs / rhs))
}

// 32 * 8 = 256 bits
#[repr(align(64))]
pub struct Aligned64<T>(pub T);

impl<T> std::ops::Deref for Aligned64<T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<T> std::ops::DerefMut for Aligned64<T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

您应该能够使用整数

lhs % rhs
=
lhs - (rhs * (lhs / rhs))
来手动向量化 64 位值。 (不确定是否会溢出,我认为不会)

自动矢量化的一些技巧:

  • RUSTFLAGS="-Ctarget-cpu=native" cargo build --release
  • 对齐数据
  • 使用数组代替切片
  • 早期断言切片长度相同
  • 不要改变输入/改变输入后不要读取输入
    • 尝试将依赖链保持在最低限度,并行而不是串行
  • 删除紧急路径
    • div 或 rem 除零,超出范围
    • NonZeroU32,get_unchecked()
  • 也许手动部分循环展开
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.