我有一个图像的验证数据集,可以通过我的CNN模型进行分类.我想用pytorch
加载这些图像.torchvision.datasets.ImageFolder()
函数不起作用,因为没有目标,因为数据集是未分类的.我假设我需要编写一个自定义的数据集类,我稍后会放入torch.utils.data.DataLoader()
.我已经在网上搜索过了,但我还是不太明白这门课应该是什么样子的.
我试过这个
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
import os
class Dset(Dataset):
def __init__(self, dir: str, transform=None) -> None:
self.transform = transform
self.images = os.listdir(dir)
self.dir = dir
def __getitem__(self, index: int) -> torch.Tensor:
image = read_image(f'{self.dir}/{self.images[index]}')
if self.transform is not None:
image = self.transform(image)
return image
def __len__(self) -> int:
return len(self.images)
但在这个单元格之后(所有图像都在.data/
中)
from torchvision import transforms
batch_size = 64
transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
data = Dset('data', transform=transform)
trainloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
images, labels = iter(trainloader)
我遇到这个错误:TypeError: Input image tensor permitted channel values are [1, 3], but found 4
Update个
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image, ImageReadMode
import os
class Dset(Dataset):
def __init__(self, dir: str, transform=None) -> None:
self.transform = transform
self.images = os.listdir(dir)
self.dir = dir
def __getitem__(self, index: int) -> torch.Tensor:
image = read_image(f'{self.dir}/{self.images[index]}', mode=ImageReadMode.RGB)
if self.transform is not None:
image = self.transform(image)
return image
def __len__(self) -> int:
return len(self.images)
该错误是由图像中的Alpha通道引起的.
在解决这个问题之后,我遇到了这个:TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
Update 2
from torchvision import transforms
batch_size = 64
transform = transforms.Compose(
[transforms.ToPILImage(), transforms.Resize((512, 512)),
transforms.Grayscale(), transforms.ToTensor()]
)
data = Dset('data', transform=transform)
trainloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
images = iter(trainloader)[0]
torchvision.transforms.ToTensor
仅将PIL图像或numpy.ndarray
转换为张量.
最后一行结果:TypeError: '_SingleProcessDataLoaderIter' object is not subscriptable