我正在尝试在keras上训练CNN模型,我的数据看起来像这样
type(datain)
<class 'list'>
len(datain)
35000
type(datain[0])
<class 'numpy.ndarray'>
datain[0].shape
(256,256,1)
作为我的输入数据的数组列表,我在尝试训练网络时遇到此错误
AttributeError: 'list' object has no attribute 'shape'
但是当我尝试像np.array(datain)
那样尝试像https://github.com/keras-team/keras/issues/4823这样的东西时,我的电脑会死机/崩溃。使用python list定义我的输入总共需要60秒,但如果我从头开始尝试numpy数组,但每个(256,256,1)
数组需要1秒,如果我打算对我的网络进行各种测试和修改,那就太多了,
这个问题有什么解决方法吗?
任何使用keras列表的方法?
一种不同的方式来定义一个numpy数组?
还是我误解了什么?
从数据创建生成器。
generator
是一个python概念,它循环并产生结果。对于Keras,你的发电机应该无限期地产生批量的X_train
和y_train
。
所以,你可以制作一个简单的发电机:
def generator(batch_size,from_list_x,from_list_y):
assert len(from_list_x) == len(from_list_y)
total_size = len(from_list_x)
while True #keras generators should be infinite
for i in range(0,total_size,batch_size):
yield np.array(from_list_x[i:i+batch_size]), np.array(from_list_y[i:i+batch_size])
在培训中使用发电机:
model.fit_generator(generator(size,datain,dataout),
steps_per_epoch=len(datain)//size,
epochs=...,...)