存储库中有一个预训练模型,其文件类型为.pyth.我在网上搜索了一下这个文件类型,以及哪种语言能够阅读这个文件,但我什么也没找到.由于我正在使用Pytork,是否可以在Pytork中读取此类文件?此外,通常情况下,如何读取和生成该信息?

更清楚的是,在repository of the TimeSformer模型中,预训练模型属于这种文件类型,例如,您可以在该存储库中找到以下命令:

import torch
from timesformer.models.vit import TimeSformer

model = TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time',  pretrained_model='/path/to/pretrained/model.pyth')

dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)

pred = model(dummy_video,) # (2, 400)

推荐答案

文件扩展名可以是任何内容,它不会更改文件内容.如果你运行torch.load("file.pyth"),它将加载一个权重字典.您可以在您包含的回购协议的代码中找到这一点.他们使用以下代码保存模型:

path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1)
with PathManager.open(path_to_checkpoint, "wb") as f:
    torch.save(checkpoint, f)

get_path_to_checkpoint函数可以找到here:

def get_path_to_checkpoint(path_to_job, epoch):
    """
    Get the full path to a checkpoint file.
    Args:
        path_to_job (string): the path to the folder of the current job.
        epoch (int): the number of epoch for the checkpoint.
    """
    name = "checkpoint_epoch_{:05d}.pyth".format(epoch)
    return os.path.join(get_checkpoint_dir(path_to_job), name)

所以,他们只传递一个扩展名为.pythtorch.save的文件名.

Python相关问答推荐

pyscript中的压痕问题

mypy无法推断类型参数.List和Iterable的区别

Python导入某些库时非法指令(核心转储)(beautifulsoup4."" yfinance)

旋转多边形而不改变内部空间关系

在Python中调用变量(特别是Tkinter)

Flask Jinja2如果语句总是计算为false&

将标签移动到matplotlib饼图中楔形块的开始处

搜索按钮不工作,Python tkinter

导入错误:无法导入名称';操作';

numpy数组和数组标量之间的不同行为

Python协议不兼容警告

使用np.fft.fft2和cv2.dft重现相位谱.为什么结果并不相似呢?

分解polars DataFrame列而不重复其他列值

根据过滤后的牛郎星图表中的数据计算新系列

如何在Polars中创建条件增量列?

如何有效地计算所有输出相对于参数的梯度?

使用_in链接操作管道传输的中间结果是否可用于链中的后续函数?

给定y的误差时,线性回归系数的计算误差

有条件的滚动平均数(面试问题)

在忽略on列中的重复值的同时连接polars重命名