mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added a regularization parameter to cca()
This commit is contained in:
parent
3c8db58597
commit
542f8b1cad
@ -17,9 +17,9 @@ namespace dlib
|
|||||||
template <
|
template <
|
||||||
typename T
|
typename T
|
||||||
>
|
>
|
||||||
matrix<T,0,1> compute_correlations (
|
matrix<typename T::type,0,1> compute_correlations (
|
||||||
const matrix<T>& L,
|
const matrix_exp<T>& L,
|
||||||
const matrix<T>& R
|
const matrix_exp<T>& R
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
DLIB_ASSERT( L.size() > 0 && R.size() > 0 && L.nr() == R.nr(),
|
DLIB_ASSERT( L.size() > 0 && R.size() > 0 && L.nr() == R.nr(),
|
||||||
@ -31,7 +31,8 @@ namespace dlib
|
|||||||
<< "\n\t R.nr(): " << R.nr()
|
<< "\n\t R.nr(): " << R.nr()
|
||||||
);
|
);
|
||||||
|
|
||||||
matrix<T> A, B, C;
|
typedef typename T::type type;
|
||||||
|
matrix<type> A, B, C;
|
||||||
A = diag(trans(R)*L);
|
A = diag(trans(R)*L);
|
||||||
B = sqrt(diag(trans(L)*L));
|
B = sqrt(diag(trans(L)*L));
|
||||||
C = sqrt(diag(trans(R)*R));
|
C = sqrt(diag(trans(R)*R));
|
||||||
@ -53,7 +54,8 @@ namespace dlib
|
|||||||
unsigned long num_correlations,
|
unsigned long num_correlations,
|
||||||
unsigned long extra_rank,
|
unsigned long extra_rank,
|
||||||
unsigned long q,
|
unsigned long q,
|
||||||
unsigned long num_output_correlations
|
unsigned long num_output_correlations,
|
||||||
|
double regularization
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
matrix<T> Ul, Vl;
|
matrix<T> Ul, Vl;
|
||||||
@ -70,8 +72,8 @@ namespace dlib
|
|||||||
// Zero out singular values that are essentially zero so they don't cause numerical
|
// Zero out singular values that are essentially zero so they don't cause numerical
|
||||||
// difficulties in the code below.
|
// difficulties in the code below.
|
||||||
const double eps = std::numeric_limits<T>::epsilon()*std::max(max(Dr),max(Dl))*100;
|
const double eps = std::numeric_limits<T>::epsilon()*std::max(max(Dr),max(Dl))*100;
|
||||||
Dl = round_zeros(Dl,eps);
|
Dl = round_zeros(Dl+regularization,eps);
|
||||||
Dr = round_zeros(Dr,eps);
|
Dr = round_zeros(Dr+regularization,eps);
|
||||||
|
|
||||||
// This matrix is really small so we can do a normal full SVD on it. Note that we
|
// This matrix is really small so we can do a normal full SVD on it. Note that we
|
||||||
// also throw away the columns of Ul and Ur corresponding to zero singular values.
|
// also throw away the columns of Ul and Ur corresponding to zero singular values.
|
||||||
@ -105,13 +107,16 @@ namespace dlib
|
|||||||
matrix<T>& Rtrans,
|
matrix<T>& Rtrans,
|
||||||
unsigned long num_correlations,
|
unsigned long num_correlations,
|
||||||
unsigned long extra_rank = 5,
|
unsigned long extra_rank = 5,
|
||||||
unsigned long q = 2
|
unsigned long q = 2,
|
||||||
|
double regularization = 0
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr(),
|
DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr() &&
|
||||||
|
regularization >= 0,
|
||||||
"\t matrix cca()"
|
"\t matrix cca()"
|
||||||
<< "\n\t Invalid inputs were given to this function."
|
<< "\n\t Invalid inputs were given to this function."
|
||||||
<< "\n\t num_correlations: " << num_correlations
|
<< "\n\t num_correlations: " << num_correlations
|
||||||
|
<< "\n\t regularization: " << regularization
|
||||||
<< "\n\t L.size(): " << L.size()
|
<< "\n\t L.size(): " << L.size()
|
||||||
<< "\n\t R.size(): " << R.size()
|
<< "\n\t R.size(): " << R.size()
|
||||||
<< "\n\t L.nr(): " << L.nr()
|
<< "\n\t L.nr(): " << L.nr()
|
||||||
@ -120,7 +125,7 @@ namespace dlib
|
|||||||
|
|
||||||
using std::min;
|
using std::min;
|
||||||
const unsigned long n = min(num_correlations, (unsigned long)min(R.nr(),min(L.nc(), R.nc())));
|
const unsigned long n = min(num_correlations, (unsigned long)min(R.nr(),min(L.nc(), R.nc())));
|
||||||
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n);
|
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n, regularization);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
@ -133,14 +138,17 @@ namespace dlib
|
|||||||
matrix<T>& Rtrans,
|
matrix<T>& Rtrans,
|
||||||
unsigned long num_correlations,
|
unsigned long num_correlations,
|
||||||
unsigned long extra_rank = 5,
|
unsigned long extra_rank = 5,
|
||||||
unsigned long q = 2
|
unsigned long q = 2,
|
||||||
|
double regularization = 0
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
DLIB_ASSERT( num_correlations > 0 && L.size() == R.size() &&
|
DLIB_ASSERT( num_correlations > 0 && L.size() == R.size() &&
|
||||||
max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0,
|
max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0 &&
|
||||||
|
regularization >= 0,
|
||||||
"\t matrix cca()"
|
"\t matrix cca()"
|
||||||
<< "\n\t Invalid inputs were given to this function."
|
<< "\n\t Invalid inputs were given to this function."
|
||||||
<< "\n\t num_correlations: " << num_correlations
|
<< "\n\t num_correlations: " << num_correlations
|
||||||
|
<< "\n\t regularization: " << regularization
|
||||||
<< "\n\t L.size(): " << L.size()
|
<< "\n\t L.size(): " << L.size()
|
||||||
<< "\n\t R.size(): " << R.size()
|
<< "\n\t R.size(): " << R.size()
|
||||||
<< "\n\t max_index_plus_one(L): " << max_index_plus_one(L)
|
<< "\n\t max_index_plus_one(L): " << max_index_plus_one(L)
|
||||||
@ -150,7 +158,7 @@ namespace dlib
|
|||||||
using std::min;
|
using std::min;
|
||||||
const unsigned long n = min(max_index_plus_one(L), max_index_plus_one(R));
|
const unsigned long n = min(max_index_plus_one(L), max_index_plus_one(R));
|
||||||
const unsigned long num_output_correlations = min(num_correlations, std::min<unsigned long>(R.size(),n));
|
const unsigned long num_output_correlations = min(num_correlations, std::min<unsigned long>(R.size(),n));
|
||||||
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations);
|
return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations, regularization);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
@ -14,9 +14,9 @@ namespace dlib
|
|||||||
template <
|
template <
|
||||||
typename T
|
typename T
|
||||||
>
|
>
|
||||||
matrix<T,0,1> compute_correlations (
|
matrix<typename T::type,0,1> compute_correlations (
|
||||||
const matrix<T>& L,
|
const matrix_exp<T>& L,
|
||||||
const matrix<T>& R
|
const matrix_exp<T>& R
|
||||||
);
|
);
|
||||||
/*!
|
/*!
|
||||||
requires
|
requires
|
||||||
@ -47,7 +47,8 @@ namespace dlib
|
|||||||
matrix<T>& Rtrans,
|
matrix<T>& Rtrans,
|
||||||
unsigned long num_correlations,
|
unsigned long num_correlations,
|
||||||
unsigned long extra_rank = 5,
|
unsigned long extra_rank = 5,
|
||||||
unsigned long q = 2
|
unsigned long q = 2,
|
||||||
|
double regularization = 0
|
||||||
);
|
);
|
||||||
/*!
|
/*!
|
||||||
requires
|
requires
|
||||||
@ -55,6 +56,7 @@ namespace dlib
|
|||||||
- L.size() > 0
|
- L.size() > 0
|
||||||
- R.size() > 0
|
- R.size() > 0
|
||||||
- L.nr() == R.nr()
|
- L.nr() == R.nr()
|
||||||
|
- regularization >= 0
|
||||||
ensures
|
ensures
|
||||||
- This function performs a canonical correlation analysis between the row
|
- This function performs a canonical correlation analysis between the row
|
||||||
vectors in L and R. That is, it finds two transformation matrices, Ltrans
|
vectors in L and R. That is, it finds two transformation matrices, Ltrans
|
||||||
@ -83,11 +85,16 @@ namespace dlib
|
|||||||
problems.
|
problems.
|
||||||
- returns an estimate of compute_correlations(L*#Ltrans, R*#Rtrans). The
|
- returns an estimate of compute_correlations(L*#Ltrans, R*#Rtrans). The
|
||||||
returned vector should exactly match the output of compute_correlations()
|
returned vector should exactly match the output of compute_correlations()
|
||||||
when the reduced rank approximation to L and R is accurate. However, when L
|
when the reduced rank approximation to L and R is accurate and regularization
|
||||||
and/or R are higher rank than num_correlations+extra_rank the return value of
|
is set to 0. However, if this is not the case then the return value of this
|
||||||
this function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans).
|
function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans). This
|
||||||
This deviation can be used to check if the reduced rank approximation is
|
deviation can be used to check if the reduced rank approximation is working
|
||||||
working or you need to increase extra_rank.
|
or you need to increase extra_rank.
|
||||||
|
- This function performs the ridge regression version of Canonical Correlation
|
||||||
|
Analysis when regularization is set to a value > 0. In particular, larger
|
||||||
|
values indicate the solution should be more heavily regularized. This can be
|
||||||
|
useful when the dimensionality of the data is larger than the number of
|
||||||
|
samples.
|
||||||
- A good discussion of CCA can be found in the paper "Canonical Correlation
|
- A good discussion of CCA can be found in the paper "Canonical Correlation
|
||||||
Analysis" by David Weenink. In particular, this function is implemented
|
Analysis" by David Weenink. In particular, this function is implemented
|
||||||
using equations 29 and 30 from his paper. We also use the idea of doing CCA
|
using equations 29 and 30 from his paper. We also use the idea of doing CCA
|
||||||
@ -109,7 +116,8 @@ namespace dlib
|
|||||||
matrix<T>& Rtrans,
|
matrix<T>& Rtrans,
|
||||||
unsigned long num_correlations,
|
unsigned long num_correlations,
|
||||||
unsigned long extra_rank = 5,
|
unsigned long extra_rank = 5,
|
||||||
unsigned long q = 2
|
unsigned long q = 2,
|
||||||
|
double regularization = 0
|
||||||
);
|
);
|
||||||
/*!
|
/*!
|
||||||
requires
|
requires
|
||||||
@ -119,6 +127,7 @@ namespace dlib
|
|||||||
(i.e. L and R can't represent empty matrices)
|
(i.e. L and R can't represent empty matrices)
|
||||||
- L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
|
- L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
|
||||||
for a definition of sparse vector)
|
for a definition of sparse vector)
|
||||||
|
- regularization >= 0
|
||||||
ensures
|
ensures
|
||||||
- This is just an overload of the cca() function defined above. Except in this
|
- This is just an overload of the cca() function defined above. Except in this
|
||||||
case we take a sparse representation of the input L and R matrices rather than
|
case we take a sparse representation of the input L and R matrices rather than
|
||||||
@ -144,7 +153,8 @@ namespace dlib
|
|||||||
matrix<T>& Rtrans,
|
matrix<T>& Rtrans,
|
||||||
unsigned long num_correlations,
|
unsigned long num_correlations,
|
||||||
unsigned long extra_rank = 5,
|
unsigned long extra_rank = 5,
|
||||||
unsigned long q = 2
|
unsigned long q = 2,
|
||||||
|
double regularization = 0
|
||||||
);
|
);
|
||||||
/*!
|
/*!
|
||||||
requires
|
requires
|
||||||
@ -154,6 +164,7 @@ namespace dlib
|
|||||||
(i.e. L and R can't represent empty matrices)
|
(i.e. L and R can't represent empty matrices)
|
||||||
- L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
|
- L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h
|
||||||
for a definition of sparse vector)
|
for a definition of sparse vector)
|
||||||
|
- regularization >= 0
|
||||||
ensures
|
ensures
|
||||||
- returns cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q)
|
- returns cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q)
|
||||||
(i.e. this is just a convenience function for calling the cca() routine when
|
(i.e. this is just a convenience function for calling the cca() routine when
|
||||||
|
Loading…
Reference in New Issue
Block a user