获取 n 维数组视图的通用函数(伪代码)

问题描述 投票:0回答:1

假设我有一个维度为

(N_i, N_j, N_k, ... N_n)
的 N 维数组。为了获得快速性能,该数组在代码中实现为长度为
N_i x N_j x N_k, ... x N_N
的平面数组/列表。例如,如果数组是维度为
(M, K)
的二维数组,那么它将表示为长度为
M x K
的平面数组。

让我将维度为

n_arr
的 n 维数组
dims
及其支持平面数组
f_arr
称为。如果我想将视图
n_arr[i]
返回到该数组中,例如
n_arr[0]
,如果数组是二维数组,那么我发现:

let start_index = i * dims[1]
let end_index = i * dims[1] + dims[1]
n_arr[i] = f_arr[start_index:end_index]

但是,我无法推导出 3D 情况或任何高维情况的公式。查找返回 N 维数组视图的开始索引和结束索引的一般公式是什么?

multidimensional-array indexing
1个回答
0
投票

我已经找到了解决这个问题的方法。为此,我做了一个 班级:

class Array:
    def __init__(data, shape):
        self.array = data
        self.dims = shape

    def ndims(self):
        return len(self.dims)

还有一个方法:

fn get_flat_index(self, idx) {
    let mut i = 0;
    for j in 0..self.dims.len() {
        assert(idx[j] < self.dims[j]);
        if idx[j] >= self.dims[j] {
            print("Index {} is out of bounds for dimension {} with size {}", idx[j], j, self.dims[j])
            }
            i = i * self.dims[j] + idx[j];
        }
        i
}

其中

get_flat_index([M, N, K...J])
返回平坦索引:

// 2d case
n_arr[i] = f_arr[get_flat_index(i, 0), get_flat_index(i, dims[1] - 1)

// 3d case
n_arr[i] = f_arr[get_flat_index(i, 0, 0), get_flat_index(i, dims[1] - 1, dims[2] - 1)

// N-d case
n_arr[i] = f_arr[get_flat_index(i, 0, 0, ... 0), get_flat_index(i, dims[n-1] - 1, dims[n-2] - 1, ... dims[n - n] - 1)

使用这个我实现了我的

index()
功能:

fn index(self, idx) {
    let start_index_slice = vec![0; self.ndims()];
    let end_index_slice = vec![0; self.ndims()];
    start_index_slice[0] = idx;
    for i in 1..self.ndims() {
        end_index_slice[i] = self.dims[i] - 1;
    }
    let start_index = self.get_index(start_index_slice);
    let end_index = self.get_index(end_index_slice);
    self.array[start_index:end_index + 1]
}
© www.soinside.com 2019 - 2024. All rights reserved.