mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Made svd_fast() accept a wider range of matrices as arguments.
This commit is contained in:
parent
6391a03440
commit
a99fc5661a
@ -835,12 +835,17 @@ convergence:
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
template <
|
||||
typename T,
|
||||
long Anr, long Anc,
|
||||
typename MM,
|
||||
typename L
|
||||
>
|
||||
void find_matrix_range (
|
||||
const matrix<T>& A,
|
||||
const matrix<T,Anr,Anc,MM,L>& A,
|
||||
unsigned long l,
|
||||
matrix<T>& Q,
|
||||
unsigned long q = 0
|
||||
matrix<T,Anr,0,MM,L>& Q,
|
||||
unsigned long q
|
||||
)
|
||||
/*!
|
||||
requires
|
||||
@ -882,12 +887,20 @@ convergence:
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
template <
|
||||
typename T,
|
||||
long Anr, long Anc,
|
||||
long Unr, long Unc,
|
||||
long Wnr, long Wnc,
|
||||
long Vnr, long Vnc,
|
||||
typename MM,
|
||||
typename L
|
||||
>
|
||||
void svd_fast (
|
||||
const matrix<T>& A,
|
||||
matrix<T>& u,
|
||||
matrix<T,0,1>& w,
|
||||
matrix<T>& v,
|
||||
const matrix<T,Anr,Anc,MM,L>& A,
|
||||
matrix<T,Unr,Unc,MM,L>& u,
|
||||
matrix<T,Wnr,Wnc,MM,L>& w,
|
||||
matrix<T,Vnr,Vnc,MM,L>& v,
|
||||
unsigned long l,
|
||||
unsigned long q = 1
|
||||
)
|
||||
@ -901,26 +914,31 @@ convergence:
|
||||
<< "\n\t A.size(): " << A.size()
|
||||
);
|
||||
|
||||
matrix<T> Q;
|
||||
matrix<T,Anr,0,MM,L> Q;
|
||||
find_matrix_range(A, k, Q, q);
|
||||
|
||||
// Compute trans(B) = trans(Q)*A. The reason we store B transposed
|
||||
// is so that when we take its SVD later using svd3() it doesn't consume
|
||||
// a whole lot of RAM. That is, we make sure the square matrix coming out
|
||||
// of svd3() has size lxl rather than the potentially much larger nxn.
|
||||
matrix<T> B = trans(A)*Q;
|
||||
matrix<T,0,0,MM,L> B = trans(A)*Q;
|
||||
svd3(B, v,w,u);
|
||||
u = Q*u;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename sparse_vector_type, typename T>
|
||||
template <
|
||||
typename sparse_vector_type,
|
||||
typename T,
|
||||
typename MM,
|
||||
typename L
|
||||
>
|
||||
void find_matrix_range (
|
||||
const std::vector<sparse_vector_type>& A,
|
||||
unsigned long l,
|
||||
matrix<T>& Q,
|
||||
unsigned long q = 0
|
||||
matrix<T,0,0,MM,L>& Q,
|
||||
unsigned long q
|
||||
)
|
||||
/*!
|
||||
requires
|
||||
@ -962,7 +980,7 @@ convergence:
|
||||
const unsigned long n = max_index_plus_one(A);
|
||||
for (unsigned long itr = 0; itr < q; ++itr)
|
||||
{
|
||||
matrix<T> Z(n, l);
|
||||
matrix<T,0,0,MM,L> Z(n, l);
|
||||
// Compute Z = trans(A)*Q
|
||||
Z = 0;
|
||||
for (unsigned long m = 0; m < A.size(); ++m)
|
||||
@ -1001,12 +1019,20 @@ convergence:
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename sparse_vector_type, typename T>
|
||||
template <
|
||||
typename sparse_vector_type,
|
||||
typename T,
|
||||
long Unr, long Unc,
|
||||
long Wnr, long Wnc,
|
||||
long Vnr, long Vnc,
|
||||
typename MM,
|
||||
typename L
|
||||
>
|
||||
void svd_fast (
|
||||
const std::vector<sparse_vector_type>& A,
|
||||
matrix<T>& u,
|
||||
matrix<T,0,1>& w,
|
||||
matrix<T>& v,
|
||||
matrix<T,Unr,Unc,MM,L>& u,
|
||||
matrix<T,Wnr,Wnc,MM,L>& w,
|
||||
matrix<T,Vnr,Vnc,MM,L>& v,
|
||||
unsigned long l,
|
||||
unsigned long q = 1
|
||||
)
|
||||
@ -1022,14 +1048,14 @@ convergence:
|
||||
<< "\n\t A.size(): " << A.size()
|
||||
);
|
||||
|
||||
matrix<T> Q;
|
||||
matrix<T,0,0,MM,L> Q;
|
||||
find_matrix_range(A, k, Q, q);
|
||||
|
||||
// Compute trans(B) = trans(Q)*A. The reason we store B transposed
|
||||
// is so that when we take its SVD later using svd3() it doesn't consume
|
||||
// a whole lot of RAM. That is, we make sure the square matrix coming out
|
||||
// of svd3() has size lxl rather than the potentially much larger nxn.
|
||||
matrix<T> B(n,k);
|
||||
matrix<T,0,0,MM,L> B(n,k);
|
||||
B = 0;
|
||||
for (unsigned long m = 0; m < A.size(); ++m)
|
||||
{
|
||||
|
@ -141,7 +141,7 @@ namespace dlib
|
||||
void svd_fast (
|
||||
const matrix<T>& A,
|
||||
matrix<T>& u,
|
||||
matrix<T,0,1>& w,
|
||||
matrix<T>& w,
|
||||
matrix<T>& v,
|
||||
unsigned long l,
|
||||
unsigned long q = 1
|
||||
@ -191,7 +191,7 @@ namespace dlib
|
||||
void svd_fast (
|
||||
const std::vector<sparse_vector_type>& A,
|
||||
matrix<T>& u,
|
||||
matrix<T,0,1>& w,
|
||||
matrix<T>& w,
|
||||
matrix<T>& v,
|
||||
unsigned long l,
|
||||
unsigned long q = 1
|
||||
|
Loading…
Reference in New Issue
Block a user