调用torchvision.transforms.Normalize并转换为PIL.Image时,大于1的值如何处理?

问题描述 投票:0回答:1

应用 torchvision.Normalize 应导致值低于 0 且高于 1,因此当切换到整数值时,值应低于 0 和 255。

然而,当观察这些值时,情况似乎并非如此。这是怎么处理的?

请在下面找到代码示例来重现该问题。

对于一些上下文,我试图将使用 onnx 的神经网络集成到 C++ 代码中,但我无法重现我的 python 结果,因为低于 0 和高于 1 的值被剪裁。

from PIL import Image
from torchvision import transforms




def make_transforms(normalize: bool = True) -> transforms.Compose:    
    results = [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]

    if normalize:
        results.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))

    return transforms.Compose(results)


def main() -> float:

    resize_a_white_image_without_normalization() # works as expected

    resize_a_white_image_WITH_normalization() # how are value > 255 handled???


def resize_a_white_image_WITH_normalization():
    
    # given a given image 
    white = (255, 255, 255)
    white_img = Image.new("RGB", (300, 300), white)

    # when resizing the image, and normalizing it
    resized_img = make_transforms(normalize=True)(white_img)
    resized_img_pil = transforms.ToPILImage()(resized_img)

    # expected normalized value 
    normalized_val = [(1.0 - 0.485) / 0.229, (1.0 - 0.456) / 0.224, (1 - 0.406) / 0.225]
    normalized_val_int = [int(i * 255) for i in normalized_val]

    print(normalized_val) # [2.2489082969432315, 2.428571428571429, 2.6399999999999997] > 1 ??
    print(normalized_val_int)                                   # [573, 619, 673] > 255??
    print([between_0_and_255(i) for i in normalized_val_int])   # [63, 109, 163]
    print(np.array(resized_img_pil)[0,0])   # [ 61 107 161] ???? still different from above!


def resize_a_white_image_without_normalization() -> float:

    # given a given image 
    white = (255, 255, 255)
    white_img = Image.new("RGB", (300, 300), white)


    # when resizing the image, but not normalizing it, 
    resized_img = make_transforms(normalize=False)(white_img)
    # then all pixels should remain white
    assert (np.array(resized_img) == np.ones_like(np.array(resized_img))).all()



def between_0_and_255(value: int):
    return value % 255


if __name__ == "__main__":
    main()
python pytorch python-imaging-library torch torchvision
1个回答
1
投票

存在裁剪,因为您有 8 位 PIL 图像

查看ToPILImage的文档,我们可以看到根据张量选择不同的PIL图像模式,在您的情况下:

如果输入有3个通道,则假定模式为RGB。

以及 来自 PIL 文档

RGB(3x8 位像素,真彩色)

因此,您的张量会自动转换为像素值范围为 0..255(8 位)的图像。因此,当值转换为 PIL 图像时,它们会被剪裁

你自己的裁剪函数有错误

def between_0_and_255(value: int):
    return value % 255

应该是

def between_0_and_255(value: int):
    return value % 256

这样,输出就变成了

[2.2489082969432315, 2.428571428571429, 2.6399999999999997]
[573, 619, 673]
[61, 107, 161]
[ 61 107 161]
© www.soinside.com 2019 - 2024. All rights reserved.