为什么我的图像张量在使用 Spatial Transformer Network Pytorch 处理后输出全为零?

问题描述 投票:0回答:1

我正在开发一个涉及空间变换网络(STN)来处理图像的小项目。我不小心上传了一个带有未经测试的代码的分支,现在我面临着一个问题,我的图像张量输出全为零。

Model Preview

这是我的代码的相关部分:

import torch
from torchvision import transforms
from PIL import Image
from pathlib import Path
from model import STModel
from typing import Union
import numpy as np

class STN:
    """
    Class to handle the processing of a single image using a Spatial Transformer Network (STN).

    Args:
        pretrained (Path): Path to the pre-trained model.
    """

    def __init__(self, pretrained: Union[str, Path]) -> None:
        self.pretrained: Path = pretrained
        self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model: STModel = STModel().to(self.device)
        self.model.load_state_dict(torch.load(self.pretrained, map_location=self.device))
        self.model.eval()

        self.transform: transforms.Compose = transforms.Compose([
            transforms.Resize((150, 120)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust normalization for single-channel input
        ])

    def process_image(self, input_path: Union[str, Path], output_path: Union[str, Path]) -> None:
        """
        Process a single image using the pre-trained model.

        Args:
            input_path (Union[str, Path]): Path to the input image.
            output_path (Union[str, Path]): Path where the output image will be saved.
        """
        input_path: Path = Path(input_path)
        output_path: Path = Path(output_path)
        image: Image.Image = Image.open(input_path).convert('L')  # Ensure the image is in greyscale
        print(f"Loaded image: {input_path}")
        print(f"Image size: {image.size}")
        print(f"Image mode: {image.mode}")

        input_tensor: torch.Tensor = self.transform(image).unsqueeze(0).to(self.device)
        print(f"Transformed tensor shape: {input_tensor.shape}")
        print(f"Transformed tensor min, max: {input_tensor.min().item()}, {input_tensor.max().item()}")

        with torch.no_grad():
            output_tensor: torch.Tensor = self.model(input_tensor)
        print(f"Output tensor shape: {output_tensor.shape}")
        print(f"Output tensor min, max: {output_tensor.min().item()}, {output_tensor.max().item()}")

        output_array = np.array([output_tensor.squeeze().cpu().detach()])

        print(f"Processed and saved output image: {output_path}")
        print(f"Output image content: {output_array}")
        print(f"Output tensor shape: {output_tensor.shape}")

if __name__ == "__main__":
    stn: STN = STN(pretrained="spt_model.pt")
    stn.process_image(
        input_path=Path("dataset/train/aaAGoBxqnJgoEGzD.jpg"),
        output_path=Path("output.jpg")
    )

当我运行代码时,

output_tensor
具有
(0.0, 0.0)
的最小值和最大值,当我打印出
output_array
时,它自然也全为零。这是打印输出的示例:

Loaded image: dataset/train/aaAGoBxqnJgoEGzD.jpg
Image size: (150, 120)
Image mode: L
Transformed tensor shape: torch.Size([1, 1, 150, 120])
Transformed tensor min, max: -2.8582..., 1.5927...
Output tensor shape: torch.Size([1, 1, 150, 120])
Output tensor min, max: 0.0, 0.0
Processed and saved output image: output.jpg
Output image content: [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
Output tensor shape: torch.Size([1, 1, 150, 120])

我怀疑图像的转换方式或模型处理图像的方式可能存在问题。任何关于为什么会发生这种情况以及我如何解决它的见解将不胜感激。

您可以在我的 GitHub 上找到完整的项目代码以获取更多上下文:License Plate STN

python numpy torch torchvision spatial-transformer-network
1个回答
0
投票

这可能是一个幼稚的答案,但只需查看代码的链接即可。首先,我看到您将 fc2 的本地化权重初始化为零在 model.py 中。您的前馈预测为您提供全零,这表明您正在使用 STN 的初始化实例,而不是经过训练的实例。

我在上面的代码中看到,您从光盘加载了预训练的模型。我看不到您在模型类中的何处定义了

pretrained
参数的行为,并且我不认为
nn.Module
采用您定义的方式的
pretrained
路径参数,但它确实需要
kwargs
。由于它不知道如何处理
pretrained
参数,因此这里只是忽略它。当您调用过程映像时,您正在使用新初始化的对象来执行此操作。

考虑加载 STN 模型的 state_dict。我认为这会起作用:

if __name__ == "__main__":
    stn: STN = STN()
    
    save_path = 'spt_model.pt'
    state_dict = torch.load(save_path)
    
    STN.load_state_dict(state_dict)
    ...

另请注意,在您的 main.py 参数中,第 17 行可能存在拼写错误。“spt_mode.pt”。希望您没有意外地将模型保存为错误的名称。

© www.soinside.com 2019 - 2024. All rights reserved.