Dans R, la multiplication matricielle est très optimisée, c'est-à-dire qu'il s'agit simplement d'un appel à BLAS/LAPACK. Cependant, je suis surpris que ce code C++ très naïf pour la multiplication matrice-vecteur semble fiable 30% plus rapide.
library(Rcpp)
# Simple C++ code for matrix multiplication
mm_code =
"NumericVector my_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
double v_j;
for(int j = 0; j < nCol; j++){
v_j = v[j];
for(int i = 0; i < nRow; i++){
ans[i] += m(i,j) * v_j;
}
}
return(ans);
}
"
# Compiling
my_mm = cppFunction(code = mm_code)
# Simulating data to use
nRow = 10^4
nCol = 10^4
m = matrix(rnorm(nRow * nCol), nrow = nRow)
v = rnorm(nCol)
system.time(my_ans <- my_mm(m, v))
#> user system elapsed
#> 0.103 0.001 0.103
system.time(r_ans <- m %*% v)
#> user system elapsed
#> 0.154 0.001 0.154
# Double checking answer is correct
max(abs(my_ans - r_ans))
#> [1] 0
Les R de base %*%
Effectuent-ils un certain type de vérification des données que je saute?
MODIFIER:
Après avoir compris ce qui se passe (merci SO!), Il convient de noter que c'est le pire des cas pour les R %*%
, C'est-à-dire matrice par vecteur. Par exemple, @RalfStubner a souligné que l'utilisation d'une implémentation RcppArmadillo d'une multiplication matrice-vecteur est encore plus rapide que l'implémentation naïve que j'ai démontrée, impliquant beaucoup plus rapide que la base R, mais est pratiquement identique à la base R %*%
Pour multiplication matrice-matrice (lorsque les deux matrices sont grandes et carrées):
arma_code <-
"arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
return m * m2;
};"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
nRow = 10^3
nCol = 10^3
mat1 = matrix(rnorm(nRow * nCol),
nrow = nRow)
mat2 = matrix(rnorm(nRow * nCol),
nrow = nRow)
system.time(arma_mm(mat1, mat2))
#> user system elapsed
#> 0.798 0.008 0.814
system.time(mat1 %*% mat2)
#> user system elapsed
#> 0.807 0.005 0.822
Le courant de R (v3.5.0) %*%
Est donc presque optimal pour matrice-matrice, mais pourrait être considérablement accéléré pour matrice-vecteur si vous êtes d'accord de sauter la vérification.
Un coup d'œil rapide dans names.c
( ici en particulier ) vous indique do_matprod
, la fonction C appelée par %*%
et qui se trouve dans le fichier array.c
. (Fait intéressant, il s'avère que crossprod
et tcrossprod
envoient également à la même fonction). Voici un lien vers le code de do_matprod
.
En parcourant la fonction, vous pouvez voir qu'elle prend en charge un certain nombre de choses que votre implémentation naïve ne fait pas, notamment:
%*%
sont des classes pour lesquelles de telles méthodes ont été fournies. (C'est ce qui se passe dans cette partie de la fonction.)Vers la fin de la fonction , il distribue à matprod
ou ou cmatprod
. Fait intéressant (du moins pour moi), dans le cas de matrices réelles, si l'une ou l'autre matrice peut contenir NaN
ou Inf
valeurs, puis matprod
envoie ( ici ) à une fonction appelée simple_matprod
qui est à peu près aussi simple et direct que le vôtre. Sinon, il est envoyé à l'une des deux routines BLAS Fortran qui, vraisemblablement, sont plus rapides, si des éléments matriciels uniformément "bien comportés" peuvent être garantis.
La réponse de Josh explique pourquoi la multiplication matricielle de R n'est pas aussi rapide que cette approche naïve. J'étais curieux de voir combien on pouvait gagner en utilisant RcppArmadillo. Le code est assez simple:
arma_code <-
"arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
return m * v;
};"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
Référence:
> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
expr min lq mean median uq max neval
my_mm(m, v) 71.23347 75.22364 90.13766 96.88279 98.07348 98.50182 10
m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751 10
arma_mm(m, v) 41.13348 41.42314 41.89311 41.81979 42.39311 42.78396 10
Donc, RcppArmadillo nous donne une syntaxe plus agréable et de meilleures performances.
La curiosité a pris le dessus sur moi. Voici une solution pour utiliser directement BLAS:
blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
char trans = 'N';
double one = 1.0, zero = 0.0;
int ione = 1;
F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
&ione, &zero, ans.begin(), &ione);
return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")
Référence:
Unit: milliseconds
expr min lq mean median uq max neval
my_mm(m, v) 72.61298 75.40050 89.75529 96.04413 96.59283 98.29938 10
m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572 10
arma_mm(m, v) 41.06718 41.70331 42.62366 42.47320 43.22625 45.19704 10
blas_mm(m, v) 41.58618 42.14718 42.89853 42.68584 43.39182 44.46577 10
Armadillo et BLAS (OpenBLAS dans mon cas) sont presque les mêmes. Et le code BLAS est aussi ce que fait R au final. Donc 2/3 de ce que fait R est la vérification des erreurs, etc.