我正在尝试提交 Kaggle 竞赛。并发现了关于 CycleGAN 的非常有趣的paper。现在我正在尝试使用 TensorFlow 来实现这个:
def build_cyclegan(
generator_g, generator_f, discriminator_x, discriminator_y, lambda_cycle=10
):
real_x = tf.keras.layers.Input(shape=(256, 256, 3), name="x_real")
real_y = tf.keras.layers.Input(shape=(256, 256, 3), name="y_real")
fake_y = generator_g(real_x)
cycled_x = generator_f(fake_y)
fake_x = generator_f(real_y)
cycled_y = generator_g(fake_x)
disc_real_x = discriminator_x(real_x)
disc_real_y = discriminator_y(real_y)
disc_fake_x = discriminator_x(fake_x)
disc_fake_y = discriminator_y(fake_y)
cycle_gan = tf.keras.Model(
inputs=[real_x, real_y],
outputs=[
disc_real_x,
disc_real_y,
cycled_x,
cycled_y,
disc_fake_x,
disc_fake_y,
],
)
cycle_loss_10 = partial(cycle_consistency_loss, LAMBDA=lambda_cycle)
cycle_gan.compile(
loss=[
addversarial_loss_discriminator,
addversarial_loss_discriminator,
addversarial_loss_generator,
addversarial_loss_generator,
cycle_loss_10,
cycle_loss_10,
],
optimizer=tf.keras.optimizers.Adam(2e-4, beta_1=0.5),
)
return cycle_gan
我遇到了一个奇怪的异常,它表示我的模型需要 2 个输入,但我只传递 1 个输入。这很奇怪,因为我试图将 2 个图像传递到模型中image
这是我如何从
tfrec
文件获取数据:
import tensorflow as tf
def _parse_image_function(example_proto):
feature_description = {
"image": tf.io.FixedLenFeature([], tf.string),
}
parsed_features = tf.io.parse_single_example(example_proto, feature_description)
image = tf.io.decode_jpeg(parsed_features["image"], channels=3)
image = tf.reshape(image, (256, 256, 3))
return image
def create_dataset(filenames, repeat=True):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_image_function)
if repeat:
dataset = dataset.repeat()
return dataset
photo_filenames = tf.data.Dataset.list_files(photos_path + "/*.tfrec")
monet_filenames = tf.data.Dataset.list_files(monet_path + "/*.tfrec")
photo_dataset = create_dataset(photo_filenames)
monet_dataset = create_dataset(monet_filenames)
combined_dataset = tf.data.Dataset.zip((photo_dataset, monet_dataset))
combined_dataset = combined_dataset.shuffle(
buffer_size=1000
)
batch_size = 1
dataset = combined_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
我不知道我做错了什么。我认为这个问题是因为使用 tf.data.Dataset.zip 方法,但在互联网上的任何地方我都发现这是正确的,所以我什至不知道该看什么(
如果没有错误消息,很难回答这个问题。 您是说模型在编译步骤失败了吗?或者当您尝试之后传递数据时?
您可以展示完整的通话内容吗?