在从 numpy 数组创建的 ndarray 切片之间执行广播添加

问题描述 投票:0回答:1

我正在尝试编写可以从 Python 调用的 Rust 代码。为简单起见,此代码应仅采用二维布尔数组并将第二行与第一行进行异或。我尝试编写这段代码:

use numpy::PyArray2;
use pyo3::prelude::{pymodule, PyModule, PyResult, Python};

#[pymodule]
fn state_generator(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
    #[pyfn(m)]
    fn random_cnots_transformation(_py: Python<'_>, x: &PyArray2<bool>) {
        let mut array = unsafe { x.as_array_mut() };
        let source = array.row(1);
        let mut target = array.row_mut(0);

        target ^= source;
    }

    Ok(())
}

但是使用

maturin
编译时会失败并出现以下错误:

error[E0271]: type mismatch resolving `<ViewRepr<&mut bool> as RawData>::Elem == ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
  --> src/lib.rs:35:16
   |
35 |         target ^= source;
   |                ^^ expected `bool`, found `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   |
   = note: expected type `bool`
            found struct `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   = note: required for `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>` to implement `BitXorAssign<ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>>`

error[E0277]: the trait bound `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>: ScalarOperand` is not satisfied
  --> src/lib.rs:35:16
   |
35 |         target ^= source;
   |                ^^ the trait `ScalarOperand` is not implemented for `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   |
   = help: the following other types implement trait `ScalarOperand`:
             bool
             isize
             i8
             i16
             i32
             i64
             i128
             usize
           and 9 others
   = note: required for `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>` to implement `BitXorAssign<ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>>`

error[E0271]: type mismatch resolving `<ViewRepr<&bool> as RawData>::Elem == ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
  --> src/lib.rs:35:16
   |
35 |         target ^= source;
   |                ^^ expected `bool`, found `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   |
   = note: expected type `bool`
            found struct `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   = note: required for `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>` to implement `BitXorAssign`
   = note: 1 redundant requirement hidden
   = note: required for `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>` to implement `BitXorAssign<ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>>`

error[E0277]: the trait bound `ViewRepr<&bool>: DataMut` is not satisfied
  --> src/lib.rs:35:16
   |
35 |         target ^= source;
   |                ^^ the trait `DataMut` is not implemented for `ViewRepr<&bool>`
   |
   = help: the trait `DataMut` is implemented for `ViewRepr<&'a mut A>`
   = note: required for `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>` to implement `BitXorAssign`
   = note: 1 redundant requirement hidden
   = note: required for `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>` to implement `BitXorAssign<ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>>`

我想这可能是因为我使用了

^=
,所以我尝试使用
+=
代替,但失败并出现以下错误:

error[E0368]: binary assignment operation `+=` cannot be applied to type `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>`
  --> src/lib.rs:35:9
   |
35 |         target += source;
   |         ------^^^^^^^^^^
   |         |
   |         cannot use `+=` on type `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>`

我已经阅读了这个答案,但我不确定这里有什么区别。如果我没记错的话,区别在于在我的代码中,

array
是一个
ArrayViewMut<bool, lx2>
,而在答案中它是一个
Array2

我应该在代码中更改什么来执行此类操作?

如果这很重要,我正在使用

cargo 1.72.0
,这是我的
Cargo.toml

[package]
name = "state_generator"
version = "0.1.0"
authors = ["Tristan NEMOZ"]
edition = "2021"

[lib]
crate-type = ["cdylib"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ndarray = "0.15.6"
numpy = "0.19.0"
rand = "0.8.5"

[dependencies.pyo3]
version = "0.19.2"
features = ["extension-module"]
rust multidimensional-array pyo3 rust-ndarray
1个回答
0
投票

这个错误很令人困惑,但其关键在于按位异或仅使用

&ArrayView
(更准确地说,
&ArrayBase
)实现,而不是拥有的
ArrayView
,所以很简单:

target ^= &source;

但是你会面临另一个错误:

error: cannot borrow `array` as mutable because it is also borrowed as immutable
 --> src/lib.rs:5:22
  |
4 |     let source = array.row(1);
  |                  ------------ immutable borrow occurs here
5 |     let mut target = array.row_mut(0);
  |                      ^^^^^^^^^^^^^^^^ mutable borrow occurs here
6 |
7 |     target ^= &source;
  |               ------- immutable borrow later used here

这可以通过使用行上的迭代器来解决:

let mut rows = array.rows_mut().into_iter();
let mut target = rows.next().unwrap();
let source = rows.next().unwrap();

target ^= &source;
© www.soinside.com 2019 - 2024. All rights reserved.