我使用DataLoader
和自定义batch_sampler
来确保每一批都是类平衡的.如何防止迭代器在第一个历元中耗尽自己?
import torch
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
self.x = torch.rand(10, 10)
self.y = torch.Tensor([0] * 5 + [1] * 5)
def __len__(self):
len(self.y)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return iter(batch_idx)
def train(loader):
for epoch in range(10):
for batch, (x, y) in enumerate(loader):
print('epoch:', epoch, 'batch:', batch) # stops after first epoch
if __name__=='__main__':
my_dataset = CustomDataset()
my_loader = torch.utils.data.DataLoader(
dataset=my_dataset,
batch_sampler=custom_batch_sampler()
)
train(my_loader)
训练在第一个历元后停止,next(iter(loader))
给出StopIteration
错误.
epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4