层需要 2 个输入,但它收到了 1 个

问题描述 投票:0回答:1

我正在尝试提交 Kaggle 竞赛。并发现了关于 CycleGAN 的非常有趣的paper。现在我正在尝试使用 TensorFlow 来实现这个:

def build_cyclegan(
    generator_g, generator_f, discriminator_x, discriminator_y, lambda_cycle=10
):
    real_x = tf.keras.layers.Input(shape=(256, 256, 3), name="x_real")
    real_y = tf.keras.layers.Input(shape=(256, 256, 3), name="y_real")

    fake_y = generator_g(real_x)
    cycled_x = generator_f(fake_y)

    fake_x = generator_f(real_y)
    cycled_y = generator_g(fake_x)

    disc_real_x = discriminator_x(real_x)
    disc_real_y = discriminator_y(real_y)

    disc_fake_x = discriminator_x(fake_x)
    disc_fake_y = discriminator_y(fake_y)

    cycle_gan = tf.keras.Model(
        inputs=[real_x, real_y],
        outputs=[
            disc_real_x,
            disc_real_y,
            cycled_x,
            cycled_y,
            disc_fake_x,
            disc_fake_y,
        ],
    )

    cycle_loss_10 = partial(cycle_consistency_loss, LAMBDA=lambda_cycle)

    cycle_gan.compile(
        loss=[
            addversarial_loss_discriminator,
            addversarial_loss_discriminator,
            addversarial_loss_generator,
            addversarial_loss_generator,
            cycle_loss_10,
            cycle_loss_10,
        ],
        optimizer=tf.keras.optimizers.Adam(2e-4, beta_1=0.5),
    )

    return cycle_gan

我遇到了一个奇怪的异常,它表示我的模型需要 2 个输入,但我只传递 1 个输入。这很奇怪,因为我试图将 2 个图像传递到模型中image

这是我如何从

tfrec
文件获取数据:

import tensorflow as tf


def _parse_image_function(example_proto):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)
    image = tf.io.decode_jpeg(parsed_features["image"], channels=3)
    image = tf.reshape(image, (256, 256, 3))
    return image


def create_dataset(filenames, repeat=True):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(_parse_image_function)
    if repeat:
        dataset = dataset.repeat()
    return dataset


photo_filenames = tf.data.Dataset.list_files(photos_path + "/*.tfrec")
monet_filenames = tf.data.Dataset.list_files(monet_path + "/*.tfrec")

photo_dataset = create_dataset(photo_filenames)
monet_dataset = create_dataset(monet_filenames)

combined_dataset = tf.data.Dataset.zip((photo_dataset, monet_dataset))

combined_dataset = combined_dataset.shuffle(
    buffer_size=1000
)

batch_size = 1
dataset = combined_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

我不知道我做错了什么。我认为这个问题是因为使用 tf.data.Dataset.zip 方法,但在互联网上的任何地方我都发现这是正确的,所以我什至不知道该看什么(

python tensorflow artificial-intelligence
1个回答
0
投票

如果没有错误消息,很难回答这个问题。 您是说模型在编译步骤失败了吗?或者当您尝试之后传递数据时?

您可以展示完整的通话内容吗?

© www.soinside.com 2019 - 2024. All rights reserved.