在 Rust 中将 CSV 文件中的数据读取到 ndarray 会返回 ShapeError

问题描述 投票:0回答:1

尝试读取 CSV 文件并返回 ndarray 会返回错误。代码清单是:

use ndarray::prelude::*;
use ndarray::{NdIndex, RemoveAxis, OwnedRepr, Array2, ShapeError};

fn read_file(filename:&str, sep:&str, header:bool) -> Result<ArrayBase<OwnedRepr<Vec<f64>>, Dim<[usize; 2]>>, ShapeError>{
    
    let path = Path::new(filename);
    let input = File::open(path).unwrap();
    let reader = BufReader::new(input);
    let mut vector:Vec<Vec<f64>> = Vec::new();

    let skip = if header {1} else {0};
    for line in reader.lines().skip(skip) {
        let data:Vec<String> = line.unwrap().split(sep).into_iter().map(|x| {x.to_string()}).collect();
        let data = data.into_iter().map(|x| x.parse().unwrap()).collect::<Vec<f64>>();
        println!("{:?}", data);
        vector.extend_from_slice(&[data]);
    }


    println!("{}, {}", vector.len(), vector[0].len());
    Array2::from_shape_vec((vector.len(), vector[0].len()), vector)
    
}

为了保证完整性和可重复性,CSV 文件的内容如下所示:

CD4 CD8b CD3 CD8
199 420 132 226
294 311 241 164
85 79 14 218
19 1 141 130

调用

read_file(filename, ",", true)
返回
Err(ShapeError/OutOfBounds: out of bounds indexing)

rust rust-ndarray
1个回答
0
投票

这是一个经过校正但未优化的版本,考虑到

Array2::from_shape_vec
想要一个平坦的输入向量:

use std::io::Cursor;
use std::io::BufRead;
use std::io::BufReader;
use ndarray::prelude::*;
use ndarray::{OwnedRepr, Array2, ShapeError};

fn read_file(input: Cursor<Vec<u8>>, sep:&str, header:bool) -> Result<ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, ShapeError>{
    let reader = BufReader::new(input);
    let mut vector: Vec<f64> = Vec::new();

    let skip = if header {1} else {0};
    let mut columns = 0;
    for line in reader.lines().skip(skip) {
        let data:Vec<String> = line.unwrap().split(sep).into_iter().map(|x| {x.to_string()}).collect();
        let data = data.into_iter().map(|x| x.parse().unwrap()).collect::<Vec<f64>>();
        println!("{:?}", data);
        vector.extend_from_slice(&data);
        columns = data.len();
    }

    println!("{}, {}", vector.len() / columns, columns);
    Array2::from_shape_vec((vector.len() / columns, columns), vector)
}

pub fn main() {
    let input = r#"CD4,CD8b,CD3,CD8
199,420,132,226
294,311,241,164
85,79,14,218
19,1,141,130"#;
    
    let array = read_file(Cursor::new(input.to_owned().into()), ",", true).unwrap();

    println!("{array:?}");
}

(将文件 I/O 更改为内存数据结构以使答案独立)

© www.soinside.com 2019 - 2024. All rights reserved.