tensorflow data api with keras(将张量传递给keras模型)

问题描述 投票:0回答:1

我正在尝试根据新数据训练预训练的 keras 模型。我遇到了tensorflow的数据集api,我正在尝试将它与我的旧keras模型一起使用。据我所知,tf data api 返回张量,因此数据 api 和模型应该是同一图的一部分,并且数据 api 的输出应该作为模型的输入连接。这是代码

import tensorflow as tf   
from data_pipeline import ImageDataGenerator
import os
import keras
from keras.engine import InputLayer

os.environ["CUDA_VISIBLE_DEVICES"]="0"
###################### to check visible devices ###############
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
###############################################################

_EPOCHS      = 10
_NUM_CLASSES = 2
_BATCH_SIZE  = 32


def training_pipeline():
  # #############
  # Load Dataset
  # #############
  training_set = ImageDataGenerator(directory="\\\\in-pdc-sem2\\training",
                                  horizontal_flip=True, vertical_flip=True, rescale=True, normalize=True,
                                  color_jitter=True, batch_size=_BATCH_SIZE,
                                  num_cpus=8, epochs=60, output_patch_size=389, validation=False).dataset_pipeline()
  testing_set = ImageDataGenerator(directory="\\\\in-pdc-sem2\\training",
                                  horizontal_flip=False, vertical_flip=False, rescale=False, normalize=True,
                                  color_jitter=False, batch_size=_BATCH_SIZE,
                                  num_cpus=8, epochs=60, output_patch_size=389, validation=True).dataset_pipeline()

  print(training_set.output_types, training_set.output_shapes)

  iterator = tf.data.Iterator.from_structure(training_set.output_types, training_set.output_shapes)#((None, 389, 389, 3), (None)))

  train_initializer = iterator.make_initializer(training_set)
  validation_initializer = iterator.make_initializer(testing_set)

  img, labels = iterator.get_next()
  img = img.set_shape((None, 389, 389, 3))

  model = baseline_model(img, labels)  # keras model defined here
  model.summary()

  keras.backend.get_session().run(tf.global_variables_initializer())
  for epoch in range(_EPOCHS):

      # #############
      # Train Model
      # #############
      keras.backend.get_session().run(train_initializer)
      model.fit(
          steps_per_epoch=1000000 // _BATCH_SIZE,
          epochs=1,
          # validation_steps=11970 // _BATCH_SIZE,
          callbacks=callbacks(),
          verbose = 1)

      keras.backend.get_session().run(validation_initializer)

      loss, acc, cross_entropy = model.evaluate(verbose=1, steps=11970 // 32)
      filepath = "./weights/ResNet_16_Best/weights-improvement-Run1-" + str(epoch) + "-" + str(loss) + ".hdf5"
      model.save_weights(filepath, overwrite=True)


def baseline_model(input_tensor, labels):
    jsonFile = '\\\\in-pdc-sem2\\resnetV4_2Best.json'
    weightsFile = '\\\\in-pdc-sem1\\resnetV4_2BestWeightsOnly.hdf5'
    with open(jsonFile, "r") as file:
        jsonDef = file.read()
    from keras.models import model_from_json
    model_single = model_from_json(jsonDef)

    model_single.load_weights(weightsFile)
    model_single.layers[0] = InputLayer(input_tensor=input_tensor, input_shape=(389, 389, 3))
    model_single.compile(target_tensors=[labels], loss='categorical_crossentropy', optimizer='Adam', metrics=[keras.metrics.categorical_accuracy])
    return model_single

def callbacks():
    tensorboard = keras.callbacks.TensorBoard(log_dir='./tensorboard', write_grads=False, write_images=False, histogram_freq=0)
    callbacks_list = [tensorboard]
    return callbacks_list

if __name__ == '__main__':
    training_pipeline()

“训练集”返回图像和标签元组,图像是形状为 (32, 389, 389, 3) 的张量,它是一批 32 张图像。我在单独的脚本中验证了形状,它是正确的。我使用 model.compile 部分中的张量和目标张量定义模型的输入层。

这就是 model.summary 输出的样子:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 389, 389, 3)  0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 383, 383, 13) 1924        input_1[0][0]                    
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 383, 383, 13) 52          conv1[0][0]                      
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 383, 383, 13) 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 191, 191, 13) 0           activation_1[0][0]               
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 191, 191, 4)  56          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 191, 191, 4)  16          res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 191, 191, 4)  0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 191, 191, 4)  148         activation_2[0][0]               
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 191, 191, 4)  16          res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 191, 191, 4)  0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 191, 191, 8)  40          activation_3[0][0]               
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 191, 191, 8)  112         max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 191, 191, 8)  32          res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 191, 191, 8)  32          res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, 191, 191, 8)  0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 191, 191, 8)  0           add_1[0][0]                      
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 191, 191, 8)  32          activation_4[0][0]               
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 191, 191, 8)  0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 191, 191, 4)  292         activation_5[0][0]               
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 191, 191, 4)  16          res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 191, 191, 4)  0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 191, 191, 8)  40          activation_6[0][0]               
__________________________________________________________________________________________________
add_2 (Add)                     (None, 191, 191, 8)  0           res2b_branch2c[0][0]             
                                                                 activation_4[0][0]               
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 191, 191, 8)  32          add_2[0][0]                      
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 191, 191, 8)  0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 191, 191, 4)  292         activation_7[0][0]               
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 191, 191, 4)  16          res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 191, 191, 4)  0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 191, 191, 8)  40          activation_8[0][0]               
__________________________________________________________________________________________________
add_3 (Add)                     (None, 191, 191, 8)  0           res2c_branch2c[0][0]             
                                                                 add_2[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 96, 96, 8)    72          add_3[0][0]                      
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 96, 96, 8)    32          res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 96, 96, 8)    0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 96, 96, 8)    584         activation_9[0][0]               
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 96, 96, 8)    32          res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 96, 96, 8)    0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 96, 96, 16)   144         activation_10[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 96, 96, 16)   144         add_3[0][0]                      
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 96, 96, 16)   64          res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 96, 96, 16)   64          res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_4 (Add)                     (None, 96, 96, 16)   0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 96, 96, 16)   0           add_4[0][0]                      
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 96, 96, 16)   64          activation_11[0][0]              
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 96, 96, 16)   0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 96, 96, 8)    1160        activation_12[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 96, 96, 8)    32          res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 96, 96, 8)    0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 96, 96, 16)   144         activation_13[0][0]              
__________________________________________________________________________________________________
add_5 (Add)                     (None, 96, 96, 16)   0           res3b_branch2c[0][0]             
                                                                 activation_11[0][0]              
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 48, 48, 16)   272         add_5[0][0]                      
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 48, 48, 16)   64          res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 48, 48, 16)   0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 48, 48, 16)   2320        activation_14[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 48, 48, 16)   64          res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 48, 48, 16)   0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 48, 48, 64)   1088        activation_15[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 48, 48, 64)   1088        add_5[0][0]                      
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 48, 48, 64)   256         res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 48, 48, 64)   256         res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_6 (Add)                     (None, 48, 48, 64)   0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 48, 48, 64)   0           add_6[0][0]                      
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 48, 48, 64)   256         activation_16[0][0]              
__________________________________________________________________________________________________
activation_17 (Activation)      (None, 48, 48, 64)   0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 48, 48, 16)   9232        activation_17[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 48, 48, 16)   64          res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 48, 48, 16)   0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 48, 48, 64)   1088        activation_18[0][0]              
__________________________________________________________________________________________________
add_7 (Add)                     (None, 48, 48, 64)   0           res4b_branch2c[0][0]             
                                                                 activation_16[0][0]              
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 24, 24, 32)   2080        add_7[0][0]                      
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 24, 24, 32)   128         res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 24, 24, 32)   0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 24, 24, 32)   9248        activation_19[0][0]              
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 24, 24, 32)   128         res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 24, 24, 32)   0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 24, 24, 128)  4224        activation_20[0][0]              
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 24, 24, 128)  8320        add_7[0][0]                      
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 24, 24, 128)  512         res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 24, 24, 128)  512         res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_8 (Add)                     (None, 24, 24, 128)  0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 24, 24, 128)  0           add_8[0][0]                      
__________________________________________________________________________________________________
res6a_branch2a (Conv2D)         (None, 12, 12, 64)   8256        activation_21[0][0]              
__________________________________________________________________________________________________
bn6a_branch2a (BatchNormalizati (None, 12, 12, 64)   256         res6a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 12, 12, 64)   0           bn6a_branch2a[0][0]              
__________________________________________________________________________________________________
res6a_branch2b (Conv2D)         (None, 12, 12, 64)   36928       activation_22[0][0]              
__________________________________________________________________________________________________
bn6a_branch2b (BatchNormalizati (None, 12, 12, 64)   256         res6a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 12, 12, 64)   0           bn6a_branch2b[0][0]              
__________________________________________________________________________________________________
res6a_branch2c (Conv2D)         (None, 12, 12, 512)  33280       activation_23[0][0]              
__________________________________________________________________________________________________
res6a_branch1 (Conv2D)          (None, 12, 12, 512)  66048       activation_21[0][0]              
__________________________________________________________________________________________________
bn6a_branch2c (BatchNormalizati (None, 12, 12, 512)  2048        res6a_branch2c[0][0]             
__________________________________________________________________________________________________
bn6a_branch1 (BatchNormalizatio (None, 12, 12, 512)  2048        res6a_branch1[0][0]              
__________________________________________________________________________________________________
add_9 (Add)                     (None, 12, 12, 512)  0           bn6a_branch2c[0][0]              
                                                                 bn6a_branch1[0][0]               
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 12, 12, 512)  0           add_9[0][0]                      
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 512)          0           activation_24[0][0]              
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 512)          0           avg_pool[0][0]                   
__________________________________________________________________________________________________
FC1 (Dense)                     (None, 1)            513         dropout_1[0][0]                  
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 1)            0           FC1[0][0]                        
==================================================================================================
Total params: 196,557
Trainable params: 192,867
Non-trainable params: 3,690

一切看起来都正确。但是,当我运行代码时,出现以下错误:

Epoch 1/1
Traceback (most recent call last):
  File "C:/Users/ASista162282/Desktop/code/camleyon_17/train.py", line 114, in <module>
    training_pipeline()
  File "C:/Users/ASista162282/Desktop/code/camleyon_17/train.py", line 71, in training_pipeline
    verbose = 1)
  File "C:\ProgramData\Miniconda3\lib\site-packages\keras\engine\training.py", line 1705, in fit
    validation_steps=validation_steps)
  File "C:\ProgramData\Miniconda3\lib\site-packages\keras\engine\training.py", line 1188, in _fit_loop
    outs = f(ins)
  File "C:\ProgramData\Miniconda3\lib\site-packages\keras\backend\tensorflow_backend.py", line 2478, in __call__
    **self.session_kwargs)
  File "C:\ProgramData\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 900, in run
    run_metadata_ptr)
  File "C:\ProgramData\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 1111, in _run
    str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape () for Tensor 'input_1:0', which has shape '(?, 389, 389, 3)'

这没有任何意义。我什至在定义模型之前添加了 set_shape 函数,但它仍然显示空形状。

tensorflow keras tensorflow-datasets
1个回答
1
投票

您替换输入层的方式似乎没有正确连接新层。尝试更换这个:

model_single.layers[0] = InputLayer(input_tensor=input_tensor, input_shape=(389, 389, 3))

这样:

from keras.models import Model
model_single.layers.pop(0)
new_input = InputLayer(input_tensor=input_tensor, input_shape=(389, 389, 3))
new_output = model_single(new_input)
model_single = Model(new_input, new_output)
© www.soinside.com 2019 - 2024. All rights reserved.