我有以下代码:

import torch
from facenet_pytorch import InceptionResnetV1, MTCNN
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import pandas as pd
import os


workers = 0 if os.name == 'nt' else 4
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

def collate_fn(x):
    return x[0]

dataset = datasets.ImageFolder('data/images/')
dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)
#print(dataset.idx_to_class)

aligned = []
names = []
i = 0
for x, y in loader:
    x_aligned, prob = mtcnn(x, return_prob=True)
    if x_aligned is not None:
        print('Face detected with probability: {:8f}'.format(prob))
        aligned.append(x_aligned)
        names.append(dataset.idx_to_class[y])
        i += 1
#print(i)

for name, param in mtcnn.named_parameters(): #Freezing everything but last layer
    #print(name)
    if name != "onet.dense6_3.bias":
        param.require_grad = False
    else:
        param.require_grad = True

现在我想重新训练这个模型来预测三个等级(现在它只预测一张脸的概率).假设我有data/images/个文件夹,分别是faces1faces2faces3.如何使用这三个文件夹重新训练此模型?我想要一个像[prob1, prob2, prob3]这样的张量,每个类都有一个图像的概率.谢谢

推荐答案

MTCN:该类加载经过预训练的P、R和O网络,并在给定原始输入图像的情况下,返回裁剪为仅包含面部的图像.

我假设您正在try 使用InceptionResnetV1对数据集进行分类.要重新训练初始模型,只需将所需的类数加载到模型中,然后对其进行训练.

resnet = InceptionResnetV1(
    classify=True,
    pretrained='vggface2',
    num_classes=3
)

完整的微调示例如下https://github.com/timesler/facenet-pytorch/blob/master/examples/finetune.ipynb

Python相关问答推荐

使用FASTCGI在IIS上运行Django频道

Pandas实际上如何对基于自定义的索引(integer和非integer)执行索引

连接两个具有不同标题的收件箱

max_of_three使用First_select、second_select、

在Python中处理大量CSV文件中的数据

追溯(最近最后一次调用):文件C:\Users\Diplom/PycharmProject\Yolo01\Roboflow-4.py,第4行,在模块导入roboflow中

运行总计基于多列pandas的分组和总和

为什么符号没有按顺序添加?

在Python Attrs包中,如何在field_Transformer函数中添加字段?

字符串合并语法在哪里记录

AES—256—CBC加密在Python和PHP中返回不同的结果,HELPPP

Geopandas未返回正确的缓冲区(单位:米)

为什么我的sundaram筛这么低效

以异步方式填充Pandas 数据帧

在Django中重命名我的表后,旧表中的项目不会被移动或删除

有没有办法让Re.Sub报告它所做的每一次替换?

提取最内层嵌套链接

一维不匹配两个数组上的广义ufunc

如何获取给定列中包含特定值的行号?

如何在networkx图中提取和绘制直接邻居(以及邻居的邻居)?