我想使用图像和表格数据构建一个多模态深度学习模型。
我通过generator加载数据构建Keras子类化模型,但是在开始模型拟合的时候出现错误
当我使用函数式 API 时它运行良好,但当我使用子类化时它不起作用。
我不知道怎么了。请帮助我。
我会附上我的代码和错误。
# data generator
# In a data frame, the first column is the image path, the last column is the label,
# and the remaining column is the table data.
def generator_train():
for item in df_train.values:
yield (item[0], item[1:-1], item[-1])
def generator_valid():
for item in df_valid.values:
yield (item[0], item[1:-1], item[-1])
dataset_train = tf.data.Dataset.from_generator(
generator_train,
(tf.string, tf.float32, tf.int32),
((), (23,), ())
)
dataset_valid = tf.data.Dataset.from_generator(
generator_valid,
(tf.string, tf.float32, tf.int32),
((), (23,), ())
)
def augmentation(image):
h, w = image.shape[0], image.shape[1]
size = h if h >= w else w
dst = pad_to_bounding_box(image, int((size-h)/2), int((size-w)/2), size, size)
dst = resize(dst, size=(224, 224))
dst = random_flip_left_right(dst)
dst = tf.cast(dst, tf.float32)
dst /= 255
return dst
def preprocessing(path, category, label):
label = tf.one_hot(label, 2)
bin = tf.io.read_file(path)
image = tf.io.decode_png(bin, channels=3)
image= tf.py_function(augmentation, [image], [tf.float32])
image = tf.squeeze(image)
return (image, category), label
dt = dataset_train.map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dt = dt.batch(32).prefetch(3)
dv = dataset_valid.map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dv = dv.batch(32).prefetch(3)
# modeling
def CustomDense(channels):
model = tf.keras.models.Sequential([
Dense(channels),
BatchNormalization(),
Activation('relu'),
Dropout(0.5)
])
return model
# subclassing model
class ImageFeature(tf.keras.Model):
def __init__(self):
super(ImageFeature, self).__init__()
self.backbone = tf.keras.applications.resnet50.ResNet50(include_top=False, input_shape=[224, 224, 3], weights='imagenet')
self.backbone.trainable = True
self.fc1 = CustomDense(1024)
self.fc2 = CustomDense(512)
self.fc3 = CustomDense(256)
def call(self, input):
print(input)
print('image shape:', input.shape)
x = self.backbone(input)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
class TabularFeature(tf.keras.Model):
def __init__(self):
super(TabularFeature, self).__init__()
self.fc1 = CustomDense(128)
self.fc2 = CustomDense(256)
def call(self, input):
x = self.fc1(input)
x = self.fc2(x)
return x
class Classification(tf.keras.Model):
def __init__(self, n_class):
super(Classification, self).__init__()
self.ImageFeature = ImageFeature()
self.TabularFeature = TabularFeature()
self.fc1 = CustomDense(64)
self.fc2 = CustomDense(16)
self.cls = Dense(n_class, activation='softmax')
def call(self, input):
img_feature = self.ImageFeature(input[0])
tabular_feature = self.TabularFeature(input[1])
feature = concatenate([img_feature, tabular_feature], dim=-1)
x = self.fc1(feature)
x = self.fc2(x)
output = self.cls(x)
return output
mc = tf.keras.callbacks.ModelCheckpoint("mymodel/save/path", monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
es = tf.keras.callbacks.EarlyStopping(monitor="val_accuracy", mode="max", verbose=1, patience=10)
model = Classification(2)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy', tf.keras.metrics.Recall()])
history = model.fit((dt), epochs=10, validation_data=(dv), callbacks=[mc, es])
# error
TypeError Traceback (most recent call last)
<ipython-input-81-59113bfbc1ce> in <module>
----> 1 history = model.fit((dt), epochs=10, validation_data=(dv), callbacks=[mc, es])
2 frames
/tmp/__autograph_generated_file_38rxzpj.py in tf__call(self, input)
10 img_feature = ag__.converted_call(ag__.ld(self).ImageFeature, (ag__.ld(input)[0],), None, fscope)
11 tabular_feature = ag__.converted_call(ag__.ld(self).TabularFeature, (ag__.ld(input)[1],), None, fscope)
---> 12 feature = ag__.converted_call(ag__.ld(concatenate), ([ag__.ld(img_feature), ag__.ld(tabular_feature)],), dict(dim=(- 1)), fscope)
13 x = ag__.converted_call(ag__.ld(self).fc1, (ag__.ld(feature),), None, fscope)
14 x = ag__.converted_call(ag__.ld(self).fc2, (ag__.ld(x),), None, fscope)
TypeError: in user code:
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1249, in train_function *
return step_function(self, iterator)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1233, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1222, in run_step **
outputs = model.train_step(data)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1023, in train_step
y_pred = self(x, training=True)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_file_38rxzpj.py", line 12, in tf__call
feature = ag__.converted_call(ag__.ld(concatenate), ([ag__.ld(img_feature), ag__.ld(tabular_feature)],), dict(dim=(- 1)), fscope)
TypeError: Exception encountered when calling layer 'classification_13' (type Classification).
in user code:
File "<ipython-input-44-d6336779b263>", line 52, in call *
feature = concatenate([img_feature, tabular_feature], dim=-1)
File "/usr/local/lib/python3.8/dist-packages/keras/layers/merging/concatenate.py", line 231, in concatenate **
return Concatenate(axis=axis, **kwargs)(inputs)
File "/usr/local/lib/python3.8/dist-packages/keras/layers/merging/concatenate.py", line 89, in __init__
super().__init__(**kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/layers/merging/base_merge.py", line 36, in __init__
super().__init__(**kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 340, in __init__
generic_utils.validate_kwargs(kwargs, allowed_kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/generic_utils.py", line 515, in validate_kwargs
raise TypeError(error_message, kwarg)
TypeError: ('Keyword argument not understood:', 'dim')
Call arguments received by layer 'classification_13' (type Classification):
• input=('tf.Tensor(shape=(None, 224, 224, 3), dtype=float32)', 'tf.Tensor(shape=(None, 23), dtype=float32)')
数据框中,第一列是图片路径,最后一列是标签,剩下一列是表格数据。