我有一个图像的验证数据集,可以通过我的CNN模型进行分类.我想用pytorch加载这些图像.torchvision.datasets.ImageFolder()函数不起作用,因为没有目标,因为数据集是未分类的.我假设我需要编写一个自定义的数据集类,我稍后会放入torch.utils.data.DataLoader().我已经在网上搜索过了,但我还是不太明白这门课应该是什么样子的.

我试过这个

import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
import os


class Dset(Dataset):
    def __init__(self, dir: str, transform=None) -> None:
        self.transform = transform
        self.images = os.listdir(dir)
        self.dir = dir
    
    def __getitem__(self, index: int) -> torch.Tensor:
        image = read_image(f'{self.dir}/{self.images[index]}')
        if self.transform is not None:
            image = self.transform(image)
        return image

    def __len__(self) -> int:
        return len(self.images)

但在这个单元格之后(所有图像都在.data/中)

from torchvision import transforms

batch_size = 64
transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
data = Dset('data', transform=transform)
trainloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

images, labels = iter(trainloader)

我遇到这个错误:TypeError: Input image tensor permitted channel values are [1, 3], but found 4

Update

import torch
from torch.utils.data import Dataset
from torchvision.io import read_image, ImageReadMode
import os


class Dset(Dataset):
    def __init__(self, dir: str, transform=None) -> None:
        self.transform = transform
        self.images = os.listdir(dir)
        self.dir = dir
    
    def __getitem__(self, index: int) -> torch.Tensor:
        image = read_image(f'{self.dir}/{self.images[index]}', mode=ImageReadMode.RGB)
        if self.transform is not None:
            image = self.transform(image)
        return image

    def __len__(self) -> int:
        return len(self.images)

该错误是由图像中的Alpha通道引起的.

在解决这个问题之后,我遇到了这个:TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

Update 2

from torchvision import transforms

batch_size = 64
transform = transforms.Compose(
[transforms.ToPILImage(), transforms.Resize((512, 512)),
transforms.Grayscale(), transforms.ToTensor()]
)
data = Dset('data', transform=transform)
trainloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

images = iter(trainloader)[0]

torchvision.transforms.ToTensor仅将PIL图像或numpy.ndarray转换为张量.

最后一行结果:TypeError: '_SingleProcessDataLoaderIter' object is not subscriptable

推荐答案

正如 comments 中所讨论的,问题是您的图像有一个Alpha通道.您可以修改read_image函数以从输入图像中删除Alpha通道,如下所示:

image = read_image(f'{self.dir}/{self.images[index]}', mode=ImageReadMode.RGB)

对于其他模式,您可以勾选ImageReadMode class.

100

对于新的错误-根据documentation:

ToTensor类将PIL Image或ndarray转换为张量并相应地zoom 值.

但在这里,您提供的是tensor作为输入,而不是所需的PIL图像或ndarray.

要解决这个问题,您可以使用ToPILImage方法.

100

对于错误:TypeError: '_SingleProcessDataLoaderIter' object is not subscriptable

查看PyTorch教程中的How to Iterate through a DataLoader.

此外,您还可以try 使用for循环进行迭代,如下所示:

for images in trainloader:
    # Process images here
    break # update this break statement as per your requirement

Python相关问答推荐

比较两个数据帧并并排附加结果(获取性能警告)

为什么我的Python代码在if-else声明中的行之前执行if-else声明中的行?

不理解Value错误:在Python中使用迭代对象设置时必须具有相等的len键和值

django禁止直接分配到多对多集合的前端.使用user.set()

梯度下降:简化要素集的运行时间比原始要素集长

ThreadPoolExecutor和单个线程的超时

创建可序列化数据模型的最佳方法

如何在图中标记平均点?

改进大型数据集的框架性能

isinstance()在使用dill.dump和dill.load后,对列表中包含的对象失败

我的字符串搜索算法的平均时间复杂度和最坏时间复杂度是多少?

在Python中使用if else或使用regex将二进制数据如111转换为001""

Python避免mypy在相互引用中从另一个类重定义类时失败

Python pint将1/华氏度转换为1/摄氏度°°

Python 3试图访问在线程调用中实例化的类的对象

mdates定位器在图表中显示不存在的时间间隔

如何使用matplotlib查看并列直方图

如何在SQLAlchemy + Alembic中定义一个"Index()",在基表中的列上

我如何处理超类和子类的情况

Groupby并在组内比较单独行上的两个时间戳