Perceptual_loss函数_扩散模型

问题描述 投票:0回答:1

我尝试计算扩散模型中生成的图像和真实图像之间的感知损失函数(我将其用于图像到图像的转换,图像为灰度)。 这是损失函数的代码:

from .custom_loss import common

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class VGGPerceptualLoss(nn.Module):
    def __init__(self, conv_index, rgb_range=1):
        super(VGGPerceptualLoss, self).__init__()
        vgg_features = models.vgg19(pretrained=True).features
        modules = [m for m in vgg_features]
        if conv_index.find('22') >= 0:
            self.vgg = nn.Sequential(*modules[:8])
        elif conv_index.find('54') >= 0:
            self.vgg = nn.Sequential(*modules[:35])

        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, sr, hr):
        def _forward(x):
            x = self.sub_mean(x)
            x = self.vgg(x)
            return x
            
        vgg_sr = _forward(sr[0])
        with torch.no_grad():
            vgg_hr = _forward(hr[0].detach())

        loss = F.mse_loss(vgg_sr, vgg_hr)

        return loss

这是扩散模型的一部分:

    def forward_backward(self, batch, cond):
        self.mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }

            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            compute_losses = functools.partial(
                self.diffusion.training_losses_segmentation,
                self.ddp_model,
                self.classifier,
                self.prior,
                self.posterior,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses1 = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses1 = compute_losses()

            

        
            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach())

            losses = losses1[0]
            sample = losses1[1] 
            conv_index='22'
            perceptual_loss_fn = VGGPerceptualLoss(conv_index=conv_index) 
            perceptual_loss = perceptual_loss_fn(sample, batch)
            loss = (losses["loss"] * weights).mean()+(perceptual_loss * self.perceptual_loss_weight)
            # tensor_losses = torch.cat((losses["loss"],perceptual_loss))
            # loss = (tensor_losses* weights).mean()
            # print(f'weights: {weights}')
            # print(f'loss: {losses["loss"]}')
            # loss = (losses["loss"] * weights).mean()
            lossseg = (losses["mse"] * weights).mean().detach()
            losscls = (losses["vb"] * weights).mean().detach()
            lossrec = loss * 0

            log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()})
            self.mp_trainer.backward(loss)
            return lossseg.detach(), losscls.detach(), lossrec.detach(), sample

这是错误:请问我该如何修复它?

/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: 参数 'pretrained' 自 0.13 起已弃用,将来可能会被删除,请使用 'weights'反而。 警告.警告( /usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: 自 0.13 起,除权重枚举或

None
之外的参数已被弃用,并且可能会在未来。当前的行为相当于传递
weights=VGG19_Weights.IMAGENET1K_V1
。您还可以使用
weights=VGG19_Weights.DEFAULT
获取最新的权重。 警告.warn(msg) 回溯(最近一次调用最后一次): 文件“/content/drive/MyDrive/seg/seg/scripts/segmentation_train.py”,第 147 行,位于 主要的() 文件“/content/drive/MyDrive/seg/seg/scripts/segmentation_train.py”,第 118 行,在 main 中 ).run_loop() 文件“/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py”,第 195 行,在 run_loop 中 self.run_step(批处理,条件) 文件“/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py”,第 220 行,在 run_step 中 lossseg、losscls、lossrec、样本 = self.forward_backward(batch、cond) 文件“/content/drive/MyDrive/seg/seg/./guided_diffusion/train_util.py”,第267行,在forward_backward中 perceptual_loss = perceptual_loss_fn(样本, 批次) 文件“/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”,第 1518 行,在 _wrapped_call_impl 中 返回 self._call_impl(*args, **kwargs) 文件“/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”,第 1527 行,在 _call_impl 中 返回forward_call(*args, **kwargs) 文件“/content/drive/MyDrive/seg/seg/./guided_diffusion/vgg.py”,第 30 行,向前 vgg_sr = _forward(sr[0]) 文件“/content/drive/MyDrive/seg/seg/./guided_diffusion/vgg.py”,第 26 行,位于 _forward x = self.sub_mean(x) 文件“/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”,第 1518 行,在 _wrapped_call_impl 中 返回 self._call_impl(*args, **kwargs) 文件“/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py”,第 1527 行,在 _call_impl 中 返回forward_call(*args, **kwargs) 文件“/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py”,第 460 行,向前 返回 self._conv_forward(输入, self.weight, self.bias) 文件“/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py”,第 456 行,位于 _conv_forward 返回 F.conv2d(输入, 权重, 偏差, self.stride, RuntimeError:给定 groups=1,权重大小为 [3, 3, 1, 1],预期输入 [1, 1, 128, 128] 有 3 个通道,但只有 1 个通道

python pytorch loss-function
1个回答
0
投票

在你的

VGGPerceptualLoss

vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)

该函数只能接受RGB通道。

但是,在您的错误消息中,

RuntimeError: Given groups=1, weight of size [3, 3, 1, 1], expected input[1, 1, 128, 128] to have 3 channels, but got 1 channels instead

您输入张量大小为

(1, 1, 128, 128)
的灰度图像。

因此,您可以使用以下代码来打开您的灰度图像数据,

from PIL import Image
img = Image.open("[your training image]").convert("RGB")

您的灰度图像可以转换为3通道的RGB格式。

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.