diff --git a/dlib/matrix/matrix_utilities.h b/dlib/matrix/matrix_utilities.h index e8bd5d1ab..d7fd0ac39 100644 --- a/dlib/matrix/matrix_utilities.h +++ b/dlib/matrix/matrix_utilities.h @@ -3644,6 +3644,69 @@ namespace dlib return matrix_range_exp(start,end,num,false); } +// ---------------------------------------------------------------------------------------- + + template + struct op_linpiece + { + op_linpiece(const double val_, const M& joints_) : joints(joints_), val(val_){} + + const M& joints; + const double val; + + const static long cost = 10; + + const static long NR = (M::NR*M::NC==0) ? (0) : (M::NR*M::NC-1); + const static long NC = 1; + typedef typename M::type type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + typedef type const_ret_type; + const_ret_type apply (long i, long ) const + { + if (joints(i) < val) + return std::min(val,joints(i+1)) - joints(i); + else + return 0; + } + + long nr () const { return joints.size()-1; } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return joints.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return joints.aliases(item); } + }; + + template < typename EXP > + const matrix_op > linpiece ( + const double val, + const matrix_exp& joints + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(joints) && joints.size() >= 2, + "\t matrix_exp linpiece()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_vector(joints): " << is_vector(joints) + << "\n\t joints.size(): " << joints.size() + ); +#ifdef ENABLE_ASSERTS + for (long i = 1; i < joints.size(); ++i) + { + DLIB_ASSERT(joints(i-1) < joints(i), + "\t matrix_exp linpiece()" + << "\n\t Invalid inputs were given to this function " + << "\n\t joints("< op; + return matrix_op(op(val,joints.ref())); + } + // ---------------------------------------------------------------------------------------- inline const matrix_log_range_exp logspace ( diff --git a/dlib/matrix/matrix_utilities_abstract.h b/dlib/matrix/matrix_utilities_abstract.h index 99b879cb6..8ba1d24f2 100644 --- a/dlib/matrix/matrix_utilities_abstract.h +++ b/dlib/matrix/matrix_utilities_abstract.h @@ -390,6 +390,37 @@ namespace dlib - M(num-1) == 10^end !*/ +// ---------------------------------------------------------------------------------------- + + const matrix_exp linpiece ( + const double val, + const matrix_exp& joints + ); + /*! + requires + - is_vector(joints) == true + - joints.size() >= 2 + - for all valid i < j: + - joints(i) < joints(j) + ensures + - linpiece() is useful for creating piecewise linear functions of val. For + example, if w is a parameter vector then you can represent a piecewise linear + function of val as: f(val) = dot(w, linpiece(val, linspace(0,100,5))). In + this case, f(val) is piecewise linear on the intervals [0,25], [25,50], + [50,75], [75,100]. Moreover, w(i) defines the derivative of f(val) in the + i-th interval. Finally, outside the interval [0,100] f(val) has a derivative + of zero and f(0) == 0. + - To be precise, this function returns a column vector L such that: + - L.size() == joints.size()-1 + - is_col_vector(L) == true + - L contains the same type of elements as joints. + - for all valid i: + - if (joints(i) < val) + - L(i) == min(val,joints(i+1)) - joints(i) + - else + - L(i) == 0 + !*/ + // ---------------------------------------------------------------------------------------- template < diff --git a/dlib/test/matrix4.cpp b/dlib/test/matrix4.cpp index 426e8a640..060c801e1 100644 --- a/dlib/test/matrix4.cpp +++ b/dlib/test/matrix4.cpp @@ -1015,6 +1015,54 @@ namespace DLIB_TEST(max(abs(mat(v8) - mat(a2))) == 0); } + void test_linpiece() + { + matrix temp = linpiece(5, linspace(-1, 9, 2)); + DLIB_CASSERT(temp.size() == 1,""); + DLIB_CASSERT(std::abs(temp(0) - 6) < 1e-13,""); + + temp = linpiece(5, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,""); + + temp = linpiece(4, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 1) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,""); + + temp = linpiece(40, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 2) < 1e-13,""); + + temp = linpiece(-40, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,""); + + temp = linpiece(0, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 1) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,""); + + } + class matrix_tester : public tester { public: @@ -1038,6 +1086,7 @@ namespace matrix_test(); test_complex(); + test_linpiece(); } } a;