From 542f8b1cad6b3d00c5d472faa0710fb2f9ca9c1a Mon Sep 17 00:00:00 2001 From: Davis King Date: Thu, 21 Mar 2013 20:08:51 -0400 Subject: [PATCH] Added a regularization parameter to cca() --- dlib/statistics/cca.h | 34 +++++++++++++++++++++------------- dlib/statistics/cca_abstract.h | 33 ++++++++++++++++++++++----------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/dlib/statistics/cca.h b/dlib/statistics/cca.h index 8558a832c..ae6bce0f8 100644 --- a/dlib/statistics/cca.h +++ b/dlib/statistics/cca.h @@ -17,9 +17,9 @@ namespace dlib template < typename T > - matrix compute_correlations ( - const matrix& L, - const matrix& R + matrix compute_correlations ( + const matrix_exp& L, + const matrix_exp& R ) { DLIB_ASSERT( L.size() > 0 && R.size() > 0 && L.nr() == R.nr(), @@ -31,7 +31,8 @@ namespace dlib << "\n\t R.nr(): " << R.nr() ); - matrix A, B, C; + typedef typename T::type type; + matrix A, B, C; A = diag(trans(R)*L); B = sqrt(diag(trans(L)*L)); C = sqrt(diag(trans(R)*R)); @@ -53,7 +54,8 @@ namespace dlib unsigned long num_correlations, unsigned long extra_rank, unsigned long q, - unsigned long num_output_correlations + unsigned long num_output_correlations, + double regularization ) { matrix Ul, Vl; @@ -70,8 +72,8 @@ namespace dlib // Zero out singular values that are essentially zero so they don't cause numerical // difficulties in the code below. const double eps = std::numeric_limits::epsilon()*std::max(max(Dr),max(Dl))*100; - Dl = round_zeros(Dl,eps); - Dr = round_zeros(Dr,eps); + Dl = round_zeros(Dl+regularization,eps); + Dr = round_zeros(Dr+regularization,eps); // This matrix is really small so we can do a normal full SVD on it. Note that we // also throw away the columns of Ul and Ur corresponding to zero singular values. @@ -105,13 +107,16 @@ namespace dlib matrix& Rtrans, unsigned long num_correlations, unsigned long extra_rank = 5, - unsigned long q = 2 + unsigned long q = 2, + double regularization = 0 ) { - DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr(), + DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr() && + regularization >= 0, "\t matrix cca()" << "\n\t Invalid inputs were given to this function." << "\n\t num_correlations: " << num_correlations + << "\n\t regularization: " << regularization << "\n\t L.size(): " << L.size() << "\n\t R.size(): " << R.size() << "\n\t L.nr(): " << L.nr() @@ -120,7 +125,7 @@ namespace dlib using std::min; const unsigned long n = min(num_correlations, (unsigned long)min(R.nr(),min(L.nc(), R.nc()))); - return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n); + return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n, regularization); } // ---------------------------------------------------------------------------------------- @@ -133,14 +138,17 @@ namespace dlib matrix& Rtrans, unsigned long num_correlations, unsigned long extra_rank = 5, - unsigned long q = 2 + unsigned long q = 2, + double regularization = 0 ) { DLIB_ASSERT( num_correlations > 0 && L.size() == R.size() && - max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0, + max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0 && + regularization >= 0, "\t matrix cca()" << "\n\t Invalid inputs were given to this function." << "\n\t num_correlations: " << num_correlations + << "\n\t regularization: " << regularization << "\n\t L.size(): " << L.size() << "\n\t R.size(): " << R.size() << "\n\t max_index_plus_one(L): " << max_index_plus_one(L) @@ -150,7 +158,7 @@ namespace dlib using std::min; const unsigned long n = min(max_index_plus_one(L), max_index_plus_one(R)); const unsigned long num_output_correlations = min(num_correlations, std::min(R.size(),n)); - return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations); + return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations, regularization); } // ---------------------------------------------------------------------------------------- diff --git a/dlib/statistics/cca_abstract.h b/dlib/statistics/cca_abstract.h index 6ce443073..420a6859a 100644 --- a/dlib/statistics/cca_abstract.h +++ b/dlib/statistics/cca_abstract.h @@ -14,9 +14,9 @@ namespace dlib template < typename T > - matrix compute_correlations ( - const matrix& L, - const matrix& R + matrix compute_correlations ( + const matrix_exp& L, + const matrix_exp& R ); /*! requires @@ -47,7 +47,8 @@ namespace dlib matrix& Rtrans, unsigned long num_correlations, unsigned long extra_rank = 5, - unsigned long q = 2 + unsigned long q = 2, + double regularization = 0 ); /*! requires @@ -55,6 +56,7 @@ namespace dlib - L.size() > 0 - R.size() > 0 - L.nr() == R.nr() + - regularization >= 0 ensures - This function performs a canonical correlation analysis between the row vectors in L and R. That is, it finds two transformation matrices, Ltrans @@ -83,11 +85,16 @@ namespace dlib problems. - returns an estimate of compute_correlations(L*#Ltrans, R*#Rtrans). The returned vector should exactly match the output of compute_correlations() - when the reduced rank approximation to L and R is accurate. However, when L - and/or R are higher rank than num_correlations+extra_rank the return value of - this function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans). - This deviation can be used to check if the reduced rank approximation is - working or you need to increase extra_rank. + when the reduced rank approximation to L and R is accurate and regularization + is set to 0. However, if this is not the case then the return value of this + function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans). This + deviation can be used to check if the reduced rank approximation is working + or you need to increase extra_rank. + - This function performs the ridge regression version of Canonical Correlation + Analysis when regularization is set to a value > 0. In particular, larger + values indicate the solution should be more heavily regularized. This can be + useful when the dimensionality of the data is larger than the number of + samples. - A good discussion of CCA can be found in the paper "Canonical Correlation Analysis" by David Weenink. In particular, this function is implemented using equations 29 and 30 from his paper. We also use the idea of doing CCA @@ -109,7 +116,8 @@ namespace dlib matrix& Rtrans, unsigned long num_correlations, unsigned long extra_rank = 5, - unsigned long q = 2 + unsigned long q = 2, + double regularization = 0 ); /*! requires @@ -119,6 +127,7 @@ namespace dlib (i.e. L and R can't represent empty matrices) - L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h for a definition of sparse vector) + - regularization >= 0 ensures - This is just an overload of the cca() function defined above. Except in this case we take a sparse representation of the input L and R matrices rather than @@ -144,7 +153,8 @@ namespace dlib matrix& Rtrans, unsigned long num_correlations, unsigned long extra_rank = 5, - unsigned long q = 2 + unsigned long q = 2, + double regularization = 0 ); /*! requires @@ -154,6 +164,7 @@ namespace dlib (i.e. L and R can't represent empty matrices) - L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h for a definition of sparse vector) + - regularization >= 0 ensures - returns cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q) (i.e. this is just a convenience function for calling the cca() routine when