我使用DataLoader和自定义batch_sampler来确保每一批都是类平衡的.如何防止迭代器在第一个历元中耗尽自己?

import torch

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.x = torch.rand(10, 10)
        self.y = torch.Tensor([0] * 5 + [1] * 5)
        
    def __len__(self):
        len(self.y)
        
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def custom_batch_sampler():
    batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
    return iter(batch_idx)

def train(loader):
    for epoch in range(10):
        for batch, (x, y) in enumerate(loader):
            print('epoch:', epoch, 'batch:', batch) # stops after first epoch

if __name__=='__main__':
    my_dataset = CustomDataset()
    my_loader = torch.utils.data.DataLoader(
        dataset=my_dataset,
        batch_sampler=custom_batch_sampler()
    )
    train(my_loader)

训练在第一个历元后停止,next(iter(loader))给出StopIteration错误.

epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4

推荐答案

定制的批量取样器必须是Sampler或更高.在每一个纪元中,一个新的迭代器从这个可数生成.这意味着您实际上不需要手动创建迭代器(在第一个历元之后,迭代器将运行并提升StopIteration),但您可以提供您的列表,因此如果您删除iter():

def custom_batch_sampler():
    batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
    return batch_idx

Python相关问答推荐

try 在树叶 map 上应用覆盖磁贴

如何将双框框列中的成对变成两个新列

Python 约束无法解决n皇后之谜

查找两极rame中组之间的所有差异

将输入管道传输到正在运行的Python脚本中

如何从在虚拟Python环境中运行的脚本中运行需要宿主Python环境的Shell脚本?

NP.round解算数据后NP.unique

Telethon加入私有频道

isinstance()在使用dill.dump和dill.load后,对列表中包含的对象失败

如何指定列数据类型

关于两个表达式的区别

PYTHON、VLC、RTSP.屏幕截图不起作用

在二维NumPy数组中,如何 Select 内部数组的第一个和第二个元素?这可以通过索引来实现吗?

统计numpy. ndarray中的项目列表出现次数的最快方法

如何求相邻对序列中元素 Select 的最小代价

504未连接IB API TWS错误—即使API连接显示已接受''

如何在Python中自动创建数字文件夹和正在进行的文件夹?

类型对象';敌人';没有属性';损害';

运行从Airflow包导入的python文件,需要airflow实例?

Pandas:新列,从列表中采样,基于列值