Made inv() handle singular matrices in a more reasonable way. Now it will make

some effort to detect them and output an identity matrix in that case.
This commit is contained in:
Davis King 2014-08-02 18:56:01 -04:00
parent f5c7248b20
commit f9fac06e16
2 changed files with 93 additions and 43 deletions

View File

@ -1120,7 +1120,11 @@ convergence:
typedef typename matrix_exp<EXP>::type type;
matrix<type, 1, 1, typename EXP::mem_manager_type> a;
a(0) = 1/m(0);
// if m is invertible
if (m(0) != 0)
a(0) = 1/m(0);
else
a(0) = 1;
return a;
}
};
@ -1138,11 +1142,20 @@ convergence:
typedef typename matrix_exp<EXP>::type type;
matrix<type, 2, 2, typename EXP::mem_manager_type> a;
type d = static_cast<type>(1.0/det(m));
a(0,0) = m(1,1)*d;
a(0,1) = m(0,1)*-d;
a(1,0) = m(1,0)*-d;
a(1,1) = m(0,0)*d;
type d = det(m);
if (d != 0)
{
d = static_cast<type>(1.0/d);
a(0,0) = m(1,1)*d;
a(0,1) = m(0,1)*-d;
a(1,0) = m(1,0)*-d;
a(1,1) = m(0,0)*d;
}
else
{
// Matrix isn't invertible so just return the identity matrix.
a = identity_matrix<type,2>();
}
return a;
}
};
@ -1160,28 +1173,36 @@ convergence:
typedef typename matrix_exp<EXP>::type type;
matrix<type, 3, 3, typename EXP::mem_manager_type> ret;
const type de = static_cast<type>(1.0/det(m));
const type a = m(0,0);
const type b = m(0,1);
const type c = m(0,2);
const type d = m(1,0);
const type e = m(1,1);
const type f = m(1,2);
const type g = m(2,0);
const type h = m(2,1);
const type i = m(2,2);
type de = det(m);
if (de != 0)
{
de = static_cast<type>(1.0/de);
const type a = m(0,0);
const type b = m(0,1);
const type c = m(0,2);
const type d = m(1,0);
const type e = m(1,1);
const type f = m(1,2);
const type g = m(2,0);
const type h = m(2,1);
const type i = m(2,2);
ret(0,0) = (e*i - f*h)*de;
ret(1,0) = (f*g - d*i)*de;
ret(2,0) = (d*h - e*g)*de;
ret(0,0) = (e*i - f*h)*de;
ret(1,0) = (f*g - d*i)*de;
ret(2,0) = (d*h - e*g)*de;
ret(0,1) = (c*h - b*i)*de;
ret(1,1) = (a*i - c*g)*de;
ret(2,1) = (b*g - a*h)*de;
ret(0,1) = (c*h - b*i)*de;
ret(1,1) = (a*i - c*g)*de;
ret(2,1) = (b*g - a*h)*de;
ret(0,2) = (b*f - c*e)*de;
ret(1,2) = (c*d - a*f)*de;
ret(2,2) = (a*e - b*d)*de;
ret(0,2) = (b*f - c*e)*de;
ret(1,2) = (c*d - a*f)*de;
ret(2,2) = (a*e - b*d)*de;
}
else
{
ret = identity_matrix<type,3>();
}
return ret;
}
@ -1200,28 +1221,36 @@ convergence:
typedef typename matrix_exp<EXP>::type type;
matrix<type, 4, 4, typename EXP::mem_manager_type> ret;
const type de = static_cast<type>(1.0/det(m));
ret(0,0) = det(removerc<0,0>(m));
ret(0,1) = -det(removerc<0,1>(m));
ret(0,2) = det(removerc<0,2>(m));
ret(0,3) = -det(removerc<0,3>(m));
type de = det(m);
if (de != 0)
{
de = static_cast<type>(1.0/de);
ret(0,0) = det(removerc<0,0>(m));
ret(0,1) = -det(removerc<0,1>(m));
ret(0,2) = det(removerc<0,2>(m));
ret(0,3) = -det(removerc<0,3>(m));
ret(1,0) = -det(removerc<1,0>(m));
ret(1,1) = det(removerc<1,1>(m));
ret(1,2) = -det(removerc<1,2>(m));
ret(1,3) = det(removerc<1,3>(m));
ret(1,0) = -det(removerc<1,0>(m));
ret(1,1) = det(removerc<1,1>(m));
ret(1,2) = -det(removerc<1,2>(m));
ret(1,3) = det(removerc<1,3>(m));
ret(2,0) = det(removerc<2,0>(m));
ret(2,1) = -det(removerc<2,1>(m));
ret(2,2) = det(removerc<2,2>(m));
ret(2,3) = -det(removerc<2,3>(m));
ret(2,0) = det(removerc<2,0>(m));
ret(2,1) = -det(removerc<2,1>(m));
ret(2,2) = det(removerc<2,2>(m));
ret(2,3) = -det(removerc<2,3>(m));
ret(3,0) = -det(removerc<3,0>(m));
ret(3,1) = det(removerc<3,1>(m));
ret(3,2) = -det(removerc<3,2>(m));
ret(3,3) = det(removerc<3,3>(m));
ret(3,0) = -det(removerc<3,0>(m));
ret(3,1) = det(removerc<3,1>(m));
ret(3,2) = -det(removerc<3,2>(m));
ret(3,3) = det(removerc<3,3>(m));
return trans(ret)*de;
return trans(ret)*de;
}
else
{
return identity_matrix<type,4>();
}
}
};

View File

@ -1102,6 +1102,27 @@ namespace
DLIB_TEST(m == m2);
}
{
print_spinner();
matrix<double,1,1> m1;
matrix<double,2,2> m2;
matrix<double,3,3> m3;
matrix<double,4,4> m4;
dlib::rand rnd;
for (int i = 0; i < 50; ++i)
{
m1 = randm(1,1,rnd);
m2 = randm(2,2,rnd);
m3 = randm(3,3,rnd);
m4 = randm(4,4,rnd);
DLIB_TEST(max(abs(m1*inv(m1) - identity_matrix(m1))) < 1e-13);
DLIB_TEST(max(abs(m2*inv(m2) - identity_matrix(m2))) < 1e-13);
DLIB_TEST(max(abs(m3*inv(m3) - identity_matrix(m3))) < 1e-13);
DLIB_TEST(max(abs(m4*inv(m4) - identity_matrix(m4))) < 1e-13);
}
}
}