Added the diagm(), svd2() and svd3() functions.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402627
This commit is contained in:
Davis King 2008-11-03 02:31:53 +00:00
parent e786884acc
commit 112eaab9b0
2 changed files with 554 additions and 15 deletions

View File

@ -776,6 +776,393 @@ namespace dlib
b(i) = sum/a(i,i);
}
}
// ------------------------------------------------------------------------------------
}
template <
typename EXP,
long qN, long qX,
long uM,
long vN,
typename MM1,
typename MM2,
typename MM3
>
long svd2 (
bool withu,
bool withv,
const matrix_exp<EXP>& a,
matrix<typename EXP::type,uM,uM,MM1>& u,
matrix<typename EXP::type,qN,qX,MM2>& q,
matrix<typename EXP::type,vN,vN,MM3>& v
)
{
/*
Singular value decomposition. Translated to 'C' from the
original Algol code in "Handbook for Automatic Computation,
vol. II, Linear Algebra", Springer-Verlag. Note that this
published algorithm is considered to be the best and numerically
stable approach to computing the real-valued svd and is referenced
repeatedly in ieee journal papers, etc where the svd is used.
This is almost an exact translation from the original, except that
an iteration counter is added to prevent stalls. This corresponds
to similar changes in other translations.
Returns an error code = 0, if no errors and 'k' if a failure to
converge at the 'kth' singular value.
USAGE: given the singular value decomposition a = u * diagm(q) * v' for an m*n
matrix a with m >= n ...
After the svd call u is an m x m matrix which is columnwise
orthogonal. q will be an n element vector consisting of singular values
and v an n x n orthogonal matrix. eps and tol are tolerance constants.
Suitable values are eps=1e-16 and tol=(1e-300)/eps if T == double.
If withu == false then u won't be computed and similarly if withv == false
then v won't be computed.
*/
const long NR = matrix_exp<EXP>::NR;
const long NC = matrix_exp<EXP>::NC;
// make sure the output matrices have valid dimensions if they are statically dimensioned
COMPILE_TIME_ASSERT(qX == 0 || qX == 1);
COMPILE_TIME_ASSERT(NR == 0 || uM == 0 || NR == uM);
COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
DLIB_ASSERT(a.nr() >= a.nc(),
"\tconst matrix_exp svd2()"
<< "\n\tYou have given an invalidly sized matrix"
<< "\n\ta.nr(): " << a.nr()
<< "\n\ta.nc(): " << a.nc()
);
typedef typename EXP::type T;
using std::abs;
using std::sqrt;
T eps = std::numeric_limits<T>::epsilon();
T tol = std::numeric_limits<T>::min()/eps;
const long m = a.nr();
const long n = a.nc();
long i, j, k, l, l1, iter, retval;
T c, f, g, h, s, x, y, z;
matrix<T,qN,1,MM2> e(n,1);
q.set_size(n,1);
u.set_size(m,m);
retval = 0;
if (withv)
{
v.set_size(n,n);
}
/* Copy 'a' to 'u' */
for (i=0; i<m; i++)
{
for (j=0; j<n; j++)
u(i,j) = a(i,j);
}
/* Householder's reduction to bidiagonal form. */
g = x = 0.0;
for (i=0; i<n; i++)
{
e(i) = g;
s = 0.0;
l = i + 1;
for (j=i; j<m; j++)
s += (u(j,i) * u(j,i));
if (s < tol)
g = 0.0;
else
{
f = u(i,i);
g = (f < 0) ? sqrt(s) : -sqrt(s);
h = f * g - s;
u(i,i) = f - g;
for (j=l; j<n; j++)
{
s = 0.0;
for (k=i; k<m; k++)
s += (u(k,i) * u(k,j));
f = s / h;
for (k=i; k<m; k++)
u(k,j) += (f * u(k,i));
} /* end j */
} /* end s */
q(i) = g;
s = 0.0;
for (j=l; j<n; j++)
s += (u(i,j) * u(i,j));
if (s < tol)
g = 0.0;
else
{
f = u(i,i+1);
g = (f < 0) ? sqrt(s) : -sqrt(s);
h = f * g - s;
u(i,i+1) = f - g;
for (j=l; j<n; j++)
e(j) = u(i,j) / h;
for (j=l; j<m; j++)
{
s = 0.0;
for (k=l; k<n; k++)
s += (u(j,k) * u(i,k));
for (k=l; k<n; k++)
u(j,k) += (s * e(k));
} /* end j */
} /* end s */
y = abs(q(i)) + abs(e(i));
if (y > x)
x = y;
} /* end i */
/* accumulation of right-hand transformations */
if (withv)
{
for (i=n-1; i>=0; i--)
{
if (g != 0.0)
{
h = u(i,i+1) * g;
for (j=l; j<n; j++)
v(j,i) = u(i,j)/h;
for (j=l; j<n; j++)
{
s = 0.0;
for (k=l; k<n; k++)
s += (u(i,k) * v(k,j));
for (k=l; k<n; k++)
v(k,j) += (s * v(k,i));
} /* end j */
} /* end g */
for (j=l; j<n; j++)
v(i,j) = v(j,i) = 0.0;
v(i,i) = 1.0;
g = e(i);
l = i;
} /* end i */
} /* end withv, parens added for clarity */
/* accumulation of left-hand transformations */
if (withu)
{
for (i=n; i<m; i++)
{
for (j=n;j<m;j++)
u(i,j) = 0.0;
u(i,i) = 1.0;
}
}
if (withu)
{
for (i=n-1; i>=0; i--)
{
l = i + 1;
g = q(i);
for (j=l; j<m; j++) /* upper limit was 'n' */
u(i,j) = 0.0;
if (g != 0.0)
{
h = u(i,i) * g;
for (j=l; j<m; j++)
{ /* upper limit was 'n' */
s = 0.0;
for (k=l; k<m; k++)
s += (u(k,i) * u(k,j));
f = s / h;
for (k=i; k<m; k++)
u(k,j) += (f * u(k,i));
} /* end j */
for (j=i; j<m; j++)
u(j,i) /= g;
} /* end g */
else
{
for (j=i; j<m; j++)
u(j,i) = 0.0;
}
u(i,i) += 1.0;
} /* end i*/
} /* end withu, parens added for clarity */
/* diagonalization of the bidiagonal form */
eps *= x;
for (k=n-1; k>=0; k--)
{
iter = 0;
test_f_splitting:
for (l=k; l>=0; l--)
{
if (abs(e(l)) <= eps)
goto test_f_convergence;
if (abs(q(l-1)) <= eps)
goto cancellation;
} /* end l */
/* cancellation of e(l) if l > 0 */
cancellation:
c = 0.0;
s = 1.0;
l1 = l - 1;
for (i=l; i<=k; i++)
{
f = s * e(i);
e(i) *= c;
if (abs(f) <= eps)
goto test_f_convergence;
g = q(i);
h = q(i) = sqrt(f*f + g*g);
c = g / h;
s = -f / h;
if (withu)
{
for (j=0; j<m; j++)
{
y = u(j,l1);
z = u(j,i);
u(j,l1) = y * c + z * s;
u(j,i) = -y * s + z * c;
} /* end j */
} /* end withu, parens added for clarity */
} /* end i */
test_f_convergence:
z = q(k);
if (l == k)
goto convergence;
/* shift from bottom 2x2 minor */
iter++;
if (iter > 30)
{
retval = k;
break;
}
x = q(l);
y = q(k-1);
g = e(k-1);
h = e(k);
f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2 * h * y);
g = sqrt(f * f + 1.0);
f = ((x - z) * (x + z) + h * (y / ((f < 0)?(f - g) : (f + g)) - h)) / x;
/* next QR transformation */
c = s = 1.0;
for (i=l+1; i<=k; i++)
{
g = e(i);
y = q(i);
h = s * g;
g *= c;
e(i-1) = z = sqrt(f * f + h * h);
c = f / z;
s = h / z;
f = x * c + g * s;
g = -x * s + g * c;
h = y * s;
y *= c;
if (withv)
{
for (j=0;j<n;j++)
{
x = v(j,i-1);
z = v(j,i);
v(j,i-1) = x * c + z * s;
v(j,i) = -x * s + z * c;
} /* end j */
} /* end withv, parens added for clarity */
q(i-1) = z = sqrt(f * f + h * h);
c = f / z;
s = h / z;
f = c * g + s * y;
x = -s * g + c * y;
if (withu)
{
for (j=0; j<m; j++)
{
y = u(j,i-1);
z = u(j,i);
u(j,i-1) = y * c + z * s;
u(j,i) = -y * s + z * c;
} /* end j */
} /* end withu, parens added for clarity */
} /* end i */
e(l) = 0.0;
e(k) = f;
q(k) = x;
goto test_f_splitting;
convergence:
if (z < 0.0)
{
/* q(k) is made non-negative */
q(k) = -z;
if (withv)
{
for (j=0; j<n; j++)
v(j,k) = -v(j,k);
} /* end withv, parens added for clarity */
} /* end z */
} /* end k */
return retval;
}
// ----------------------------------------------------------------------------------------
@ -2261,6 +2648,53 @@ namespace dlib
return matrix_exp<exp>(exp(m,R));
}
// ----------------------------------------------------------------------------------------
struct op_diagm
{
template <typename EXP>
struct op : has_destructive_aliasing
{
const static long N = EXP::NC*EXP::NR;
const static long NR = N;
const static long NC = N;
typedef typename EXP::type type;
typedef typename EXP::mem_manager_type mem_manager_type;
template <typename M>
static type apply ( const M& m, long r, long c)
{
if (r==c)
return m(r);
else
return 0;
}
template <typename M>
static long nr (const M& m) { return (m.nr()>m.nc())? m.nr():m.nc(); }
template <typename M>
static long nc (const M& m) { return (m.nr()>m.nc())? m.nr():m.nc(); }
};
};
template <
typename EXP
>
const matrix_exp<matrix_unary_exp<matrix_exp<EXP>,op_diagm> > diagm (
const matrix_exp<EXP>& m
)
{
// You can only make a diagonal matrix out of a row or column vector
COMPILE_TIME_ASSERT(EXP::NR == 0 || EXP::NR == 1 || EXP::NC == 1 || EXP::NC == 0);
DLIB_ASSERT(m.nr() == 1 || m.nc() == 1,
"\tconst matrix_exp diagm(const matrix_exp& m)"
<< "\n\tYou can only apply diagm() to a row or column matrix"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
);
typedef matrix_unary_exp<matrix_exp<EXP>,op_diagm> exp;
return matrix_exp<exp>(exp(m));
}
// ----------------------------------------------------------------------------------------
struct op_diag
@ -2379,6 +2813,47 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
template <
typename EXP,
long uNR,
long uNC,
long wN,
long vN,
long wX,
typename MM1,
typename MM2,
typename MM3
>
inline void svd3 (
const matrix_exp<EXP>& m,
matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1>& u,
matrix<typename matrix_exp<EXP>::type, wN, wX,MM2>& w,
matrix<typename matrix_exp<EXP>::type, vN, vN,MM3>& v
)
{
typedef typename matrix_exp<EXP>::type T;
const long NR = matrix_exp<EXP>::NR;
const long NC = matrix_exp<EXP>::NC;
// make sure the output matrices have valid dimensions if they are statically dimensioned
COMPILE_TIME_ASSERT(NR == 0 || uNR == 0 || NR == uNR);
COMPILE_TIME_ASSERT(NC == 0 || uNC == 0 || NC == uNC);
COMPILE_TIME_ASSERT(NC == 0 || wN == 0 || NC == wN);
COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
COMPILE_TIME_ASSERT(wX == 0 || wX == 1);
v.set_size(m.nc(),m.nc());
typedef typename matrix_exp<EXP>::type T;
u = m;
w.set_size(m.nc(),1);
matrix<T,matrix_exp<EXP>::NC,1,MM1> rv1(m.nc(),1);
nric::svdcmp(u,w,v,rv1);
}
// ----------------------------------------------------------------------------------------
template <
@ -2408,20 +2883,9 @@ namespace dlib
COMPILE_TIME_ASSERT(NC == 0 || wN == 0 || NC == wN);
COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
w.set_size(m.nc(),m.nc());
v.set_size(m.nc(),m.nc());
typedef typename matrix_exp<EXP>::type T;
u = m;
matrix<T,matrix_exp<EXP>::NC,1,MM1> W(m.nc(),1);
matrix<T,matrix_exp<EXP>::NC,1,MM1> rv1(m.nc(),1);
set_all_elements(w,0);
nric::svdcmp(u,W,v,rv1);
for (long r = 0; r < W.nr(); ++r)
w(r,r) = W(r);
matrix<T,matrix_exp<EXP>::NC,1,MM1> W;
svd3(m,u,W,v);
w = diagm(W);
}
// ----------------------------------------------------------------------------------------

View File

@ -29,6 +29,20 @@ namespace dlib
of m in the order R(0)==m(0,0), R(1)==m(1,1), R(2)==m(2,2) and so on.
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp diagm (
const matrix_exp& m
);
/*!
requires
- m is a row or column matrix
ensures
- returns a square matrix M such that:
- diag(M) == m
- non diagonal elements of M are 0
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp trans (
@ -730,7 +744,7 @@ namespace dlib
- m == #u*#w*trans(#v)
- trans(#u)*#u == identity matrix
- trans(#v)*#v == identity matrix
- diag(#w) == the sinular values of the matrix m in no
- diag(#w) == the singular values of the matrix m in no
particular order. All non-diagonal elements of #w are
set to 0.
- #u.nr() == m.nr()
@ -741,6 +755,67 @@ namespace dlib
- #v.nc() == m.nc()
!*/
// ----------------------------------------------------------------------------------------
long svd2 (
bool withu,
bool withv,
const matrix_exp& m,
matrix<matrix_exp::type>& u,
matrix<matrix_exp::type>& w,
matrix<matrix_exp::type>& v
);
/*!
requires
- a.nr() >= a.nc()
ensures
- computes the singular value decomposition of matrix a
- a == subm(#u,get_rect(a))*diagm(#w)*trans(#v)
- trans(#u)*#u == identity matrix
- trans(#v)*#v == identity matrix
- #w == the singular values of the matrix m in no
particular order.
- #u.nr() == m.nr()
- #u.nc() == m.nr()
- #w.nr() == m.nc()
- #w.nc() == 1
- #v.nr() == m.nc()
- #v.nc() == m.nc()
- if (widthu == false) then
- ignore the above regarding #u, it isn't computed and its
output state is undefined.
- if (widthv == false) then
- ignore the above regarding #v, it isn't computed and its
output state is undefined.
- returns an error code of 0, if no errors and 'k' if we fail to
converge at the 'kth' singular value.
!*/
// ----------------------------------------------------------------------------------------
void svd3 (
const matrix_exp& m,
matrix<matrix_exp::type>& u,
matrix<matrix_exp::type>& w,
matrix<matrix_exp::type>& v
);
/*!
ensures
- computes the singular value decomposition of m
- m == #u*diagm(#w)*trans(#v)
- trans(#u)*#u == identity matrix
- trans(#v)*#v == identity matrix
- #w == the singular values of the matrix m in no
particular order. All non-diagonal elements of #w are
set to 0.
- #u.nr() == m.nr()
- #u.nc() == m.nc()
- #w.nr() == m.nc()
- #w.nc() == 1
- #v.nr() == m.nc()
- #v.nc() == m.nc()
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp::type det (