我正在尝试根据新数据训练预训练的 keras 模型。我遇到了tensorflow的数据集api,我正在尝试将它与我的旧keras模型一起使用。据我所知,tf data api 返回张量,因此数据 api 和模型应该是同一图的一部分,并且数据 api 的输出应该作为模型的输入连接。这是代码
import tensorflow as tf
from data_pipeline import ImageDataGenerator
import os
import keras
from keras.engine import InputLayer
os.environ["CUDA_VISIBLE_DEVICES"]="0"
###################### to check visible devices ###############
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
###############################################################
_EPOCHS = 10
_NUM_CLASSES = 2
_BATCH_SIZE = 32
def training_pipeline():
# #############
# Load Dataset
# #############
training_set = ImageDataGenerator(directory="\\\\in-pdc-sem2\\training",
horizontal_flip=True, vertical_flip=True, rescale=True, normalize=True,
color_jitter=True, batch_size=_BATCH_SIZE,
num_cpus=8, epochs=60, output_patch_size=389, validation=False).dataset_pipeline()
testing_set = ImageDataGenerator(directory="\\\\in-pdc-sem2\\training",
horizontal_flip=False, vertical_flip=False, rescale=False, normalize=True,
color_jitter=False, batch_size=_BATCH_SIZE,
num_cpus=8, epochs=60, output_patch_size=389, validation=True).dataset_pipeline()
print(training_set.output_types, training_set.output_shapes)
iterator = tf.data.Iterator.from_structure(training_set.output_types, training_set.output_shapes)#((None, 389, 389, 3), (None)))
train_initializer = iterator.make_initializer(training_set)
validation_initializer = iterator.make_initializer(testing_set)
img, labels = iterator.get_next()
img = img.set_shape((None, 389, 389, 3))
model = baseline_model(img, labels) # keras model defined here
model.summary()
keras.backend.get_session().run(tf.global_variables_initializer())
for epoch in range(_EPOCHS):
# #############
# Train Model
# #############
keras.backend.get_session().run(train_initializer)
model.fit(
steps_per_epoch=1000000 // _BATCH_SIZE,
epochs=1,
# validation_steps=11970 // _BATCH_SIZE,
callbacks=callbacks(),
verbose = 1)
keras.backend.get_session().run(validation_initializer)
loss, acc, cross_entropy = model.evaluate(verbose=1, steps=11970 // 32)
filepath = "./weights/ResNet_16_Best/weights-improvement-Run1-" + str(epoch) + "-" + str(loss) + ".hdf5"
model.save_weights(filepath, overwrite=True)
def baseline_model(input_tensor, labels):
jsonFile = '\\\\in-pdc-sem2\\resnetV4_2Best.json'
weightsFile = '\\\\in-pdc-sem1\\resnetV4_2BestWeightsOnly.hdf5'
with open(jsonFile, "r") as file:
jsonDef = file.read()
from keras.models import model_from_json
model_single = model_from_json(jsonDef)
model_single.load_weights(weightsFile)
model_single.layers[0] = InputLayer(input_tensor=input_tensor, input_shape=(389, 389, 3))
model_single.compile(target_tensors=[labels], loss='categorical_crossentropy', optimizer='Adam', metrics=[keras.metrics.categorical_accuracy])
return model_single
def callbacks():
tensorboard = keras.callbacks.TensorBoard(log_dir='./tensorboard', write_grads=False, write_images=False, histogram_freq=0)
callbacks_list = [tensorboard]
return callbacks_list
if __name__ == '__main__':
training_pipeline()
“训练集”返回图像和标签元组,图像是形状为 (32, 389, 389, 3) 的张量,它是一批 32 张图像。我在单独的脚本中验证了形状,它是正确的。我使用 model.compile 部分中的张量和目标张量定义模型的输入层。
这就是 model.summary 输出的样子:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 389, 389, 3) 0
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 383, 383, 13) 1924 input_1[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 383, 383, 13) 52 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 383, 383, 13) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 191, 191, 13) 0 activation_1[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 191, 191, 4) 56 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 191, 191, 4) 16 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 191, 191, 4) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 191, 191, 4) 148 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 191, 191, 4) 16 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 191, 191, 4) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 191, 191, 8) 40 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 191, 191, 8) 112 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 191, 191, 8) 32 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 191, 191, 8) 32 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 191, 191, 8) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 191, 191, 8) 0 add_1[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 191, 191, 8) 32 activation_4[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 191, 191, 8) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 191, 191, 4) 292 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 191, 191, 4) 16 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 191, 191, 4) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 191, 191, 8) 40 activation_6[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 191, 191, 8) 0 res2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 191, 191, 8) 32 add_2[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 191, 191, 8) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 191, 191, 4) 292 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 191, 191, 4) 16 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 191, 191, 4) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 191, 191, 8) 40 activation_8[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 191, 191, 8) 0 res2c_branch2c[0][0]
add_2[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 96, 96, 8) 72 add_3[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 96, 96, 8) 32 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 96, 96, 8) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 96, 96, 8) 584 activation_9[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 96, 96, 8) 32 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 96, 96, 8) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 96, 96, 16) 144 activation_10[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 96, 96, 16) 144 add_3[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 96, 96, 16) 64 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 96, 96, 16) 64 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 96, 96, 16) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 96, 96, 16) 0 add_4[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 96, 96, 16) 64 activation_11[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 96, 96, 16) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 96, 96, 8) 1160 activation_12[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 96, 96, 8) 32 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 96, 96, 8) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 96, 96, 16) 144 activation_13[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 96, 96, 16) 0 res3b_branch2c[0][0]
activation_11[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 48, 48, 16) 272 add_5[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 48, 48, 16) 64 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 48, 48, 16) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 48, 48, 16) 2320 activation_14[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 48, 48, 16) 64 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 48, 48, 16) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 48, 48, 64) 1088 activation_15[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 48, 48, 64) 1088 add_5[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 48, 48, 64) 256 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 48, 48, 64) 256 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 48, 48, 64) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 48, 48, 64) 0 add_6[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 48, 48, 64) 256 activation_16[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, 48, 48, 64) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 48, 48, 16) 9232 activation_17[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 48, 48, 16) 64 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, 48, 48, 16) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 48, 48, 64) 1088 activation_18[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 48, 48, 64) 0 res4b_branch2c[0][0]
activation_16[0][0]
__________________________________________________________________________________________________
res5a_branch2a (Conv2D) (None, 24, 24, 32) 2080 add_7[0][0]
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 24, 24, 32) 128 res5a_branch2a[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 24, 24, 32) 0 bn5a_branch2a[0][0]
__________________________________________________________________________________________________
res5a_branch2b (Conv2D) (None, 24, 24, 32) 9248 activation_19[0][0]
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 24, 24, 32) 128 res5a_branch2b[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 24, 24, 32) 0 bn5a_branch2b[0][0]
__________________________________________________________________________________________________
res5a_branch2c (Conv2D) (None, 24, 24, 128) 4224 activation_20[0][0]
__________________________________________________________________________________________________
res5a_branch1 (Conv2D) (None, 24, 24, 128) 8320 add_7[0][0]
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 24, 24, 128) 512 res5a_branch2c[0][0]
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 24, 24, 128) 512 res5a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 24, 24, 128) 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 24, 24, 128) 0 add_8[0][0]
__________________________________________________________________________________________________
res6a_branch2a (Conv2D) (None, 12, 12, 64) 8256 activation_21[0][0]
__________________________________________________________________________________________________
bn6a_branch2a (BatchNormalizati (None, 12, 12, 64) 256 res6a_branch2a[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 12, 12, 64) 0 bn6a_branch2a[0][0]
__________________________________________________________________________________________________
res6a_branch2b (Conv2D) (None, 12, 12, 64) 36928 activation_22[0][0]
__________________________________________________________________________________________________
bn6a_branch2b (BatchNormalizati (None, 12, 12, 64) 256 res6a_branch2b[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 12, 12, 64) 0 bn6a_branch2b[0][0]
__________________________________________________________________________________________________
res6a_branch2c (Conv2D) (None, 12, 12, 512) 33280 activation_23[0][0]
__________________________________________________________________________________________________
res6a_branch1 (Conv2D) (None, 12, 12, 512) 66048 activation_21[0][0]
__________________________________________________________________________________________________
bn6a_branch2c (BatchNormalizati (None, 12, 12, 512) 2048 res6a_branch2c[0][0]
__________________________________________________________________________________________________
bn6a_branch1 (BatchNormalizatio (None, 12, 12, 512) 2048 res6a_branch1[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 12, 12, 512) 0 bn6a_branch2c[0][0]
bn6a_branch1[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 12, 12, 512) 0 add_9[0][0]
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 512) 0 activation_24[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 512) 0 avg_pool[0][0]
__________________________________________________________________________________________________
FC1 (Dense) (None, 1) 513 dropout_1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, 1) 0 FC1[0][0]
==================================================================================================
Total params: 196,557
Trainable params: 192,867
Non-trainable params: 3,690
一切看起来都正确。但是,当我运行代码时,出现以下错误:
Epoch 1/1
Traceback (most recent call last):
File "C:/Users/ASista162282/Desktop/code/camleyon_17/train.py", line 114, in <module>
training_pipeline()
File "C:/Users/ASista162282/Desktop/code/camleyon_17/train.py", line 71, in training_pipeline
verbose = 1)
File "C:\ProgramData\Miniconda3\lib\site-packages\keras\engine\training.py", line 1705, in fit
validation_steps=validation_steps)
File "C:\ProgramData\Miniconda3\lib\site-packages\keras\engine\training.py", line 1188, in _fit_loop
outs = f(ins)
File "C:\ProgramData\Miniconda3\lib\site-packages\keras\backend\tensorflow_backend.py", line 2478, in __call__
**self.session_kwargs)
File "C:\ProgramData\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 900, in run
run_metadata_ptr)
File "C:\ProgramData\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 1111, in _run
str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape () for Tensor 'input_1:0', which has shape '(?, 389, 389, 3)'
这没有任何意义。我什至在定义模型之前添加了 set_shape 函数,但它仍然显示空形状。
您替换输入层的方式似乎没有正确连接新层。尝试更换这个:
model_single.layers[0] = InputLayer(input_tensor=input_tensor, input_shape=(389, 389, 3))
这样:
from keras.models import Model
model_single.layers.pop(0)
new_input = InputLayer(input_tensor=input_tensor, input_shape=(389, 389, 3))
new_output = model_single(new_input)
model_single = Model(new_input, new_output)