mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added a regularization parameter to cca()
This commit is contained in:
parent
3c8db58597
commit
542f8b1cad
@ -17,9 +17,9 @@ namespace dlib
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
matrix<T,0,1> compute_correlations (
|
||||
const matrix<T>& L,
|
||||
const matrix<T>& R
|
||||
matrix<typename T::type,0,1> compute_correlations (
|
||||
const matrix_exp<T>& L,
|
||||
const matrix_exp<T>& R
|
||||
)
|
||||
{
|
||||
DLIB_ASSERT( L.size() > 0 && R.size() > 0 && L.nr() == R.nr(),
|
||||
@ -31,7 +31,8 @@ namespace dlib
|
||||
<< "\n\t R.nr(): " << R.nr()
|
||||
);
|
||||
|
||||
matrix<T> A, B, C;
|
||||
typedef typename T::type type;
|
||||
matrix<type> A, B, C;
|
||||
A = diag(trans(R)*L);
|
||||
B = sqrt(diag(trans(L)*L));
|
||||
C = sqrt(diag(trans(R)*R));
|
||||
@ -53,7 +54,8 @@ namespace dlib
|
||||
unsigned long num_correlations,
|
||||
unsigned long extra_rank,
|
||||
unsigned long q,
|
||||
unsigned long num_output_correlations
|
||||
unsigned long num_output_correlations,
|
||||
double regularization
|
||||
)
|
||||
{
|
||||
matrix<T> Ul, Vl;
|
||||
@ -70,8 +72,8 @@ namespace dlib
|
||||
// Zero out singular values that are essentially zero so they don't cause numerical
|
||||
// difficulties in the code below.
|
||||
const double eps = std::numeric_limits<T>::epsilon()*std::max(max(Dr),max(Dl))*100;
|
||||
Dl = round_zeros(Dl,eps);
|
||||
Dr = round_zeros(Dr,eps);
|
||||
Dl = round_zeros(Dl+regularization,eps);
|
||||
Dr = round_zeros(Dr+regularization,eps);
|
||||
|
||||
// This matrix is really small so we can do a normal full SVD on it. Note that we
|
||||
// also throw away the columns of Ul and Ur corresponding to zero singular values.
|
||||
@ -105,13 +107,16 @@ namespace dlib
|
||||
matrix<T>& Rtrans,
|
||||
unsigned long num_correlations,
|
||||
unsigned long extra_rank = 5,
|
||||
unsigned long q = 2
|
||||
unsigned long q = 2,
|
||||
double regularization = 0
|
||||
)
|
||||
{
|
||||
DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr(),
|
||||
DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr() &&
|
||||
regularization >= 0,
|
||||
"\t matrix cca()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t num_correlations: " << num_correlations
|
||||
<< "\n\t regularization: " << regularization
|
||||
<< "\n\t L.size(): " << L.size()
|
||||
<< "\n\t R.size(): " << R.size()
|
||||
<< "\n\t L.nr(): " << L.nr()
|
||||
@ -120,7 +125,7 @@ namespace dlib
|
||||
|
||||
using std::min;
|
||||
const unsigned long n = min(num_correlations, (unsigned long)min(R.nr(),min(L.nc(), R.nc())));
|
||||
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n);
|
||||
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n, regularization);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
@ -133,14 +138,17 @@ namespace dlib
|
||||
matrix<T>& Rtrans,
|
||||
unsigned long num_correlations,
|
||||
unsigned long extra_rank = 5,
|
||||
unsigned long q = 2
|
||||
unsigned long q = 2,
|
||||
double regularization = 0
|
||||
)
|
||||
{
|
||||
DLIB_ASSERT( num_correlations > 0 && L.size() == R.size() &&
|
||||
max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0,
|
||||
max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0 &&
|
||||
regularization >= 0,
|
||||
"\t matrix cca()"
|
||||
<< "\n\t Invalid inputs were given to this function."
|
||||
<< "\n\t num_correlations: " << num_correlations
|
||||
<< "\n\t regularization: " << regularization
|
||||
<< "\n\t L.size(): " << L.size()
|
||||
<< "\n\t R.size(): " << R.size()
|
||||
<< "\n\t max_index_plus_one(L): " << max_index_plus_one(L)
|
||||
@ -150,7 +158,7 @@ namespace dlib
|
||||
using std::min;
|
||||
const unsigned long n = min(max_index_plus_one(L), max_index_plus_one(R));
|
||||
const unsigned long num_output_correlations = min(num_correlations, std::min<unsigned long>(R.size(),n));
|
||||
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations);
|
||||
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations, regularization);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -14,9 +14,9 @@ namespace dlib
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
matrix<T,0,1> compute_correlations (
|
||||
const matrix<T>& L,
|
||||
const matrix<T>& R
|
||||
matrix<typename T::type,0,1> compute_correlations (
|
||||
const matrix_exp<T>& L,
|
||||
const matrix_exp<T>& R
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -47,7 +47,8 @@ namespace dlib
|
||||
matrix<T>& Rtrans,
|
||||
unsigned long num_correlations,
|
||||
unsigned long extra_rank = 5,
|
||||
unsigned long q = 2
|
||||
unsigned long q = 2,
|
||||
double regularization = 0
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -55,6 +56,7 @@ namespace dlib
|
||||
- L.size() > 0
|
||||
- R.size() > 0
|
||||
- L.nr() == R.nr()
|
||||
- regularization >= 0
|
||||
ensures
|
||||
- This function performs a canonical correlation analysis between the row
|
||||
vectors in L and R. That is, it finds two transformation matrices, Ltrans
|
||||
@ -83,11 +85,16 @@ namespace dlib
|
||||
problems.
|
||||
- returns an estimate of compute_correlations(L*#Ltrans, R*#Rtrans). The
|
||||
returned vector should exactly match the output of compute_correlations()
|
||||
when the reduced rank approximation to L and R is accurate. However, when L
|
||||
and/or R are higher rank than num_correlations+extra_rank the return value of
|
||||
this function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans).
|
||||
This deviation can be used to check if the reduced rank approximation is
|
||||
working or you need to increase extra_rank.
|
||||
when the reduced rank approximation to L and R is accurate and regularization
|
||||
is set to 0. However, if this is not the case then the return value of this
|
||||
function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans). This
|
||||
deviation can be used to check if the reduced rank approximation is working
|
||||
or you need to increase extra_rank.
|
||||
- This function performs the ridge regression version of Canonical Correlation
|
||||
Analysis when regularization is set to a value > 0. In particular, larger
|
||||
values indicate the solution should be more heavily regularized. This can be
|
||||
useful when the dimensionality of the data is larger than the number of
|
||||
samples.
|
||||
- A good discussion of CCA can be found in the paper "Canonical Correlation
|
||||
Analysis" by David Weenink. In particular, this function is implemented
|
||||
using equations 29 and 30 from his paper. We also use the idea of doing CCA
|
||||
@ -109,7 +116,8 @@ namespace dlib
|
||||
matrix<T>& Rtrans,
|
||||
unsigned long num_correlations,
|
||||
unsigned long extra_rank = 5,
|
||||
unsigned long q = 2
|
||||
unsigned long q = 2,
|
||||
double regularization = 0
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -119,6 +127,7 @@ namespace dlib
|
||||
(i.e. L and R can't represent empty matrices)
|
||||
- L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
|
||||
for a definition of sparse vector)
|
||||
- regularization >= 0
|
||||
ensures
|
||||
- This is just an overload of the cca() function defined above. Except in this
|
||||
case we take a sparse representation of the input L and R matrices rather than
|
||||
@ -144,7 +153,8 @@ namespace dlib
|
||||
matrix<T>& Rtrans,
|
||||
unsigned long num_correlations,
|
||||
unsigned long extra_rank = 5,
|
||||
unsigned long q = 2
|
||||
unsigned long q = 2,
|
||||
double regularization = 0
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
@ -154,6 +164,7 @@ namespace dlib
|
||||
(i.e. L and R can't represent empty matrices)
|
||||
- L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
|
||||
for a definition of sparse vector)
|
||||
- regularization >= 0
|
||||
ensures
|
||||
- returns cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q)
|
||||
(i.e. this is just a convenience function for calling the cca() routine when
|
||||
|
Loading…
Reference in New Issue
Block a user