Added a regularization parameter to cca()

This commit is contained in:
Davis King 2013-03-21 20:08:51 -04:00
parent 3c8db58597
commit 542f8b1cad
2 changed files with 43 additions and 24 deletions

View File

@ -17,9 +17,9 @@ namespace dlib
template < template <
typename T typename T
> >
matrix<T,0,1> compute_correlations ( matrix<typename T::type,0,1> compute_correlations (
const matrix<T>& L, const matrix_exp<T>& L,
const matrix<T>& R const matrix_exp<T>& R
) )
{ {
DLIB_ASSERT( L.size() > 0 && R.size() > 0 && L.nr() == R.nr(), DLIB_ASSERT( L.size() > 0 && R.size() > 0 && L.nr() == R.nr(),
@ -31,7 +31,8 @@ namespace dlib
<< "\n\t R.nr(): " << R.nr() << "\n\t R.nr(): " << R.nr()
); );
matrix<T> A, B, C; typedef typename T::type type;
matrix<type> A, B, C;
A = diag(trans(R)*L); A = diag(trans(R)*L);
B = sqrt(diag(trans(L)*L)); B = sqrt(diag(trans(L)*L));
C = sqrt(diag(trans(R)*R)); C = sqrt(diag(trans(R)*R));
@ -53,7 +54,8 @@ namespace dlib
unsigned long num_correlations, unsigned long num_correlations,
unsigned long extra_rank, unsigned long extra_rank,
unsigned long q, unsigned long q,
unsigned long num_output_correlations unsigned long num_output_correlations,
double regularization
) )
{ {
matrix<T> Ul, Vl; matrix<T> Ul, Vl;
@ -70,8 +72,8 @@ namespace dlib
// Zero out singular values that are essentially zero so they don't cause numerical // Zero out singular values that are essentially zero so they don't cause numerical
// difficulties in the code below. // difficulties in the code below.
const double eps = std::numeric_limits<T>::epsilon()*std::max(max(Dr),max(Dl))*100; const double eps = std::numeric_limits<T>::epsilon()*std::max(max(Dr),max(Dl))*100;
Dl = round_zeros(Dl,eps); Dl = round_zeros(Dl+regularization,eps);
Dr = round_zeros(Dr,eps); Dr = round_zeros(Dr+regularization,eps);
// This matrix is really small so we can do a normal full SVD on it. Note that we // This matrix is really small so we can do a normal full SVD on it. Note that we
// also throw away the columns of Ul and Ur corresponding to zero singular values. // also throw away the columns of Ul and Ur corresponding to zero singular values.
@ -105,13 +107,16 @@ namespace dlib
matrix<T>& Rtrans, matrix<T>& Rtrans,
unsigned long num_correlations, unsigned long num_correlations,
unsigned long extra_rank = 5, unsigned long extra_rank = 5,
unsigned long q = 2 unsigned long q = 2,
double regularization = 0
) )
{ {
DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr(), DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr() &&
regularization >= 0,
"\t matrix cca()" "\t matrix cca()"
<< "\n\t Invalid inputs were given to this function." << "\n\t Invalid inputs were given to this function."
<< "\n\t num_correlations: " << num_correlations << "\n\t num_correlations: " << num_correlations
<< "\n\t regularization: " << regularization
<< "\n\t L.size(): " << L.size() << "\n\t L.size(): " << L.size()
<< "\n\t R.size(): " << R.size() << "\n\t R.size(): " << R.size()
<< "\n\t L.nr(): " << L.nr() << "\n\t L.nr(): " << L.nr()
@ -120,7 +125,7 @@ namespace dlib
using std::min; using std::min;
const unsigned long n = min(num_correlations, (unsigned long)min(R.nr(),min(L.nc(), R.nc()))); const unsigned long n = min(num_correlations, (unsigned long)min(R.nr(),min(L.nc(), R.nc())));
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n); return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n, regularization);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
@ -133,14 +138,17 @@ namespace dlib
matrix<T>& Rtrans, matrix<T>& Rtrans,
unsigned long num_correlations, unsigned long num_correlations,
unsigned long extra_rank = 5, unsigned long extra_rank = 5,
unsigned long q = 2 unsigned long q = 2,
double regularization = 0
) )
{ {
DLIB_ASSERT( num_correlations > 0 && L.size() == R.size() && DLIB_ASSERT( num_correlations > 0 && L.size() == R.size() &&
max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0, max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0 &&
regularization >= 0,
"\t matrix cca()" "\t matrix cca()"
<< "\n\t Invalid inputs were given to this function." << "\n\t Invalid inputs were given to this function."
<< "\n\t num_correlations: " << num_correlations << "\n\t num_correlations: " << num_correlations
<< "\n\t regularization: " << regularization
<< "\n\t L.size(): " << L.size() << "\n\t L.size(): " << L.size()
<< "\n\t R.size(): " << R.size() << "\n\t R.size(): " << R.size()
<< "\n\t max_index_plus_one(L): " << max_index_plus_one(L) << "\n\t max_index_plus_one(L): " << max_index_plus_one(L)
@ -150,7 +158,7 @@ namespace dlib
using std::min; using std::min;
const unsigned long n = min(max_index_plus_one(L), max_index_plus_one(R)); const unsigned long n = min(max_index_plus_one(L), max_index_plus_one(R));
const unsigned long num_output_correlations = min(num_correlations, std::min<unsigned long>(R.size(),n)); const unsigned long num_output_correlations = min(num_correlations, std::min<unsigned long>(R.size(),n));
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations); return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations, regularization);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------

View File

@ -14,9 +14,9 @@ namespace dlib
template < template <
typename T typename T
> >
matrix<T,0,1> compute_correlations ( matrix<typename T::type,0,1> compute_correlations (
const matrix<T>& L, const matrix_exp<T>& L,
const matrix<T>& R const matrix_exp<T>& R
); );
/*! /*!
requires requires
@ -47,7 +47,8 @@ namespace dlib
matrix<T>& Rtrans, matrix<T>& Rtrans,
unsigned long num_correlations, unsigned long num_correlations,
unsigned long extra_rank = 5, unsigned long extra_rank = 5,
unsigned long q = 2 unsigned long q = 2,
double regularization = 0
); );
/*! /*!
requires requires
@ -55,6 +56,7 @@ namespace dlib
- L.size() > 0 - L.size() > 0
- R.size() > 0 - R.size() > 0
- L.nr() == R.nr() - L.nr() == R.nr()
- regularization >= 0
ensures ensures
- This function performs a canonical correlation analysis between the row - This function performs a canonical correlation analysis between the row
vectors in L and R. That is, it finds two transformation matrices, Ltrans vectors in L and R. That is, it finds two transformation matrices, Ltrans
@ -83,11 +85,16 @@ namespace dlib
problems. problems.
- returns an estimate of compute_correlations(L*#Ltrans, R*#Rtrans). The - returns an estimate of compute_correlations(L*#Ltrans, R*#Rtrans). The
returned vector should exactly match the output of compute_correlations() returned vector should exactly match the output of compute_correlations()
when the reduced rank approximation to L and R is accurate. However, when L when the reduced rank approximation to L and R is accurate and regularization
and/or R are higher rank than num_correlations+extra_rank the return value of is set to 0. However, if this is not the case then the return value of this
this function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans). function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans). This
This deviation can be used to check if the reduced rank approximation is deviation can be used to check if the reduced rank approximation is working
working or you need to increase extra_rank. or you need to increase extra_rank.
- This function performs the ridge regression version of Canonical Correlation
Analysis when regularization is set to a value > 0. In particular, larger
values indicate the solution should be more heavily regularized. This can be
useful when the dimensionality of the data is larger than the number of
samples.
- A good discussion of CCA can be found in the paper "Canonical Correlation - A good discussion of CCA can be found in the paper "Canonical Correlation
Analysis" by David Weenink. In particular, this function is implemented Analysis" by David Weenink. In particular, this function is implemented
using equations 29 and 30 from his paper. We also use the idea of doing CCA using equations 29 and 30 from his paper. We also use the idea of doing CCA
@ -109,7 +116,8 @@ namespace dlib
matrix<T>& Rtrans, matrix<T>& Rtrans,
unsigned long num_correlations, unsigned long num_correlations,
unsigned long extra_rank = 5, unsigned long extra_rank = 5,
unsigned long q = 2 unsigned long q = 2,
double regularization = 0
); );
/*! /*!
requires requires
@ -119,6 +127,7 @@ namespace dlib
(i.e. L and R can't represent empty matrices) (i.e. L and R can't represent empty matrices)
- L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h - L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
for a definition of sparse vector) for a definition of sparse vector)
- regularization >= 0
ensures ensures
- This is just an overload of the cca() function defined above. Except in this - This is just an overload of the cca() function defined above. Except in this
case we take a sparse representation of the input L and R matrices rather than case we take a sparse representation of the input L and R matrices rather than
@ -144,7 +153,8 @@ namespace dlib
matrix<T>& Rtrans, matrix<T>& Rtrans,
unsigned long num_correlations, unsigned long num_correlations,
unsigned long extra_rank = 5, unsigned long extra_rank = 5,
unsigned long q = 2 unsigned long q = 2,
double regularization = 0
); );
/*! /*!
requires requires
@ -154,6 +164,7 @@ namespace dlib
(i.e. L and R can't represent empty matrices) (i.e. L and R can't represent empty matrices)
- L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h - L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
for a definition of sparse vector) for a definition of sparse vector)
- regularization >= 0
ensures ensures
- returns cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q) - returns cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q)
(i.e. this is just a convenience function for calling the cca() routine when (i.e. this is just a convenience function for calling the cca() routine when