mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added linpiece()
This commit is contained in:
parent
9d0f6796dc
commit
b22e9f2fc8
@ -3644,6 +3644,69 @@ namespace dlib
|
||||
return matrix_range_exp<double>(start,end,num,false);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename M>
|
||||
struct op_linpiece
|
||||
{
|
||||
op_linpiece(const double val_, const M& joints_) : joints(joints_), val(val_){}
|
||||
|
||||
const M& joints;
|
||||
const double val;
|
||||
|
||||
const static long cost = 10;
|
||||
|
||||
const static long NR = (M::NR*M::NC==0) ? (0) : (M::NR*M::NC-1);
|
||||
const static long NC = 1;
|
||||
typedef typename M::type type;
|
||||
typedef default_memory_manager mem_manager_type;
|
||||
typedef row_major_layout layout_type;
|
||||
|
||||
typedef type const_ret_type;
|
||||
const_ret_type apply (long i, long ) const
|
||||
{
|
||||
if (joints(i) < val)
|
||||
return std::min<type>(val,joints(i+1)) - joints(i);
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
long nr () const { return joints.size()-1; }
|
||||
long nc () const { return 1; }
|
||||
|
||||
template <typename U> bool aliases ( const matrix_exp<U>& item) const { return joints.aliases(item); }
|
||||
template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return joints.aliases(item); }
|
||||
};
|
||||
|
||||
template < typename EXP >
|
||||
const matrix_op<op_linpiece<EXP> > linpiece (
|
||||
const double val,
|
||||
const matrix_exp<EXP>& joints
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(is_vector(joints) && joints.size() >= 2,
|
||||
"\t matrix_exp linpiece()"
|
||||
<< "\n\t Invalid inputs were given to this function "
|
||||
<< "\n\t is_vector(joints): " << is_vector(joints)
|
||||
<< "\n\t joints.size(): " << joints.size()
|
||||
);
|
||||
#ifdef ENABLE_ASSERTS
|
||||
for (long i = 1; i < joints.size(); ++i)
|
||||
{
|
||||
DLIB_ASSERT(joints(i-1) < joints(i),
|
||||
"\t matrix_exp linpiece()"
|
||||
<< "\n\t Invalid inputs were given to this function "
|
||||
<< "\n\t joints("<<i-1<<"): " << joints(i-1)
|
||||
<< "\n\t joints("<<i<<"): " << joints(i)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
typedef op_linpiece<EXP> op;
|
||||
return matrix_op<op>(op(val,joints.ref()));
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
inline const matrix_log_range_exp<double> logspace (
|
||||
|
@ -390,6 +390,37 @@ namespace dlib
|
||||
- M(num-1) == 10^end
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
const matrix_exp linpiece (
|
||||
const double val,
|
||||
const matrix_exp& joints
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- is_vector(joints) == true
|
||||
- joints.size() >= 2
|
||||
- for all valid i < j:
|
||||
- joints(i) < joints(j)
|
||||
ensures
|
||||
- linpiece() is useful for creating piecewise linear functions of val. For
|
||||
example, if w is a parameter vector then you can represent a piecewise linear
|
||||
function of val as: f(val) = dot(w, linpiece(val, linspace(0,100,5))). In
|
||||
this case, f(val) is piecewise linear on the intervals [0,25], [25,50],
|
||||
[50,75], [75,100]. Moreover, w(i) defines the derivative of f(val) in the
|
||||
i-th interval. Finally, outside the interval [0,100] f(val) has a derivative
|
||||
of zero and f(0) == 0.
|
||||
- To be precise, this function returns a column vector L such that:
|
||||
- L.size() == joints.size()-1
|
||||
- is_col_vector(L) == true
|
||||
- L contains the same type of elements as joints.
|
||||
- for all valid i:
|
||||
- if (joints(i) < val)
|
||||
- L(i) == min(val,joints(i+1)) - joints(i)
|
||||
- else
|
||||
- L(i) == 0
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
|
@ -1015,6 +1015,54 @@ namespace
|
||||
DLIB_TEST(max(abs(mat(v8) - mat(a2))) == 0);
|
||||
}
|
||||
|
||||
void test_linpiece()
|
||||
{
|
||||
matrix<double,0,1> temp = linpiece(5, linspace(-1, 9, 2));
|
||||
DLIB_CASSERT(temp.size() == 1,"");
|
||||
DLIB_CASSERT(std::abs(temp(0) - 6) < 1e-13,"");
|
||||
|
||||
temp = linpiece(5, linspace(-1, 9, 6));
|
||||
DLIB_CASSERT(temp.size() == 5,"");
|
||||
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
|
||||
|
||||
temp = linpiece(4, linspace(-1, 9, 6));
|
||||
DLIB_CASSERT(temp.size() == 5,"");
|
||||
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(2) - 1) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
|
||||
|
||||
temp = linpiece(40, linspace(-1, 9, 6));
|
||||
DLIB_CASSERT(temp.size() == 5,"");
|
||||
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(3) - 2) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(4) - 2) < 1e-13,"");
|
||||
|
||||
temp = linpiece(-40, linspace(-1, 9, 6));
|
||||
DLIB_CASSERT(temp.size() == 5,"");
|
||||
DLIB_CASSERT(std::abs(temp(0) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
|
||||
|
||||
temp = linpiece(0, linspace(-1, 9, 6));
|
||||
DLIB_CASSERT(temp.size() == 5,"");
|
||||
DLIB_CASSERT(std::abs(temp(0) - 1) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
|
||||
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
|
||||
|
||||
}
|
||||
|
||||
class matrix_tester : public tester
|
||||
{
|
||||
public:
|
||||
@ -1038,6 +1086,7 @@ namespace
|
||||
matrix_test();
|
||||
|
||||
test_complex();
|
||||
test_linpiece();
|
||||
}
|
||||
} a;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user