mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Setup the qr_decomposition to use LAPACK when available. Also removed the
qr_decomposition::get_householder() function since I don't currently have any way to test it or precisely define what it does. --HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403839
This commit is contained in:
parent
f0fd26ba5a
commit
b54d9d2b4a
199
dlib/matrix/lapack/ormqr.h
Normal file
199
dlib/matrix/lapack/ormqr.h
Normal file
@ -0,0 +1,199 @@
|
||||
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_LAPACk_ORMQR_H__
|
||||
#define DLIB_LAPACk_ORMQR_H__
|
||||
|
||||
#include "fortran_id.h"
|
||||
#include "../matrix.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
namespace lapack
|
||||
{
|
||||
namespace binding
|
||||
{
|
||||
extern "C"
|
||||
{
|
||||
void DLIB_FORTRAN_ID(dormqr) (char *side, char *trans, integer *m, integer *n,
|
||||
integer *k, const double *a, integer *lda, const double *tau,
|
||||
double * c__, integer *ldc, double *work, integer *lwork,
|
||||
integer *info);
|
||||
|
||||
void DLIB_FORTRAN_ID(sormqr) (char *side, char *trans, integer *m, integer *n,
|
||||
integer *k, const float *a, integer *lda, const float *tau,
|
||||
float * c__, integer *ldc, float *work, integer *lwork,
|
||||
integer *info);
|
||||
|
||||
}
|
||||
|
||||
inline int ormqr (char side, char trans, integer m, integer n,
|
||||
integer k, const double *a, integer lda, const double *tau,
|
||||
double *c__, integer ldc, double *work, integer lwork)
|
||||
{
|
||||
integer info = 0;
|
||||
DLIB_FORTRAN_ID(dormqr)(&side, &trans, &m, &n,
|
||||
&k, a, &lda, tau,
|
||||
c__, &ldc, work, &lwork, &info);
|
||||
return info;
|
||||
}
|
||||
|
||||
inline int ormqr (char side, char trans, integer m, integer n,
|
||||
integer k, const float *a, integer lda, const float *tau,
|
||||
float *c__, integer ldc, float *work, integer lwork)
|
||||
{
|
||||
integer info = 0;
|
||||
DLIB_FORTRAN_ID(sormqr)(&side, &trans, &m, &n,
|
||||
&k, a, &lda, tau,
|
||||
c__, &ldc, work, &lwork, &info);
|
||||
return info;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
/* -- LAPACK routine (version 3.1) -- */
|
||||
/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
|
||||
/* November 2006 */
|
||||
|
||||
/* .. Scalar Arguments .. */
|
||||
/* .. */
|
||||
/* .. Array Arguments .. */
|
||||
/* .. */
|
||||
|
||||
/* Purpose */
|
||||
/* ======= */
|
||||
|
||||
/* DORMQR overwrites the general real M-by-N matrix C with */
|
||||
|
||||
/* SIDE = 'L' SIDE = 'R' */
|
||||
/* TRANS = 'N': Q * C C * Q */
|
||||
/* TRANS = 'T': Q**T * C C * Q**T */
|
||||
|
||||
/* where Q is a real orthogonal matrix defined as the product of k */
|
||||
/* elementary reflectors */
|
||||
|
||||
/* Q = H(1) H(2) . . . H(k) */
|
||||
|
||||
/* as returned by DGEQRF. Q is of order M if SIDE = 'L' and of order N */
|
||||
/* if SIDE = 'R'. */
|
||||
|
||||
/* Arguments */
|
||||
/* ========= */
|
||||
|
||||
/* SIDE (input) CHARACTER*1 */
|
||||
/* = 'L': apply Q or Q**T from the Left; */
|
||||
/* = 'R': apply Q or Q**T from the Right. */
|
||||
|
||||
/* TRANS (input) CHARACTER*1 */
|
||||
/* = 'N': No transpose, apply Q; */
|
||||
/* = 'T': Transpose, apply Q**T. */
|
||||
|
||||
/* M (input) INTEGER */
|
||||
/* The number of rows of the matrix C. M >= 0. */
|
||||
|
||||
/* N (input) INTEGER */
|
||||
/* The number of columns of the matrix C. N >= 0. */
|
||||
|
||||
/* K (input) INTEGER */
|
||||
/* The number of elementary reflectors whose product defines */
|
||||
/* the matrix Q. */
|
||||
/* If SIDE = 'L', M >= K >= 0; */
|
||||
/* if SIDE = 'R', N >= K >= 0. */
|
||||
|
||||
/* A (input) DOUBLE PRECISION array, dimension (LDA,K) */
|
||||
/* The i-th column must contain the vector which defines the */
|
||||
/* elementary reflector H(i), for i = 1,2,...,k, as returned by */
|
||||
/* DGEQRF in the first k columns of its array argument A. */
|
||||
/* A is modified by the routine but restored on exit. */
|
||||
|
||||
/* LDA (input) INTEGER */
|
||||
/* The leading dimension of the array A. */
|
||||
/* If SIDE = 'L', LDA >= max(1,M); */
|
||||
/* if SIDE = 'R', LDA >= max(1,N). */
|
||||
|
||||
/* TAU (input) DOUBLE PRECISION array, dimension (K) */
|
||||
/* TAU(i) must contain the scalar factor of the elementary */
|
||||
/* reflector H(i), as returned by DGEQRF. */
|
||||
|
||||
/* C (input/output) DOUBLE PRECISION array, dimension (LDC,N) */
|
||||
/* On entry, the M-by-N matrix C. */
|
||||
/* On exit, C is overwritten by Q*C or Q**T*C or C*Q**T or C*Q. */
|
||||
|
||||
/* LDC (input) INTEGER */
|
||||
/* The leading dimension of the array C. LDC >= max(1,M). */
|
||||
|
||||
/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
|
||||
/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */
|
||||
|
||||
/* LWORK (input) INTEGER */
|
||||
/* The dimension of the array WORK. */
|
||||
/* If SIDE = 'L', LWORK >= max(1,N); */
|
||||
/* if SIDE = 'R', LWORK >= max(1,M). */
|
||||
/* For optimum performance LWORK >= N*NB if SIDE = 'L', and */
|
||||
/* LWORK >= M*NB if SIDE = 'R', where NB is the optimal */
|
||||
/* blocksize. */
|
||||
|
||||
/* If LWORK = -1, then a workspace query is assumed; the routine */
|
||||
/* only calculates the optimal size of the WORK array, returns */
|
||||
/* this value as the first entry of the WORK array, and no error */
|
||||
/* message related to LWORK is issued by XERBLA. */
|
||||
|
||||
/* INFO (output) INTEGER */
|
||||
/* = 0: successful exit */
|
||||
/* < 0: if INFO = -i, the i-th argument had an illegal value */
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename T,
|
||||
long NR1, long NR2, long NR3,
|
||||
long NC1, long NC2, long NC3,
|
||||
typename MM
|
||||
>
|
||||
int ormqr (
|
||||
char side,
|
||||
char trans,
|
||||
const matrix<T,NR1,NC1,MM,column_major_layout>& a,
|
||||
const matrix<T,NR2,NC2,MM,column_major_layout>& tau,
|
||||
matrix<T,NR3,NC3,MM,column_major_layout>& c
|
||||
)
|
||||
{
|
||||
const long m = c.nr();
|
||||
const long n = c.nc();
|
||||
const long k = a.nc();
|
||||
|
||||
matrix<T,0,1,MM,column_major_layout> work;
|
||||
|
||||
// figure out how big the workspace needs to be.
|
||||
T work_size = 1;
|
||||
int info = binding::ormqr(side, trans, m, n,
|
||||
k, &a(0,0), a.nr(), &tau(0,0),
|
||||
&c(0,0), c.nr(), &work_size, -1);
|
||||
|
||||
if (info != 0)
|
||||
return info;
|
||||
|
||||
if (work.size() < work_size)
|
||||
work.set_size(static_cast<long>(work_size), 1);
|
||||
|
||||
// compute the actual result
|
||||
info = binding::ormqr(side, trans, m, n,
|
||||
k, &a(0,0), a.nr(), &tau(0,0),
|
||||
&c(0,0), c.nr(), &work(0,0), work.size());
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
#endif // DLIB_LAPACk_ORMQR_H__
|
||||
|
@ -504,6 +504,9 @@ namespace dlib
|
||||
The Q and R factors can be retrieved via the get_q() and get_r()
|
||||
methods. Furthermore, a solve() method is provided to find the
|
||||
least squares solution of Ax=b using the QR factors.
|
||||
|
||||
If DLIB_USE_LAPACK is #defined then the xGEQRF routine
|
||||
from LAPACK is used to compute the QR decomposition.
|
||||
!*/
|
||||
|
||||
public:
|
||||
@ -515,7 +518,6 @@ namespace dlib
|
||||
typedef typename matrix_exp_type::layout_type layout_type;
|
||||
|
||||
typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
|
||||
typedef matrix<type,0,1,mem_manager_type,layout_type> column_vector_type;
|
||||
|
||||
template <typename EXP>
|
||||
qr_decomposition(
|
||||
@ -556,17 +558,6 @@ namespace dlib
|
||||
- returns the number of columns in the input matrix
|
||||
!*/
|
||||
|
||||
const matrix_type get_householder (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns a matrix H such that:
|
||||
- H is the lower trapezoidal matrix whose columns define the
|
||||
Householder reflection vectors from QR factorization
|
||||
- H.nr() == nr()
|
||||
- H.nc() == nc()
|
||||
!*/
|
||||
|
||||
const matrix_type get_r (
|
||||
) const;
|
||||
/*!
|
||||
|
@ -9,6 +9,13 @@
|
||||
#include "matrix_utilities.h"
|
||||
#include "matrix_subexp.h"
|
||||
|
||||
#ifdef DLIB_USE_LAPACK
|
||||
#include "lapack/geqrf.h"
|
||||
#include "lapack/ormqr.h"
|
||||
#endif
|
||||
|
||||
#include "matrix_trsm.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
@ -27,7 +34,6 @@ namespace dlib
|
||||
typedef typename matrix_exp_type::layout_type layout_type;
|
||||
|
||||
typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
|
||||
typedef matrix<type,0,1,mem_manager_type,layout_type> column_vector_type;
|
||||
|
||||
// You have supplied an invalid type of matrix_exp_type. You have
|
||||
// to use this object with matrices that contain float or double type data.
|
||||
@ -50,9 +56,6 @@ namespace dlib
|
||||
long nc(
|
||||
) const;
|
||||
|
||||
const matrix_type get_householder (
|
||||
) const;
|
||||
|
||||
const matrix_type get_r (
|
||||
) const;
|
||||
|
||||
@ -66,6 +69,7 @@ namespace dlib
|
||||
|
||||
private:
|
||||
|
||||
#ifndef DLIB_USE_LAPACK
|
||||
template <typename EXP>
|
||||
const matrix_type solve_mat (
|
||||
const matrix_exp<EXP>& B
|
||||
@ -75,12 +79,13 @@ namespace dlib
|
||||
const matrix_type solve_vect (
|
||||
const matrix_exp<EXP>& B
|
||||
) const;
|
||||
#endif
|
||||
|
||||
|
||||
/** Array for internal storage of decomposition.
|
||||
@serial internal array storage.
|
||||
*/
|
||||
matrix_type QR_;
|
||||
matrix<type,0,0,mem_manager_type,column_major_layout> QR_;
|
||||
|
||||
/** Row and column dimensions.
|
||||
@serial column dimension.
|
||||
@ -91,6 +96,8 @@ namespace dlib
|
||||
/** Array for internal storage of diagonal of R.
|
||||
@serial diagonal of R.
|
||||
*/
|
||||
typedef matrix<type,0,1,mem_manager_type,column_major_layout> column_vector_type;
|
||||
column_vector_type tau;
|
||||
column_vector_type Rdiag;
|
||||
|
||||
|
||||
@ -125,6 +132,13 @@ namespace dlib
|
||||
QR_ = A;
|
||||
m = A.nr();
|
||||
n = A.nc();
|
||||
|
||||
#ifdef DLIB_USE_LAPACK
|
||||
|
||||
lapack::geqrf(QR_, tau);
|
||||
Rdiag = diag(QR_);
|
||||
|
||||
#else
|
||||
Rdiag.set_size(n);
|
||||
long i=0, j=0, k=0;
|
||||
|
||||
@ -168,6 +182,7 @@ namespace dlib
|
||||
}
|
||||
Rdiag(k) = -nrm;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
@ -207,16 +222,6 @@ namespace dlib
|
||||
return min(abs(Rdiag)) > eps;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename matrix_exp_type>
|
||||
const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
|
||||
get_householder (
|
||||
) const
|
||||
{
|
||||
return lowerm(QR_);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename matrix_exp_type>
|
||||
@ -253,6 +258,17 @@ namespace dlib
|
||||
get_q(
|
||||
) const
|
||||
{
|
||||
#ifdef DLIB_USE_LAPACK
|
||||
matrix<type,0,0,mem_manager_type,column_major_layout> X;
|
||||
// Take only the first n columns of an identity matrix. This way
|
||||
// X ends up being an m by n matrix.
|
||||
X = colm(identity_matrix<type>(m), range(0,n-1));
|
||||
|
||||
// Compute Y = Q*X
|
||||
lapack::ormqr('L','N', QR_, tau, X);
|
||||
return X;
|
||||
|
||||
#else
|
||||
long i=0, j=0, k=0;
|
||||
|
||||
matrix_type Q(m,n);
|
||||
@ -281,6 +297,7 @@ namespace dlib
|
||||
}
|
||||
}
|
||||
return Q;
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
@ -303,11 +320,25 @@ namespace dlib
|
||||
<< "\n\tthis: " << this
|
||||
);
|
||||
|
||||
#ifdef DLIB_USE_LAPACK
|
||||
|
||||
using namespace blas_bindings;
|
||||
matrix<type,0,0,mem_manager_type,column_major_layout> X(B);
|
||||
// Compute Y = transpose(Q)*B
|
||||
lapack::ormqr('L','T',QR_, tau, X);
|
||||
// Solve R*X = Y;
|
||||
triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, QR_, X, n);
|
||||
|
||||
/* return n x nx portion of X */
|
||||
return subm(X,0,0,n,B.nc());
|
||||
|
||||
#else
|
||||
// just call the right version of the solve function
|
||||
if (B.nc() == 1)
|
||||
return solve_vect(B);
|
||||
else
|
||||
return solve_mat(B);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
@ -316,6 +347,8 @@ namespace dlib
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
#ifndef DLIB_USE_LAPACK
|
||||
|
||||
template <typename matrix_exp_type>
|
||||
template <typename EXP>
|
||||
const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
|
||||
@ -407,6 +440,7 @@ namespace dlib
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
#endif // DLIB_USE_LAPACK not defined
|
||||
|
||||
}
|
||||
|
||||
|
@ -591,6 +591,29 @@ namespace dlib
|
||||
alpha, &A(0,0), A.nr(), &B(0,0), B.nr());
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename T,
|
||||
long NR1, long NR2,
|
||||
long NC1, long NC2,
|
||||
typename MM
|
||||
>
|
||||
inline void triangular_solver (
|
||||
const enum CBLAS_SIDE Side,
|
||||
const enum CBLAS_UPLO Uplo,
|
||||
const enum CBLAS_TRANSPOSE TransA,
|
||||
const enum CBLAS_DIAG Diag,
|
||||
const matrix<T,NR1,NC1,MM,column_major_layout>& A,
|
||||
matrix<T,NR2,NC2,MM,column_major_layout>& B,
|
||||
long rows_of_B
|
||||
)
|
||||
{
|
||||
const T alpha = 1;
|
||||
cblas_trsm(CblasColMajor, Side, Uplo, TransA, Diag, rows_of_B, B.nc(),
|
||||
alpha, &A(0,0), A.nr(), &B(0,0), B.nr());
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
|
Loading…
Reference in New Issue
Block a user