我正在研究一个问题,其中一个矩阵必须迭代计算大量次.必要的矩阵乘法采用以下形式

t(X) %*% ( ( X %*% W %*% t(X) * mu0 ) * mu1 )

其中XN x PWsymmetric P x P矩阵.mu0mu1N x 1个向量,计算和输入各自的乘积元素很便宜.

不幸的是,N可能相当大,这导致了巨大的计算需求,因为X %*% W %*% t(X)N x N.我想知道是否有任何策略或计算技巧,例如基于矩阵分解的策略或计算技巧,可以用来加快计算速度.在每次迭代中,mu0mu1都会发生变化,但XW是固定的,因此任何包括这些矩阵的预计算都会起作用.

BENCHMARKING

到目前为止,我能想到的最快的方法是进行一些明显的预计算:

# fake data
N   = 2500
P   = 10
X   = matrix(rnorm(N*P), N, P)
W   = matrix(rnorm(P*P), P, P)
mu0 = rnorm(N)
mu1 = rnorm(N)

# precomputations
tX  = t(X)
XWX = X %*% W %*% t(X)

# functions
f_raw      = function(X, W, mu0, mu1){t(X) %*% ( ( X %*% W %*% t(X) * mu0 ) * mu1 )}
f_precomp  = function(XWX, tX, mu0, mu1){tX %*% ( ( XWX * mu0 ) * mu1 )}


# benchmark
microbenchmark::microbenchmark(f_raw(X, W, mu0, mu1), 
                               f_precomp(XWX, tX, mu0, mu1))


Unit: milliseconds
                         expr      min       lq     mean   median       uq      max neval
        f_raw(X, W, mu0, mu1) 283.5918 286.5080 299.4621 289.5151 302.9726 355.4271   100
 f_precomp(XWX, tX, mu0, mu1) 167.4169 168.7336 180.6468 171.0852 197.7475 263.8090   100

推荐答案

您的示例数据和try :

# fake data
N   = 2500
P   = 10
X   = matrix(rnorm(N*P), N, P)
W   = matrix(rnorm(P*P), P, P)
mu0 = rnorm(N)
mu1 = rnorm(N)

# precomputations
tX  = t(X)
XWX = X %*% W %*% t(X)

# functions
f_raw      = function(X, W, mu0, mu1){t(X) %*% ( ( X %*% W %*% t(X) * mu0 ) * mu1 )}
f_precomp  = function(XWX, tX, mu0, mu1){tX %*% ( ( XWX * mu0 ) * mu1 )}

更好的方法是预先计算

WX <- tcrossprod(W, X)

然后

f_better <- function (X, mu0, mu1, WX) crossprod(X, (mu0 * mu1) * X) %*% WX

笔记本电脑上的基准测试:

microbenchmark::microbenchmark(f_raw(X, W, mu0, mu1), 
                               f_precomp(XWX, tX, mu0, mu1),
                               f_better(X, mu0, mu1, WX))
#Unit: milliseconds
#                         expr        min         lq      mean     median
#        f_raw(X, W, mu0, mu1) 236.926979 238.984573 243.86319 242.212716
# f_precomp(XWX, tX, mu0, mu1) 151.190689 152.612059 156.25277 155.646093
#    f_better(X, mu0, mu1, WX)   1.113031   1.126434   1.17974   1.138207
#         uq        max neval
# 244.721163 270.477737   100
# 157.970378 182.935082   100
#   1.146352   5.130876   100

验证正确性:

ans_raw <- f_raw(X, W, mu0, mu1)
ans_precomp <- f_precomp(XWX, tX, mu0, mu1)
ans_better <- f_better(X, mu0, mu1, WX)

all.equal(ans_raw, ans_precomp)
# [1] TRUE
all.equal(ans_raw, ans_better)
# [1] TRUE

我不知道mu0mu1在您的实际应用中是否保证为非负.如果是这样,则以下速度会快2倍:

f_fastest <- function (X, mu0, mu1, WX) crossprod(sqrt(mu0 * mu1) * X) %*% WX

R相关问答推荐

在与ggplot 2和网格的最佳匹配线上绘制箭头

R:将列名的字符载体传递给可以 Select 接受多个参数的函数

从R中的地址提取街道名称

如何在热图中绘制一个图形,但在每个单元格中通过饼形图显示?

使用R中的gt对R中的html rmarkdown文件进行条件格式设置表的单元格

查找图下的面积

根据模式将一列拆分为多列,并在R中进行拆分

在R中替换函数中的特定符号

在for循环中转换rabrame

根据类别合并(汇总)某些行

从多个可选列中选取一个值到一个新列中

R -使用矩阵reshape 列表

在gggraph中显示来自不同数据帧的单个值

将数据集旋转到长格式,用于遵循特定名称模式的所有变量对

有没有办法一次粘贴所有列

如何删除R中除数字元素以外的所有元素

按组内中位数分类

在ggplot2上从多个数据框创建复杂的自定义图形

用多边形替换地块点

在使用SliderInput In Shiny(R)设置输入数据的子集时,保留一些情节痕迹