获取keras输入数据的numpy数组列表的任何解决方法?

问题描述 投票:0回答:1

我正在尝试在keras上训练CNN模型,我的数据看起来像这样

type(datain)
<class 'list'>
len(datain)
35000
type(datain[0])
<class 'numpy.ndarray'>
datain[0].shape
(256,256,1)

作为我的输入数据的数组列表,我在尝试训练网络时遇到此错误

AttributeError: 'list' object has no attribute 'shape'

但是当我尝试像np.array(datain)那样尝试像https://github.com/keras-team/keras/issues/4823这样的东西时,我的电脑会死机/崩溃。使用python list定义我的输入总共需要60秒,但如果我从头开始尝试numpy数组,但每个(256,256,1)数组需要1秒,如果我打算对我的网络进行各种测试和修改,那就太多了, 这个问题有什么解决方法吗? 任何使用keras列表的方法? 一种不同的方式来定义一个numpy数组? 还是我误解了什么?

python arrays numpy keras
1个回答
2
投票

从数据创建生成器。

generator是一个python概念,它循环并产生结果。对于Keras,你的发电机应该无限期地产生批量的X_trainy_train

所以,你可以制作一个简单的发电机:

def generator(batch_size,from_list_x,from_list_y):

    assert len(from_list_x) == len(from_list_y)
    total_size = len(from_list_x)

    while True #keras generators should be infinite

        for i in range(0,total_size,batch_size):
            yield np.array(from_list_x[i:i+batch_size]), np.array(from_list_y[i:i+batch_size])

在培训中使用发电机:

model.fit_generator(generator(size,datain,dataout),
                    steps_per_epoch=len(datain)//size, 
                    epochs=...,...)
© www.soinside.com 2019 - 2024. All rights reserved.