我想创建一个 3d ndarray 来表示三个 1d ndarray 的坐标,类似于 numpy 的网格网格的工作方式。例如python 的 rust 等价物是什么:
a = np.array(
np.meshgrid(
np.linspace(-1, 1, 50),
np.linspace(-1, 1, 50),
np.linspace(-1, 1, 50),
indexing="ij"
)
)
虽然它不是内置的,但我们可以编写自己的。我们生成尺寸最大为 6 的静态网格,这是
Dimension
中静态 ndarray
的最大尺寸。 (在此之后,我们必须使事情变得动态。)
use ndarray::{prelude::*, Ix};
enum Indexing {
Ij,
Xy,
}
macro_rules! meshgrid_gen {
($(($fn_name:ident, $n_arrs:literal)),* $(,)?) => {
$(fn $fn_name<T: Clone>(
arrs: [ArrayView1<T>; $n_arrs], indexing: Indexing,
) -> [Array<T, Dim<[Ix; $n_arrs]>>; $n_arrs] {
let mat_shape = {
let mut shape = arrs.map(|arr| arr.len());
if let Indexing::Xy = indexing {
shape.swap(0, 1);
}
shape
};
let mut out_arrs: [_; $n_arrs] = std::array::from_fn(|_| {
Array::<T, Dim<[Ix; $n_arrs]>>::uninit(mat_shape)
});
for (i, (out_arr, in_arr)) in
out_arrs.iter_mut().zip(&arrs).enumerate()
{
let axis = Axis(match indexing {
Indexing::Ij => i,
Indexing::Xy => if i <= 1 { 1 - i } else { i },
});
for lane in out_arr.view_mut().lanes_mut(axis) {
in_arr.assign_to(lane);
}
}
out_arrs.map(|arr| unsafe { arr.assume_init() })
})*
};
}
meshgrid_gen!(
(meshgrid1, 1),
(meshgrid2, 2),
(meshgrid3, 3),
(meshgrid4, 4),
(meshgrid5, 5),
(meshgrid6, 6),
);
fn main() {
let arr1 = Array1::linspace(-1.0, 1.0, 2);
let arr2 = Array1::linspace(-1.0, 1.0, 3);
let arr3 = Array1::linspace(-1.0, 1.0, 4);
let grid = meshgrid3([arr1.view(), arr2.view(), arr3.view()], Indexing::Ij);
println!("{grid:#?}");
}
[
[[[-1.0, -1.0, -1.0, -1.0],
[-1.0, -1.0, -1.0, -1.0],
[-1.0, -1.0, -1.0, -1.0]],
[[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0]]], shape=[2, 3, 4], strides=[12, 4, 1], layout=Cc (0x5), const ndim=3,
[[[-1.0, -1.0, -1.0, -1.0],
[0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0]],
[[-1.0, -1.0, -1.0, -1.0],
[0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0]]], shape=[2, 3, 4], strides=[12, 4, 1], layout=Cc (0x5), const ndim=3,
[[[-1.0, -0.33333333333333337, 0.33333333333333326, 1.0],
[-1.0, -0.33333333333333337, 0.33333333333333326, 1.0],
[-1.0, -0.33333333333333337, 0.33333333333333326, 1.0]],
[[-1.0, -0.33333333333333337, 0.33333333333333326, 1.0],
[-1.0, -0.33333333333333337, 0.33333333333333326, 1.0],
[-1.0, -0.33333333333333337, 0.33333333333333326, 1.0]]], shape=[2, 3, 4], strides=[12, 4, 1], layout=Cc (0x5), const ndim=3,
]