应用 torchvision.Normalize 应导致值低于 0 且高于 1,因此当切换到整数值时,值应低于 0 和 255。
然而,当观察这些值时,情况似乎并非如此。这是怎么处理的?
请在下面找到代码示例来重现该问题。
对于一些上下文,我试图将使用 onnx 的神经网络集成到 C++ 代码中,但我无法重现我的 python 结果,因为低于 0 和高于 1 的值被剪裁。
from PIL import Image
from torchvision import transforms
def make_transforms(normalize: bool = True) -> transforms.Compose:
results = [
transforms.Resize((224, 224)),
transforms.ToTensor(),
]
if normalize:
results.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
return transforms.Compose(results)
def main() -> float:
resize_a_white_image_without_normalization() # works as expected
resize_a_white_image_WITH_normalization() # how are value > 255 handled???
def resize_a_white_image_WITH_normalization():
# given a given image
white = (255, 255, 255)
white_img = Image.new("RGB", (300, 300), white)
# when resizing the image, and normalizing it
resized_img = make_transforms(normalize=True)(white_img)
resized_img_pil = transforms.ToPILImage()(resized_img)
# expected normalized value
normalized_val = [(1.0 - 0.485) / 0.229, (1.0 - 0.456) / 0.224, (1 - 0.406) / 0.225]
normalized_val_int = [int(i * 255) for i in normalized_val]
print(normalized_val) # [2.2489082969432315, 2.428571428571429, 2.6399999999999997] > 1 ??
print(normalized_val_int) # [573, 619, 673] > 255??
print([between_0_and_255(i) for i in normalized_val_int]) # [63, 109, 163]
print(np.array(resized_img_pil)[0,0]) # [ 61 107 161] ???? still different from above!
def resize_a_white_image_without_normalization() -> float:
# given a given image
white = (255, 255, 255)
white_img = Image.new("RGB", (300, 300), white)
# when resizing the image, but not normalizing it,
resized_img = make_transforms(normalize=False)(white_img)
# then all pixels should remain white
assert (np.array(resized_img) == np.ones_like(np.array(resized_img))).all()
def between_0_and_255(value: int):
return value % 255
if __name__ == "__main__":
main()
查看ToPILImage的文档,我们可以看到根据张量选择不同的PIL图像模式,在您的情况下:
如果输入有3个通道,则假定模式为RGB。
以及 来自 PIL 文档:
RGB(3x8 位像素,真彩色)
因此,您的张量会自动转换为像素值范围为 0..255(8 位)的图像。因此,当值转换为 PIL 图像时,它们会被剪裁
def between_0_and_255(value: int):
return value % 255
应该是
def between_0_and_255(value: int):
return value % 256
这样,输出就变成了
[2.2489082969432315, 2.428571428571429, 2.6399999999999997]
[573, 619, 673]
[61, 107, 161]
[ 61 107 161]