Setup the qr_decomposition to use LAPACK when available. Also removed the

qr_decomposition::get_householder() function since I don't currently have any
way to test it or precisely define what it does.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403839
This commit is contained in:
Davis King 2010-09-14 01:33:56 +00:00
parent f0fd26ba5a
commit b54d9d2b4a
4 changed files with 274 additions and 27 deletions

199
dlib/matrix/lapack/ormqr.h Normal file
View File

@ -0,0 +1,199 @@
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_LAPACk_ORMQR_H__
#define DLIB_LAPACk_ORMQR_H__
#include "fortran_id.h"
#include "../matrix.h"
namespace dlib
{
namespace lapack
{
namespace binding
{
extern "C"
{
void DLIB_FORTRAN_ID(dormqr) (char *side, char *trans, integer *m, integer *n,
integer *k, const double *a, integer *lda, const double *tau,
double * c__, integer *ldc, double *work, integer *lwork,
integer *info);
void DLIB_FORTRAN_ID(sormqr) (char *side, char *trans, integer *m, integer *n,
integer *k, const float *a, integer *lda, const float *tau,
float * c__, integer *ldc, float *work, integer *lwork,
integer *info);
}
inline int ormqr (char side, char trans, integer m, integer n,
integer k, const double *a, integer lda, const double *tau,
double *c__, integer ldc, double *work, integer lwork)
{
integer info = 0;
DLIB_FORTRAN_ID(dormqr)(&side, &trans, &m, &n,
&k, a, &lda, tau,
c__, &ldc, work, &lwork, &info);
return info;
}
inline int ormqr (char side, char trans, integer m, integer n,
integer k, const float *a, integer lda, const float *tau,
float *c__, integer ldc, float *work, integer lwork)
{
integer info = 0;
DLIB_FORTRAN_ID(sormqr)(&side, &trans, &m, &n,
&k, a, &lda, tau,
c__, &ldc, work, &lwork, &info);
return info;
}
}
// ------------------------------------------------------------------------------------
/* -- LAPACK routine (version 3.1) -- */
/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
/* November 2006 */
/* .. Scalar Arguments .. */
/* .. */
/* .. Array Arguments .. */
/* .. */
/* Purpose */
/* ======= */
/* DORMQR overwrites the general real M-by-N matrix C with */
/* SIDE = 'L' SIDE = 'R' */
/* TRANS = 'N': Q * C C * Q */
/* TRANS = 'T': Q**T * C C * Q**T */
/* where Q is a real orthogonal matrix defined as the product of k */
/* elementary reflectors */
/* Q = H(1) H(2) . . . H(k) */
/* as returned by DGEQRF. Q is of order M if SIDE = 'L' and of order N */
/* if SIDE = 'R'. */
/* Arguments */
/* ========= */
/* SIDE (input) CHARACTER*1 */
/* = 'L': apply Q or Q**T from the Left; */
/* = 'R': apply Q or Q**T from the Right. */
/* TRANS (input) CHARACTER*1 */
/* = 'N': No transpose, apply Q; */
/* = 'T': Transpose, apply Q**T. */
/* M (input) INTEGER */
/* The number of rows of the matrix C. M >= 0. */
/* N (input) INTEGER */
/* The number of columns of the matrix C. N >= 0. */
/* K (input) INTEGER */
/* The number of elementary reflectors whose product defines */
/* the matrix Q. */
/* If SIDE = 'L', M >= K >= 0; */
/* if SIDE = 'R', N >= K >= 0. */
/* A (input) DOUBLE PRECISION array, dimension (LDA,K) */
/* The i-th column must contain the vector which defines the */
/* elementary reflector H(i), for i = 1,2,...,k, as returned by */
/* DGEQRF in the first k columns of its array argument A. */
/* A is modified by the routine but restored on exit. */
/* LDA (input) INTEGER */
/* The leading dimension of the array A. */
/* If SIDE = 'L', LDA >= max(1,M); */
/* if SIDE = 'R', LDA >= max(1,N). */
/* TAU (input) DOUBLE PRECISION array, dimension (K) */
/* TAU(i) must contain the scalar factor of the elementary */
/* reflector H(i), as returned by DGEQRF. */
/* C (input/output) DOUBLE PRECISION array, dimension (LDC,N) */
/* On entry, the M-by-N matrix C. */
/* On exit, C is overwritten by Q*C or Q**T*C or C*Q**T or C*Q. */
/* LDC (input) INTEGER */
/* The leading dimension of the array C. LDC >= max(1,M). */
/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */
/* LWORK (input) INTEGER */
/* The dimension of the array WORK. */
/* If SIDE = 'L', LWORK >= max(1,N); */
/* if SIDE = 'R', LWORK >= max(1,M). */
/* For optimum performance LWORK >= N*NB if SIDE = 'L', and */
/* LWORK >= M*NB if SIDE = 'R', where NB is the optimal */
/* blocksize. */
/* If LWORK = -1, then a workspace query is assumed; the routine */
/* only calculates the optimal size of the WORK array, returns */
/* this value as the first entry of the WORK array, and no error */
/* message related to LWORK is issued by XERBLA. */
/* INFO (output) INTEGER */
/* = 0: successful exit */
/* < 0: if INFO = -i, the i-th argument had an illegal value */
// ------------------------------------------------------------------------------------
template <
typename T,
long NR1, long NR2, long NR3,
long NC1, long NC2, long NC3,
typename MM
>
int ormqr (
char side,
char trans,
const matrix<T,NR1,NC1,MM,column_major_layout>& a,
const matrix<T,NR2,NC2,MM,column_major_layout>& tau,
matrix<T,NR3,NC3,MM,column_major_layout>& c
)
{
const long m = c.nr();
const long n = c.nc();
const long k = a.nc();
matrix<T,0,1,MM,column_major_layout> work;
// figure out how big the workspace needs to be.
T work_size = 1;
int info = binding::ormqr(side, trans, m, n,
k, &a(0,0), a.nr(), &tau(0,0),
&c(0,0), c.nr(), &work_size, -1);
if (info != 0)
return info;
if (work.size() < work_size)
work.set_size(static_cast<long>(work_size), 1);
// compute the actual result
info = binding::ormqr(side, trans, m, n,
k, &a(0,0), a.nr(), &tau(0,0),
&c(0,0), c.nr(), &work(0,0), work.size());
return info;
}
// ------------------------------------------------------------------------------------
}
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_LAPACk_ORMQR_H__

View File

@ -504,6 +504,9 @@ namespace dlib
The Q and R factors can be retrieved via the get_q() and get_r()
methods. Furthermore, a solve() method is provided to find the
least squares solution of Ax=b using the QR factors.
If DLIB_USE_LAPACK is #defined then the xGEQRF routine
from LAPACK is used to compute the QR decomposition.
!*/
public:
@ -515,7 +518,6 @@ namespace dlib
typedef typename matrix_exp_type::layout_type layout_type;
typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
typedef matrix<type,0,1,mem_manager_type,layout_type> column_vector_type;
template <typename EXP>
qr_decomposition(
@ -556,17 +558,6 @@ namespace dlib
- returns the number of columns in the input matrix
!*/
const matrix_type get_householder (
) const;
/*!
ensures
- returns a matrix H such that:
- H is the lower trapezoidal matrix whose columns define the
Householder reflection vectors from QR factorization
- H.nr() == nr()
- H.nc() == nc()
!*/
const matrix_type get_r (
) const;
/*!

View File

@ -9,6 +9,13 @@
#include "matrix_utilities.h"
#include "matrix_subexp.h"
#ifdef DLIB_USE_LAPACK
#include "lapack/geqrf.h"
#include "lapack/ormqr.h"
#endif
#include "matrix_trsm.h"
namespace dlib
{
@ -27,7 +34,6 @@ namespace dlib
typedef typename matrix_exp_type::layout_type layout_type;
typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
typedef matrix<type,0,1,mem_manager_type,layout_type> column_vector_type;
// You have supplied an invalid type of matrix_exp_type. You have
// to use this object with matrices that contain float or double type data.
@ -50,9 +56,6 @@ namespace dlib
long nc(
) const;
const matrix_type get_householder (
) const;
const matrix_type get_r (
) const;
@ -66,6 +69,7 @@ namespace dlib
private:
#ifndef DLIB_USE_LAPACK
template <typename EXP>
const matrix_type solve_mat (
const matrix_exp<EXP>& B
@ -75,12 +79,13 @@ namespace dlib
const matrix_type solve_vect (
const matrix_exp<EXP>& B
) const;
#endif
/** Array for internal storage of decomposition.
@serial internal array storage.
*/
matrix_type QR_;
matrix<type,0,0,mem_manager_type,column_major_layout> QR_;
/** Row and column dimensions.
@serial column dimension.
@ -91,6 +96,8 @@ namespace dlib
/** Array for internal storage of diagonal of R.
@serial diagonal of R.
*/
typedef matrix<type,0,1,mem_manager_type,column_major_layout> column_vector_type;
column_vector_type tau;
column_vector_type Rdiag;
@ -125,6 +132,13 @@ namespace dlib
QR_ = A;
m = A.nr();
n = A.nc();
#ifdef DLIB_USE_LAPACK
lapack::geqrf(QR_, tau);
Rdiag = diag(QR_);
#else
Rdiag.set_size(n);
long i=0, j=0, k=0;
@ -168,6 +182,7 @@ namespace dlib
}
Rdiag(k) = -nrm;
}
#endif
}
// ----------------------------------------------------------------------------------------
@ -207,16 +222,6 @@ namespace dlib
return min(abs(Rdiag)) > eps;
}
// ----------------------------------------------------------------------------------------
template <typename matrix_exp_type>
const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
get_householder (
) const
{
return lowerm(QR_);
}
// ----------------------------------------------------------------------------------------
template <typename matrix_exp_type>
@ -253,6 +258,17 @@ namespace dlib
get_q(
) const
{
#ifdef DLIB_USE_LAPACK
matrix<type,0,0,mem_manager_type,column_major_layout> X;
// Take only the first n columns of an identity matrix. This way
// X ends up being an m by n matrix.
X = colm(identity_matrix<type>(m), range(0,n-1));
// Compute Y = Q*X
lapack::ormqr('L','N', QR_, tau, X);
return X;
#else
long i=0, j=0, k=0;
matrix_type Q(m,n);
@ -281,6 +297,7 @@ namespace dlib
}
}
return Q;
#endif
}
// ----------------------------------------------------------------------------------------
@ -303,11 +320,25 @@ namespace dlib
<< "\n\tthis: " << this
);
#ifdef DLIB_USE_LAPACK
using namespace blas_bindings;
matrix<type,0,0,mem_manager_type,column_major_layout> X(B);
// Compute Y = transpose(Q)*B
lapack::ormqr('L','T',QR_, tau, X);
// Solve R*X = Y;
triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, QR_, X, n);
/* return n x nx portion of X */
return subm(X,0,0,n,B.nc());
#else
// just call the right version of the solve function
if (B.nc() == 1)
return solve_vect(B);
else
return solve_mat(B);
#endif
}
// ----------------------------------------------------------------------------------------
@ -316,6 +347,8 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
#ifndef DLIB_USE_LAPACK
template <typename matrix_exp_type>
template <typename EXP>
const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
@ -407,6 +440,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
#endif // DLIB_USE_LAPACK not defined
}

View File

@ -591,6 +591,29 @@ namespace dlib
alpha, &A(0,0), A.nr(), &B(0,0), B.nr());
}
// ------------------------------------------------------------------------------------
template <
typename T,
long NR1, long NR2,
long NC1, long NC2,
typename MM
>
inline void triangular_solver (
const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag,
const matrix<T,NR1,NC1,MM,column_major_layout>& A,
matrix<T,NR2,NC2,MM,column_major_layout>& B,
long rows_of_B
)
{
const T alpha = 1;
cblas_trsm(CblasColMajor, Side, Uplo, TransA, Diag, rows_of_B, B.nc(),
alpha, &A(0,0), A.nr(), &B(0,0), B.nr());
}
// ------------------------------------------------------------------------------------
template <