假设我有一个维度为
(N_i, N_j, N_k, ... N_n)
的 N 维数组。为了获得快速性能,该数组在代码中实现为长度为 N_i x N_j x N_k, ... x N_N
的平面数组/列表。例如,如果数组是维度为 (M, K)
的二维数组,那么它将表示为长度为 M x K
的平面数组。
让我将维度为
n_arr
的 n 维数组 dims
及其支持平面数组 f_arr
称为。如果我想将视图 n_arr[i]
返回到该数组中,例如n_arr[0]
,如果数组是二维数组,那么我发现:
let start_index = i * dims[1]
let end_index = i * dims[1] + dims[1]
n_arr[i] = f_arr[start_index:end_index]
但是,我无法推导出 3D 情况或任何高维情况的公式。查找返回 N 维数组视图的开始索引和结束索引的一般公式是什么?
我已经找到了解决这个问题的方法。为此,我做了一个 班级:
class Array:
def __init__(data, shape):
self.array = data
self.dims = shape
def ndims(self):
return len(self.dims)
还有一个方法:
fn get_flat_index(self, idx) {
let mut i = 0;
for j in 0..self.dims.len() {
assert(idx[j] < self.dims[j]);
if idx[j] >= self.dims[j] {
print("Index {} is out of bounds for dimension {} with size {}", idx[j], j, self.dims[j])
}
i = i * self.dims[j] + idx[j];
}
i
}
其中
get_flat_index([M, N, K...J])
返回平坦索引:
// 2d case
n_arr[i] = f_arr[get_flat_index(i, 0), get_flat_index(i, dims[1] - 1)
// 3d case
n_arr[i] = f_arr[get_flat_index(i, 0, 0), get_flat_index(i, dims[1] - 1, dims[2] - 1)
// N-d case
n_arr[i] = f_arr[get_flat_index(i, 0, 0, ... 0), get_flat_index(i, dims[n-1] - 1, dims[n-2] - 1, ... dims[n - n] - 1)
使用这个我实现了我的
index()
功能:
fn index(self, idx) {
let start_index_slice = vec![0; self.ndims()];
let end_index_slice = vec![0; self.ndims()];
start_index_slice[0] = idx;
for i in 1..self.ndims() {
end_index_slice[i] = self.dims[i] - 1;
}
let start_index = self.get_index(start_index_slice);
let end_index = self.get_index(end_index_slice);
self.array[start_index:end_index + 1]
}