python – Pytorch transform.ToTensor() changes image

python – Pytorch transform.ToTensor() changes image

It seems that the problem is with the channel axis.

If you look at torchvision.transforms docs, especially on ToTensor()

Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]

So once you perform the transformation and return to numpy.array your shape is: (C, H, W) and you should change the positions, you can do the following:

demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)

This will transform the array to shape (H, W, C) and then when you return to PIL and show it will be the same image.

So in total:

import numpy as np
from PIL import Image
from torchvision import transforms

trans = transforms.Compose([transforms.ToTensor()])

demo = Image.open(img) 
demo_img = trans(demo)
demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)
print(Image.fromarray(demo_array.astype(np.uint8)))

python – Pytorch transform.ToTensor() changes image

Leave a Reply

Your email address will not be published.