Optimized the cost values for a few matrix expressions.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403795
This commit is contained in:
Davis King 2010-07-28 23:11:47 +00:00
parent db98db26c5
commit fb46e2fb4e
6 changed files with 50 additions and 22 deletions

View File

@ -100,8 +100,8 @@ namespace dlib
#ifdef DLIB_USE_BLAS
// if there are BLAS functions to be called then we want to make sure we
// always evaluate any complex expressions so that the BLAS bindings can happen.
const static bool lhs_is_costly = (LHS::cost > 1)&&(RHS::NC != 1 || LHS::cost >= 10000);
const static bool rhs_is_costly = (RHS::cost > 1)&&(LHS::NR != 1 || RHS::cost >= 10000);
const static bool lhs_is_costly = (LHS::cost > 2)&&(RHS::NC != 1 || LHS::cost >= 10000);
const static bool rhs_is_costly = (RHS::cost > 2)&&(LHS::NR != 1 || RHS::cost >= 10000);
#else
const static bool lhs_is_costly = (LHS::cost > 4)&&(RHS::NC != 1);
const static bool rhs_is_costly = (RHS::cost > 4)&&(LHS::NR != 1);
@ -287,7 +287,7 @@ namespace dlib
typedef typename LHS::layout_type layout_type;
const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR;
const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC;
const static long cost = LHS::cost+RHS::cost;
const static long cost = LHS::cost+RHS::cost+1;
};
template <
@ -394,7 +394,7 @@ namespace dlib
typedef typename LHS::layout_type layout_type;
const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR;
const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC;
const static long cost = LHS::cost+RHS::cost;
const static long cost = LHS::cost+RHS::cost+1;
};
template <

View File

@ -28,7 +28,7 @@ namespace dlib
op_conj_trans( const M& m_) : m(m_){}
const M& m;
const static long cost = M::cost;
const static long cost = M::cost+1;
const static long NR = M::NC;
const static long NC = M::NR;
typedef typename M::type type;

View File

@ -23,7 +23,7 @@ namespace dlib
DLIB_DEFINE_FUNCTION_M(op_log10, log10, std::log10 ,7);
DLIB_DEFINE_FUNCTION_M(op_exp, exp, std::exp ,7);
DLIB_DEFINE_FUNCTION_M(op_conj, conj, std::conj ,1);
DLIB_DEFINE_FUNCTION_M(op_conj, conj, std::conj ,2);
DLIB_DEFINE_FUNCTION_M(op_ceil, ceil, std::ceil ,7);
DLIB_DEFINE_FUNCTION_M(op_floor, floor, std::floor ,7);

View File

@ -762,7 +762,7 @@ namespace dlib
op_remove_col( const M& m_) : m(m_){}
const M& m;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
const static long NR = M::NR;
const static long NC = (M::NC==0) ? 0 : (M::NC - 1);
typedef typename M::type type;
@ -795,7 +795,7 @@ namespace dlib
const M& m;
const long C;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
const static long NR = M::NR;
const static long NC = (M::NC==0) ? 0 : (M::NC - 1);
typedef typename M::type type;
@ -870,7 +870,7 @@ namespace dlib
op_remove_row( const M& m_) : m(m_){}
const M& m;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
const static long NR = (M::NR==0) ? 0 : (M::NR - 1);
const static long NC = M::NC;
typedef typename M::type type;
@ -903,7 +903,7 @@ namespace dlib
const M& m;
const long R;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
const static long NR = (M::NR==0) ? 0 : (M::NR - 1);
const static long NC = M::NC;
typedef typename M::type type;
@ -978,7 +978,7 @@ namespace dlib
op_diagm( const M& m_) : m(m_){}
const M& m;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
const static long N = M::NC*M::NR;
const static long NR = N;
const static long NC = N;
@ -1063,7 +1063,7 @@ namespace dlib
op_cast( const M& m_) : m(m_){}
const M& m;
const static long cost = M::cost;
const static long cost = M::cost+2;
const static long NR = M::NR;
const static long NC = M::NC;
typedef target_type type;
@ -1514,7 +1514,7 @@ namespace dlib
op_sumr(const M& m_) : m(m_) {}
const M& m;
const static long cost = M::cost;
const static long cost = M::cost+10;
const static long NR = 1;
const static long NC = M::NC;
typedef typename M::type type;
@ -1560,7 +1560,7 @@ namespace dlib
op_sumc(const M& m_) : m(m_) {}
const M& m;
const static long cost = M::cost;
const static long cost = M::cost + 10;
const static long NR = M::NR;
const static long NC = 1;
typedef typename M::type type;
@ -1975,7 +1975,7 @@ namespace dlib
op_rotate(const M& m_) : m(m_) {}
const M& m;
const static long cost = M::cost + 1;
const static long cost = M::cost + 2;
const static long NR = M::NR;
const static long NC = M::NC;
typedef typename M::type type;
@ -2350,7 +2350,7 @@ namespace dlib
typedef typename M::type type;
typedef const typename M::type const_ret_type;
const static long cost = M::cost + 1;
const static long cost = M::cost + 2;
const_ret_type apply ( long r, long c) const
{
@ -2428,7 +2428,7 @@ namespace dlib
const long rows;
const long cols;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
const static long NR = 0;
const static long NC = 0;
typedef typename M::type type;
@ -2739,7 +2739,7 @@ namespace dlib
{
op_lowerm( const M& m_) : basic_op_m<M>(m_){}
const static long cost = M::cost+1;
const static long cost = M::cost+2;
typedef typename M::type type;
typedef const typename M::type const_ret_type;
const_ret_type apply ( long r, long c) const
@ -2759,7 +2759,7 @@ namespace dlib
const type s;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
typedef const typename M::type const_ret_type;
const_ret_type apply ( long r, long c) const
{
@ -2802,7 +2802,7 @@ namespace dlib
{
op_upperm( const M& m_) : basic_op_m<M>(m_){}
const static long cost = M::cost+1;
const static long cost = M::cost+2;
typedef typename M::type type;
typedef const typename M::type const_ret_type;
const_ret_type apply ( long r, long c) const
@ -2822,7 +2822,7 @@ namespace dlib
const type s;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
typedef const typename M::type const_ret_type;
const_ret_type apply ( long r, long c) const
{
@ -3013,7 +3013,7 @@ namespace dlib
op_mat_to_vect(const M& m_) : m(m_) {}
const M& m;
const static long cost = M::cost+1;
const static long cost = M::cost+2;
const static long NR = M::NC*M::NR;
const static long NC = 1;
typedef typename M::type type;

View File

@ -57,10 +57,22 @@ namespace
b = a*a;
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = a/2*a;
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = a*trans(a) + a;
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = (a+a)*(a+a);
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = a*(a-a);
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = trans(a)*trans(a) + a;
DLIB_TEST(counter_gemm() == 1);

View File

@ -1196,6 +1196,22 @@ namespace
DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
}
{
matrix<double> m1(3,3), m2(3,3);
m1 = 1;
m2 = 1;
m1 = m1*subm(m2,0,0,3,3);
}
{
matrix<double,3,1> m1;
matrix<double> m2(3,3);
m1 = 1;
m2 = 1;
m1 = subm(m2,0,0,3,3)*m1;
}
}