100:

我试图计算两个大张量(对于k近邻)中每两个样本之间的成对距离,即给定形状为(b1,c,h,w)的张量test和形状为(b2,c,h,w)的张量train,每ij需要|| test[i]-train[j] ||.(其中test[i]train[j]都具有形状(c,h,w),因为它们是批次中的样本).

100

traintest都很大,所以我无法将它们放入RAM中

100

首先,我并没有一次性构造这些张量——在构建它们的过程中,我分割了数据张量,并将它们分别保存到内存中,因此我最终得到了文件{Test\test_1,...,Test\test_n}{Train\train_1,...,Train\train_m}.

这个半伪代码可以解释

test_files = [f'Test\test_{i}' for i in range(n)]
train_files = [f'Train\train_{j}' for j in range(m)]
dist = lambda t1,t2: torch.cdist(t1.flatten(1), t2.flatten(1)) 
all_distances = []
for test_i in test_files:
    test_i = torch.load(test_i) # shape (c,h,w)
    dist_of_i_from_all_j = torch.Tensor([])
    for train_j in train_files:
        train_j = torch.load(train_j) # shape (c,h,w)
        dist_of_i_from_all_j = torch.cat((dist_of_i_from_all_j, dist(test_i,train_j))
    all_distances.append(dist_of_i_from_all_j)
# and now I can take the k-smallest from all_distances

100

我遇到了FAISS repository个,他们解释说这个过程可以加快(也许?)使用他们的解决方案,尽管我不太确定如何使用.无论如何,任何方法都会有帮助!

推荐答案

你查过FAISS documentation了吗?

如果您需要的是L2范数(torch.cidst使用p=2作为默认参数),那么它非常简单.下面的代码是FAISS docs对您的示例的改编:

import faiss
import numpy as np
d = 64                           # dimension
nb = 100000                      # database size
nq = 10000                       # nb of queries
np.random.seed(1234)             # make reproducible
x_test = np.random.random((nb, d)).astype('float32')
x_test[:, 0] += np.arange(nb) / 1000.
x_train = np.random.random((nq, d)).astype('float32')
x_train[:, 0] += np.arange(nq) / 1000.

index = faiss.IndexFlatL2(d)   # build the index
print(index.is_trained)
index.add(x_test)                  # add vectors to the index
print(index.ntotal)

k= 100 # take the 100 closest neighbors
D, I = index.search(x_train, k)     # actual search
print(I[:5])                   # neighbors of the 100 first queries
print(I[-5:])                  # neighbors of the 100 last queries

Python相关问答推荐

有症状地 destruct 了Python中的regex?

为什么这个带有List输入的简单numba函数这么慢

numpy卷积与有效

try 将一行连接到Tensorflow中的矩阵

如何在给定的条件下使numpy数组的计算速度最快?

Streamlit应用程序中的Plotly条形图中未正确显示Y轴刻度

如何设置视频语言时上传到YouTube与Python API客户端

所有列的滚动标准差,忽略NaN

driver. find_element无法通过class_name找到元素'""

用砂箱开发Web统计分析

基于形状而非距离的两个numpy数组相似性

基于行条件计算(pandas)

从源代码显示不同的输出(机器学习)(Python)

用两个字符串构建回文

用LAKEF划分实木地板AWS Wrangler

合并Pandas中的数据帧,但处理不存在的列

Django-修改后的管理表单返回对象而不是文本

Wagail:当通过外键访问索引页时,如何过滤索引页的子项

torch 二维张量与三维张量欧氏距离的计算

条件Python Polars cum_sum over a group,有更好的方法吗?