经过训练的 U-Net 将在 60 个 epoch 后根据参数/权重预测 CT 图像

问题描述 投票:0回答:1

我的 U-Net 将根据 MRI 预测 CT 图像。我已经使用级联 3D MRI/CT 训练和验证了 U-Net,如下所示:

def train_valid_model(X, Y, x_val, y_val):

    s = X.shape
    model = UNet_model_2D(s[1], s[2], 1)  # s[1]=256, s[2]=256 
    
    callbacks = [
        tf.keras.callbacks.TensorBoard(
            log_dir='logs',
            histogram_freq=1,
            write_graph=True,
            write_images=True,
        )
    ]
    
    # return a History object whose attribute '.history ' is a record of  
    # training loss, metrics, validation loss, and validation metrics values
    results = model.fit(
        x=X,  # concatenated 3D MRIs
        y=Y,  # concatenated 3D reference CTs
        batch_size=16, 
        epochs=200,
        verbose=1,
        callbacks=callbacks,
        validation_data = (x_val, y_val),  # concatenated 3D MRIs/CTs
    )
    
    tmp = list(results.history.values())
    
    train_loss=tmp[0][:]  # train loss
    val_loss=tmp[1][:]  # val loss
    
    # write/append csv file
    f = open('log_train_loss_TF_CT.csv', 'a')
    writer = csv.writer(f)
    writer.writerow(train_loss)
    f.close()
    f = open('log_val_loss_TF_CT.csv', 'a')
    writer = csv.writer(f)
    writer.writerow(val_loss)
    f.close()
    
    model.save('pCT_2D_deep_large_batch16', save_format='tf')

查看 TensorBoard 中的损失函数图,我发现在 60 个 epoch 后,进一步收敛和过拟合之间有一个很好的折衷。因此,我现在想要根据 60 个 epoch 后的模型参数/权重,从串联测试 MRI 中预测 CT。我该怎么做?

到目前为止我有以下方法:

# load trained & validated model
model_name = 'pCT_2D_deep_large_batch16'
model = tf.keras.models.load_model(model_name, compile=False)

# load concatenated test MRIs
X_test = nib.load('test_MRIs.nii.gz').get_fdata()

# predict sCTs
predicted_data = model.predict(X_test, verbose=1)

# save predicted sCTs as concatenated NIfTI file
image = nib.Nifti1Image(predicted_data, affine=None)
nib.save(image, 'predicted_sCTs.nii.gz')

在 Spyder 控制台中,出现以下内容:

有什么办法可以停在60/200吗?有人可以帮我吗?

python tensorflow keras deep-learning prediction
1个回答
0
投票

如果唯一保存的模型是在 200 个时期之后,则没有直接的方法来检索 60 个时期的权重。

如果可以重新训练,最简单的解决方案是重新训练,但使用

model.fit(..., epoch=60, ...)
并查看损失大致与之前训练中的情况相同。

此外,Keras 的 ModelCheckpoint 回调可用于在训练期间保存检查点,以便中间模型权重不会丢失。

© www.soinside.com 2019 - 2024. All rights reserved.