如何在Rust中提高元素乘法的性能?

问题描述 投票:6回答:1

我将使用10 ^ 6 +元素对多个向量进行元素乘法。这在标题中被标记为我的代码中最慢的部分之一,所以我该如何改进它?

/// element-wise multiplication for vecs
pub fn vec_mul<T>(v1: &Vec<T>, v2: &Vec<T>) -> Vec<T>
where
    T: std::ops::Mul<Output = T> + Copy,
{
    if v1.len() != v2.len() {
        panic!("Cannot multiply vectors of different lengths!")
    }
    let mut out: Vec<T> = Vec::with_capacity(v1.len());
    for i in 0..(v1.len()) {
        out.push(v1[i] * v2[i]);
    }
    out
}
performance optimization rust
1个回答
7
投票

Vec或切片上使用索引器运算符时,编译器必须检查索引是在边界还是超出范围。

但是,当您使用迭代器时,这些边界检查将被省略,因为迭代器已经过仔细编写,以确保它们永远不会读出边界。此外,由于借用如何在Rust中工作,当迭代器存在于该数据结构上时(通过迭代器本身除外),数据结构不能被变异,因此在迭代期间有效边界不可能发生变化。

由于您同时迭代两个不同的数据结构,因此您需要使用zip迭代器适配器。一旦迭代器耗尽,zip就会停止,所以它仍然与验证两个向量的长度相同。 zip生成元组的迭代器,其中每个元组包含两个原始迭代器中相同位置的项。然后你可以使用map将每个元组转换为两个值的乘积。最后,你需要将collect生成的新迭代器map变成Vec,然后你可以从你的函数返回。 collect使用size_hint使用Vec::with_capacity为载体预分配内存。

/// element-wise multiplication for vecs
pub fn vec_mul<T>(v1: &[T], v2: &[T]) -> Vec<T>
where
    T: std::ops::Mul<Output = T> + Copy,
{
    if v1.len() != v2.len() {
        panic!("Cannot multiply vectors of different lengths!")
    }

    v1.iter().zip(v2).map(|(&i1, &i2)| i1 * i2).collect()
}

注意:我已经更改了签名以获取切片而不是对向量的引用。有关更多信息,请参阅Why is it discouraged to accept a reference to a String (&String), Vec (&Vec), or Box (&Box) as a function argument?

© www.soinside.com 2019 - 2024. All rights reserved.