Added linpiece()

This commit is contained in:
Davis King 2013-03-30 16:51:27 -04:00
parent 9d0f6796dc
commit b22e9f2fc8
3 changed files with 143 additions and 0 deletions

View File

@ -3644,6 +3644,69 @@ namespace dlib
return matrix_range_exp<double>(start,end,num,false); return matrix_range_exp<double>(start,end,num,false);
} }
// ----------------------------------------------------------------------------------------
template <typename M>
struct op_linpiece
{
op_linpiece(const double val_, const M& joints_) : joints(joints_), val(val_){}
const M& joints;
const double val;
const static long cost = 10;
const static long NR = (M::NR*M::NC==0) ? (0) : (M::NR*M::NC-1);
const static long NC = 1;
typedef typename M::type type;
typedef default_memory_manager mem_manager_type;
typedef row_major_layout layout_type;
typedef type const_ret_type;
const_ret_type apply (long i, long ) const
{
if (joints(i) < val)
return std::min<type>(val,joints(i+1)) - joints(i);
else
return 0;
}
long nr () const { return joints.size()-1; }
long nc () const { return 1; }
template <typename U> bool aliases ( const matrix_exp<U>& item) const { return joints.aliases(item); }
template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return joints.aliases(item); }
};
template < typename EXP >
const matrix_op<op_linpiece<EXP> > linpiece (
const double val,
const matrix_exp<EXP>& joints
)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_vector(joints) && joints.size() >= 2,
"\t matrix_exp linpiece()"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t is_vector(joints): " << is_vector(joints)
<< "\n\t joints.size(): " << joints.size()
);
#ifdef ENABLE_ASSERTS
for (long i = 1; i < joints.size(); ++i)
{
DLIB_ASSERT(joints(i-1) < joints(i),
"\t matrix_exp linpiece()"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t joints("<<i-1<<"): " << joints(i-1)
<< "\n\t joints("<<i<<"): " << joints(i)
);
}
#endif
typedef op_linpiece<EXP> op;
return matrix_op<op>(op(val,joints.ref()));
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline const matrix_log_range_exp<double> logspace ( inline const matrix_log_range_exp<double> logspace (

View File

@ -390,6 +390,37 @@ namespace dlib
- M(num-1) == 10^end - M(num-1) == 10^end
!*/ !*/
// ----------------------------------------------------------------------------------------
const matrix_exp linpiece (
const double val,
const matrix_exp& joints
);
/*!
requires
- is_vector(joints) == true
- joints.size() >= 2
- for all valid i < j:
- joints(i) < joints(j)
ensures
- linpiece() is useful for creating piecewise linear functions of val. For
example, if w is a parameter vector then you can represent a piecewise linear
function of val as: f(val) = dot(w, linpiece(val, linspace(0,100,5))). In
this case, f(val) is piecewise linear on the intervals [0,25], [25,50],
[50,75], [75,100]. Moreover, w(i) defines the derivative of f(val) in the
i-th interval. Finally, outside the interval [0,100] f(val) has a derivative
of zero and f(0) == 0.
- To be precise, this function returns a column vector L such that:
- L.size() == joints.size()-1
- is_col_vector(L) == true
- L contains the same type of elements as joints.
- for all valid i:
- if (joints(i) < val)
- L(i) == min(val,joints(i+1)) - joints(i)
- else
- L(i) == 0
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <

View File

@ -1015,6 +1015,54 @@ namespace
DLIB_TEST(max(abs(mat(v8) - mat(a2))) == 0); DLIB_TEST(max(abs(mat(v8) - mat(a2))) == 0);
} }
void test_linpiece()
{
matrix<double,0,1> temp = linpiece(5, linspace(-1, 9, 2));
DLIB_CASSERT(temp.size() == 1,"");
DLIB_CASSERT(std::abs(temp(0) - 6) < 1e-13,"");
temp = linpiece(5, linspace(-1, 9, 6));
DLIB_CASSERT(temp.size() == 5,"");
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
temp = linpiece(4, linspace(-1, 9, 6));
DLIB_CASSERT(temp.size() == 5,"");
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(2) - 1) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
temp = linpiece(40, linspace(-1, 9, 6));
DLIB_CASSERT(temp.size() == 5,"");
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(3) - 2) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(4) - 2) < 1e-13,"");
temp = linpiece(-40, linspace(-1, 9, 6));
DLIB_CASSERT(temp.size() == 5,"");
DLIB_CASSERT(std::abs(temp(0) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
temp = linpiece(0, linspace(-1, 9, 6));
DLIB_CASSERT(temp.size() == 5,"");
DLIB_CASSERT(std::abs(temp(0) - 1) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
}
class matrix_tester : public tester class matrix_tester : public tester
{ {
public: public:
@ -1038,6 +1086,7 @@ namespace
matrix_test(); matrix_test();
test_complex(); test_complex();
test_linpiece();
} }
} a; } a;