RcppEigen更快的协方差

问题描述 投票:0回答:1

我已经拟合了一个回归模型,它为回归参数B输出一个协方差矩阵S.我需要通过乘以X对这个协方差矩阵进行操作,然后得到新的协方差和stderr向量

cov(X * B) = X * cov(B) * X.transpose()

因为我只需要cov(X * B)的对角线,我不需要进行全矩阵乘法,我可以得到每行X_i * B的协方差并求它们

#include <RcppEigen.h>
// [[Rcpp::depends(RcppEigen)]]

using Eigen::Map;
using Eigen::MatrixXd;
using Eigen::VectorXd;
using Eigen::SparseMatrix;
using Eigen::MappedSparseMatrix;
using namespace Rcpp;
using namespace Eigen;

double foo(const Eigen::MappedSparseMatrix<double>& mm, 
           const Eigen::MappedSparseMatrix<double>& vcov) {

  int n = mm.rows();
  double out = 0;
  SparseMatrix<double> mm_t = mm.adjoint();

  SparseMatrix<double> var(1, 1);
  var.setZero();

  for (int i = 0; i < n; i++) {
    var = mm.row(i) * vcov * mm_t.col(i);
    out += var.coeff(0, 0);
  }

  return out;
}

出于某种原因,这个函数在1M行上非常慢。我尝试使用“块”而不是逐行操作mm,认为通过操作一个值块可以使用vcov进行矩阵乘法更快。这并没有使功能更快。这是一个可重复的例子

require(Matrix)

set.seed(100)
N = 2.5e5
p = 100

mm = rsparsematrix(N, p, .01)
vcov = rsparsematrix(p, p, .5)

system.time(foo(mm, vcov))

有没有办法让这个功能更快?

c++ r rcpp
1个回答
4
投票

你可以使用一个简单的数学“技巧”,如果协方差矩阵是真实的和对称的(并且在你的情况下是一个协方差矩阵)。

x %*% b %*% t(b) %*% t(x)的对角元素之和可以计算为

sum((x %*% b)^2)

这是超级快。请注意,上面的公式将b %*% t(b)作为“三明治”的“火腿”部分,因此您需要计算cov(B)的平方根,然后您可以使用该公式。

或者,您可以直接在R中使用以下按元素生成的产品

sum((mm %*% vcov) * mm)

我不是那么精通RcppEigen和那里的稀疏矩阵所以以下可能会被优化,但它似乎很快

// [[Rcpp::export]]                                                                                                                        
double foo2(const Eigen::MappedSparseMatrix<double>& mm,
           const Eigen::MappedSparseMatrix<double>& vcov) {

  double out = 0;
  SparseMatrix<double> mat;

  mat = mm.cwiseProduct(mm*vcov);


  for (int k=0; k<mat.outerSize(); ++k) {
    for (SparseMatrix<double>::InnerIterator it(mat,k); it; ++it)
      {
        out +=it.value();
      }
  }

  return out;
}

这是一个简短的速度比较

> microbenchmark::microbenchmark(foo(mm, vcov), foo2(mm, vcov), sum((mm %*% vcov) * mm), times=2)
Unit: milliseconds
                    expr        min         lq       mean     median         uq
           foo(mm, vcov) 32575.5488 32575.5488 33587.4147 33587.4147 34599.2806
          foo2(mm, vcov)   463.9440   463.9440   492.4232   492.4232   520.9023
 sum((mm %*% vcov) * mm)   953.7902   953.7902   981.4750   981.4750  1009.1598
        max neval cld
 34599.2806     2   b
   520.9023     2  a 
  1009.1598     2  a 

相当一些改进。即使只是单独使用R。

© www.soinside.com 2019 - 2024. All rights reserved.