我正在研究一个与人工智能相关的问题,我需要在视频中跟踪几个人体部位.我用图像创建了一个数据加载器,并在调用Dataset类时进行了多次转换.

下面是一个代码示例:

transform = transforms.Compose(
        [
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

dataset = NamedClassDataset(annotation_folder_path=path, transform=transform, img_size=img_size, normalized=normalize)
train_set, validation_set = torch.utils.data.random_split(dataset, get_train_test_size(dataset,train_percent))
train_loader = DataLoader(dataset=train_set, shuffle=shuffle, batch_size=batch_size,num_workers=num_workers,pin_memory=pin_memory)
validation_loader = DataLoader(dataset=validation_set, shuffle=shuffle, batch_size=batch_size,num_workers=num_workers, pin_memory=pin_memory)

The problem is :运行模型后,我显示带有预测点的图像,以查看其质量.但由于图像已调整大小和规格,我无法检索其原始质量和 colored颜色 .我想在原始图像上显示点,而不是转换图像,我想知道通常的方法是什么.

我已经想到了两种解决方案,各有其缺点:

  • 恢复转换,但在调用resize时不可能,因为我们丢失了信息
  • NamedClassDataset__getitem__方法中,返回索引作为第三个参数(以及图像和标签).但pytorch方法在使用__getitem__时只需要两个输出,即(图像、相关标签).

编辑:以下是我的NamedClassDataset类的getitem:

def __getitem__(self, index):
        (img_path, coords) = self.annotations.iloc[index].values
        img = Image.open(img_path).convert("RGB")
        w,h = img.size
        # Normalize by img size
        if self.img_size is not None:
            if self.normalized:
                coords = coords/(w,h) # Normalized
            else:
                n_h,n_w = self.img_size
                coords = coords/(w,h)*(n_w,n_h) # Not normalized 
            

        y_coords = torch.flatten(torch.tensor(coords)).float() # Flatten outputs and convert from double to float32

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_coords)

推荐答案

我成功地用原始图像声明了另一个数据集.

# Create the same dataset with untransformed images for visualization purposes
org_dataset = NamedClassDataset(annotation_folder_path="./12_labels/extracted_swimmers", transform=None, img_size=None, normalized=False)
viz_train_set, viz_validation_set = random_split(org_dataset, get_train_test_size(org_dataset,train_percent,_print_size=False), generator=torch.Generator().manual_seed(seed))

下面是我在__getitem__ when transform=None中所做的:

        if self.transform is not None:
            tr_img = self.transform(org_img)
            return (tr_img, y_coords)
        return (org_img, y_coords)

然后,我可以通过传递viz集作为参数来访问原始图像.请注意,这是一个数据集,而不是一个数据加载器,因此您需要考虑批大小以匹配预测.

plot_predictions(viz_set[0+i*batch_size][0], preds[0])

我打开了feed,因为我坚信可以提供更有效的答案.

Python相关问答推荐

在输入行运行时停止代码

使用Openpyxl从Excel中的折线图更改图表样式

当条件满足时停止ODE集成?

Polars map_使用多处理对UDF进行批处理

如何根据rame中的列值分别分组值

应用指定的规则构建数组

使用Scikit的ValueError-了解

如何使用Polars从AWS S3读取镶木地板文件

运行从Airflow包导入的python文件,需要airflow实例?

使用pyopencl、ArrayFire或另一个Python OpenCL库制作基于欧几里得距离的掩模

是否在DataFrame中将所有列设置为大写?

将COLUMN BY GROUP中的值连接为列表,并将其赋值给PANAS数据框中的变量

检测并显示网页更改

如何在JAX中训练具有多输出(向量值)损失函数的梯度下降模型?

如何在Polars DataFrame中使用`isin‘?

我怎样才能用python打印一个 map 对象?

合并Pandas 数据框中的一些列并复制其他列

用于判断x=()的&Quot;isInstance()和Not&Quot;vs&Quot;==&Quot;

使用CSS Select 器和::before抓取不会显示文本

CKEditor更新通知