GAN 中的训练问题

问题描述 投票:0回答:0

我一直在尝试在高光谱数据上训练 GAN。数据 (Indian Pines) 的大小为 145x145x200,在执行 PCA (K=3) 后已分解为 64x64x3 的块。生成器在接收到 100x1x1 潜在暗淡后生成此大小的图像,而鉴别器接收真实类标签并生成图像或真实图像块以提供两个单独的输出-

  1. 乙状结肠;对真实/假图像进行分类。
  2. 软最大;对图像的类别进行分类。

注意-数据集中有 16 个类别,我为生成的图像分配了第 17 个类别。

这是一个 colab 副本,以便于访问

使用的数据预处理实用程序:

def applyPCA(X, numComponents=75):
    newX = np.reshape(X, (-1, X.shape[2]))
    pca = PCA(n_components=numComponents, whiten=True)
    newX = pca.fit_transform(newX)
    newX = np.reshape(newX, (X.shape[0],X.shape[1], numComponents))
    return newX, pca

def padWithZeros(X, margin=2):
    newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))
    x_offset = margin
    y_offset = margin
    newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X
    return newX

def createImageCubes(X, y, windowSize=5, removeZeroLabels = True):
    margin = int((windowSize) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)
    # split patches
    patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]))
    patchesLabels = np.zeros((X.shape[0] * X.shape[1]))
    patchIndex = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin, c - margin:c + margin]   
            patchesData[patchIndex, :, :, :] = patch
            patchesLabels[patchIndex] = y[r-margin, c-margin]
            patchIndex = patchIndex + 1
    if removeZeroLabels:
        patchesData = patchesData[patchesLabels>0,:,:,:]
        patchesLabels = patchesLabels[patchesLabels>0]
        patchesLabels -= 1
    return patchesData, patchesLabels

def splitTrainTestSet(X, y, testRatio, randomState=345):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testRatio, random_state=randomState, stratify=y)
    return X_train, X_test, y_train, y_test

模型架构和训练循环:

def build_generator(latent_dim=(100, 1, 1)):

    model = Sequential()
    model.add(Dense(2 * 13 * 13, input_shape=latent_dim))
    model.add(Reshape((13, 13, 200)))
    # model.add(Dropout(0.2))

    model.add(Conv2DTranspose(512, kernel_size=4, strides=1, padding='valid'))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(Dropout(0.2))

    model.add(Conv2DTranspose(256, kernel_size=4, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(Dropout(0.2))

    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(Dropout(0.2))

    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(Dropout(0.2))
    
    model.add(Conv2D(3, kernel_size=4, strides=2, padding='same', activation='tanh'))
    
    return model

generator = build_generator()
generator.compile(
    loss='binary_crossentropy', optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
)
generator.summary()

from keras import backend as be

def build_discriminator_net(img_shape=(64,64,3), num_classes=17):

    img_input = Input(shape=img_shape)
    label_input = Input(shape=(1,), dtype='int32')

    label_embedding = Flatten()(Embedding(num_classes+1, np.prod(img_shape))(label_input))
    label_embedding = Reshape(img_shape)(label_embedding)
    x = Concatenate()([img_input, label_embedding])

    x = Conv2D(64, kernel_size=4, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Dropout(0.5)(x)

    x = Conv2D(128, kernel_size=4, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Dropout(0.5)(x)

    x = Conv2D(256, kernel_size=4, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Dropout(0.5)(x)

    x = Conv2D(512, kernel_size=4, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Dropout(0.5)(x)

    x = Conv2D(64, kernel_size=4, strides=1, padding='valid')(x)
    x = Flatten()(x)

    real_or_fake = Dense(2, activation='sigmoid', name='real_or_fake')(x)
    class_label = Dense(num_classes, activation='softmax', name='class_label')(x)

    model = Model(inputs=[img_input, label_input], outputs=[real_or_fake, class_label])

    return model

discriminator = build_discriminator_net()
discriminator.compile(
    loss=['binary_crossentropy', 'categorical_crossentropy'],
    optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    metrics=["accuracy"],
)
discriminator.summary()

gan_input_noise = keras.Input(shape=latent_dim)
gan_input_label = keras.Input(shape=(1,))
gan_output = discriminator([generator(gan_input_noise), gan_input_label])
gan = keras.Model(inputs=[gan_input_noise, gan_input_label], outputs=gan_output, name="gan")

gan.compile(
    loss=['binary_crossentropy', 'categorical_crossentropy'],
    optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
)

epochs = 500
batch_size = 64

def preprocess_images(images):
    images = (images - 127.5) / 127.5 
    return images

x_train = preprocess_images(Xtrain)
y_train = to_categorical(ytrain, num_classes+1)

def batch(batch_size=batch_size, x_train=x_train, y_train=y_train):
        idx = np.random.randint(0, Xtrain.shape[0], batch_size)
        imgs = x_train[idx]
        labels = y_train[idx]
        return imgs, labels

for epoch in range(epochs):

    real_images, real_labels = batch()

    noise = tf.random.normal([batch_size, 100, 1, 1])
    fake_images = generator.predict(noise)
    fake_labels = np.ones((batch_size, 17)) * 17

    y_real = np.ones((batch_size, 1), dtype=np.int32)
    y_fake = np.zeros((batch_size, 1), dtype=np.int32)

    d_loss_real = discriminator.train_on_batch([real_images, real_labels], [y_real, real_labels])
    d_loss_fake = discriminator.train_on_batch([fake_images, fake_labels], [y_fake, fake_labels])
    d_loss = d_loss_real + d_loss_fake

    noise = tf.random.normal([batch_size, 100, 1, 1])
    gen_labels = np.ones((batch_size, 1)) * 17
    g_loss = gan.train_on_batch(noise, gen_labels)

    if epoch % 100 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")
tensorflow machine-learning deep-learning tensorflow2.0 generative-adversarial-network
© www.soinside.com 2019 - 2024. All rights reserved.