mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Optimized the cost values for a few matrix expressions.
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403795
This commit is contained in:
parent
db98db26c5
commit
fb46e2fb4e
@ -100,8 +100,8 @@ namespace dlib
|
||||
#ifdef DLIB_USE_BLAS
|
||||
// if there are BLAS functions to be called then we want to make sure we
|
||||
// always evaluate any complex expressions so that the BLAS bindings can happen.
|
||||
const static bool lhs_is_costly = (LHS::cost > 1)&&(RHS::NC != 1 || LHS::cost >= 10000);
|
||||
const static bool rhs_is_costly = (RHS::cost > 1)&&(LHS::NR != 1 || RHS::cost >= 10000);
|
||||
const static bool lhs_is_costly = (LHS::cost > 2)&&(RHS::NC != 1 || LHS::cost >= 10000);
|
||||
const static bool rhs_is_costly = (RHS::cost > 2)&&(LHS::NR != 1 || RHS::cost >= 10000);
|
||||
#else
|
||||
const static bool lhs_is_costly = (LHS::cost > 4)&&(RHS::NC != 1);
|
||||
const static bool rhs_is_costly = (RHS::cost > 4)&&(LHS::NR != 1);
|
||||
@ -287,7 +287,7 @@ namespace dlib
|
||||
typedef typename LHS::layout_type layout_type;
|
||||
const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR;
|
||||
const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC;
|
||||
const static long cost = LHS::cost+RHS::cost;
|
||||
const static long cost = LHS::cost+RHS::cost+1;
|
||||
};
|
||||
|
||||
template <
|
||||
@ -394,7 +394,7 @@ namespace dlib
|
||||
typedef typename LHS::layout_type layout_type;
|
||||
const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR;
|
||||
const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC;
|
||||
const static long cost = LHS::cost+RHS::cost;
|
||||
const static long cost = LHS::cost+RHS::cost+1;
|
||||
};
|
||||
|
||||
template <
|
||||
|
@ -28,7 +28,7 @@ namespace dlib
|
||||
op_conj_trans( const M& m_) : m(m_){}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost;
|
||||
const static long cost = M::cost+1;
|
||||
const static long NR = M::NC;
|
||||
const static long NC = M::NR;
|
||||
typedef typename M::type type;
|
||||
|
@ -23,7 +23,7 @@ namespace dlib
|
||||
DLIB_DEFINE_FUNCTION_M(op_log10, log10, std::log10 ,7);
|
||||
DLIB_DEFINE_FUNCTION_M(op_exp, exp, std::exp ,7);
|
||||
|
||||
DLIB_DEFINE_FUNCTION_M(op_conj, conj, std::conj ,1);
|
||||
DLIB_DEFINE_FUNCTION_M(op_conj, conj, std::conj ,2);
|
||||
|
||||
DLIB_DEFINE_FUNCTION_M(op_ceil, ceil, std::ceil ,7);
|
||||
DLIB_DEFINE_FUNCTION_M(op_floor, floor, std::floor ,7);
|
||||
|
@ -762,7 +762,7 @@ namespace dlib
|
||||
op_remove_col( const M& m_) : m(m_){}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
const static long NR = M::NR;
|
||||
const static long NC = (M::NC==0) ? 0 : (M::NC - 1);
|
||||
typedef typename M::type type;
|
||||
@ -795,7 +795,7 @@ namespace dlib
|
||||
const M& m;
|
||||
const long C;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
const static long NR = M::NR;
|
||||
const static long NC = (M::NC==0) ? 0 : (M::NC - 1);
|
||||
typedef typename M::type type;
|
||||
@ -870,7 +870,7 @@ namespace dlib
|
||||
op_remove_row( const M& m_) : m(m_){}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
const static long NR = (M::NR==0) ? 0 : (M::NR - 1);
|
||||
const static long NC = M::NC;
|
||||
typedef typename M::type type;
|
||||
@ -903,7 +903,7 @@ namespace dlib
|
||||
const M& m;
|
||||
const long R;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
const static long NR = (M::NR==0) ? 0 : (M::NR - 1);
|
||||
const static long NC = M::NC;
|
||||
typedef typename M::type type;
|
||||
@ -978,7 +978,7 @@ namespace dlib
|
||||
op_diagm( const M& m_) : m(m_){}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
const static long N = M::NC*M::NR;
|
||||
const static long NR = N;
|
||||
const static long NC = N;
|
||||
@ -1063,7 +1063,7 @@ namespace dlib
|
||||
op_cast( const M& m_) : m(m_){}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost;
|
||||
const static long cost = M::cost+2;
|
||||
const static long NR = M::NR;
|
||||
const static long NC = M::NC;
|
||||
typedef target_type type;
|
||||
@ -1514,7 +1514,7 @@ namespace dlib
|
||||
op_sumr(const M& m_) : m(m_) {}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost;
|
||||
const static long cost = M::cost+10;
|
||||
const static long NR = 1;
|
||||
const static long NC = M::NC;
|
||||
typedef typename M::type type;
|
||||
@ -1560,7 +1560,7 @@ namespace dlib
|
||||
op_sumc(const M& m_) : m(m_) {}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost;
|
||||
const static long cost = M::cost + 10;
|
||||
const static long NR = M::NR;
|
||||
const static long NC = 1;
|
||||
typedef typename M::type type;
|
||||
@ -1975,7 +1975,7 @@ namespace dlib
|
||||
op_rotate(const M& m_) : m(m_) {}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost + 1;
|
||||
const static long cost = M::cost + 2;
|
||||
const static long NR = M::NR;
|
||||
const static long NC = M::NC;
|
||||
typedef typename M::type type;
|
||||
@ -2350,7 +2350,7 @@ namespace dlib
|
||||
|
||||
typedef typename M::type type;
|
||||
typedef const typename M::type const_ret_type;
|
||||
const static long cost = M::cost + 1;
|
||||
const static long cost = M::cost + 2;
|
||||
|
||||
const_ret_type apply ( long r, long c) const
|
||||
{
|
||||
@ -2428,7 +2428,7 @@ namespace dlib
|
||||
const long rows;
|
||||
const long cols;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
const static long NR = 0;
|
||||
const static long NC = 0;
|
||||
typedef typename M::type type;
|
||||
@ -2739,7 +2739,7 @@ namespace dlib
|
||||
{
|
||||
op_lowerm( const M& m_) : basic_op_m<M>(m_){}
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
typedef typename M::type type;
|
||||
typedef const typename M::type const_ret_type;
|
||||
const_ret_type apply ( long r, long c) const
|
||||
@ -2759,7 +2759,7 @@ namespace dlib
|
||||
|
||||
const type s;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
typedef const typename M::type const_ret_type;
|
||||
const_ret_type apply ( long r, long c) const
|
||||
{
|
||||
@ -2802,7 +2802,7 @@ namespace dlib
|
||||
{
|
||||
op_upperm( const M& m_) : basic_op_m<M>(m_){}
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
typedef typename M::type type;
|
||||
typedef const typename M::type const_ret_type;
|
||||
const_ret_type apply ( long r, long c) const
|
||||
@ -2822,7 +2822,7 @@ namespace dlib
|
||||
|
||||
const type s;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
typedef const typename M::type const_ret_type;
|
||||
const_ret_type apply ( long r, long c) const
|
||||
{
|
||||
@ -3013,7 +3013,7 @@ namespace dlib
|
||||
op_mat_to_vect(const M& m_) : m(m_) {}
|
||||
const M& m;
|
||||
|
||||
const static long cost = M::cost+1;
|
||||
const static long cost = M::cost+2;
|
||||
const static long NR = M::NC*M::NR;
|
||||
const static long NC = 1;
|
||||
typedef typename M::type type;
|
||||
|
@ -57,10 +57,22 @@ namespace
|
||||
b = a*a;
|
||||
DLIB_TEST(counter_gemm() == 1);
|
||||
|
||||
counter_gemm() = 0;
|
||||
b = a/2*a;
|
||||
DLIB_TEST(counter_gemm() == 1);
|
||||
|
||||
counter_gemm() = 0;
|
||||
b = a*trans(a) + a;
|
||||
DLIB_TEST(counter_gemm() == 1);
|
||||
|
||||
counter_gemm() = 0;
|
||||
b = (a+a)*(a+a);
|
||||
DLIB_TEST(counter_gemm() == 1);
|
||||
|
||||
counter_gemm() = 0;
|
||||
b = a*(a-a);
|
||||
DLIB_TEST(counter_gemm() == 1);
|
||||
|
||||
counter_gemm() = 0;
|
||||
b = trans(a)*trans(a) + a;
|
||||
DLIB_TEST(counter_gemm() == 1);
|
||||
|
@ -1196,6 +1196,22 @@ namespace
|
||||
DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
|
||||
DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
|
||||
}
|
||||
|
||||
{
|
||||
matrix<double> m1(3,3), m2(3,3);
|
||||
|
||||
m1 = 1;
|
||||
m2 = 1;
|
||||
m1 = m1*subm(m2,0,0,3,3);
|
||||
}
|
||||
{
|
||||
matrix<double,3,1> m1;
|
||||
matrix<double> m2(3,3);
|
||||
|
||||
m1 = 1;
|
||||
m2 = 1;
|
||||
m1 = subm(m2,0,0,3,3)*m1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user