mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Added linpiece()
This commit is contained in:
parent
9d0f6796dc
commit
b22e9f2fc8
@ -3644,6 +3644,69 @@ namespace dlib
|
|||||||
return matrix_range_exp<double>(start,end,num,false);
|
return matrix_range_exp<double>(start,end,num,false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
template <typename M>
|
||||||
|
struct op_linpiece
|
||||||
|
{
|
||||||
|
op_linpiece(const double val_, const M& joints_) : joints(joints_), val(val_){}
|
||||||
|
|
||||||
|
const M& joints;
|
||||||
|
const double val;
|
||||||
|
|
||||||
|
const static long cost = 10;
|
||||||
|
|
||||||
|
const static long NR = (M::NR*M::NC==0) ? (0) : (M::NR*M::NC-1);
|
||||||
|
const static long NC = 1;
|
||||||
|
typedef typename M::type type;
|
||||||
|
typedef default_memory_manager mem_manager_type;
|
||||||
|
typedef row_major_layout layout_type;
|
||||||
|
|
||||||
|
typedef type const_ret_type;
|
||||||
|
const_ret_type apply (long i, long ) const
|
||||||
|
{
|
||||||
|
if (joints(i) < val)
|
||||||
|
return std::min<type>(val,joints(i+1)) - joints(i);
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
long nr () const { return joints.size()-1; }
|
||||||
|
long nc () const { return 1; }
|
||||||
|
|
||||||
|
template <typename U> bool aliases ( const matrix_exp<U>& item) const { return joints.aliases(item); }
|
||||||
|
template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return joints.aliases(item); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template < typename EXP >
|
||||||
|
const matrix_op<op_linpiece<EXP> > linpiece (
|
||||||
|
const double val,
|
||||||
|
const matrix_exp<EXP>& joints
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(is_vector(joints) && joints.size() >= 2,
|
||||||
|
"\t matrix_exp linpiece()"
|
||||||
|
<< "\n\t Invalid inputs were given to this function "
|
||||||
|
<< "\n\t is_vector(joints): " << is_vector(joints)
|
||||||
|
<< "\n\t joints.size(): " << joints.size()
|
||||||
|
);
|
||||||
|
#ifdef ENABLE_ASSERTS
|
||||||
|
for (long i = 1; i < joints.size(); ++i)
|
||||||
|
{
|
||||||
|
DLIB_ASSERT(joints(i-1) < joints(i),
|
||||||
|
"\t matrix_exp linpiece()"
|
||||||
|
<< "\n\t Invalid inputs were given to this function "
|
||||||
|
<< "\n\t joints("<<i-1<<"): " << joints(i-1)
|
||||||
|
<< "\n\t joints("<<i<<"): " << joints(i)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef op_linpiece<EXP> op;
|
||||||
|
return matrix_op<op>(op(val,joints.ref()));
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
inline const matrix_log_range_exp<double> logspace (
|
inline const matrix_log_range_exp<double> logspace (
|
||||||
|
@ -390,6 +390,37 @@ namespace dlib
|
|||||||
- M(num-1) == 10^end
|
- M(num-1) == 10^end
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
const matrix_exp linpiece (
|
||||||
|
const double val,
|
||||||
|
const matrix_exp& joints
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- is_vector(joints) == true
|
||||||
|
- joints.size() >= 2
|
||||||
|
- for all valid i < j:
|
||||||
|
- joints(i) < joints(j)
|
||||||
|
ensures
|
||||||
|
- linpiece() is useful for creating piecewise linear functions of val. For
|
||||||
|
example, if w is a parameter vector then you can represent a piecewise linear
|
||||||
|
function of val as: f(val) = dot(w, linpiece(val, linspace(0,100,5))). In
|
||||||
|
this case, f(val) is piecewise linear on the intervals [0,25], [25,50],
|
||||||
|
[50,75], [75,100]. Moreover, w(i) defines the derivative of f(val) in the
|
||||||
|
i-th interval. Finally, outside the interval [0,100] f(val) has a derivative
|
||||||
|
of zero and f(0) == 0.
|
||||||
|
- To be precise, this function returns a column vector L such that:
|
||||||
|
- L.size() == joints.size()-1
|
||||||
|
- is_col_vector(L) == true
|
||||||
|
- L contains the same type of elements as joints.
|
||||||
|
- for all valid i:
|
||||||
|
- if (joints(i) < val)
|
||||||
|
- L(i) == min(val,joints(i+1)) - joints(i)
|
||||||
|
- else
|
||||||
|
- L(i) == 0
|
||||||
|
!*/
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
template <
|
template <
|
||||||
|
@ -1015,6 +1015,54 @@ namespace
|
|||||||
DLIB_TEST(max(abs(mat(v8) - mat(a2))) == 0);
|
DLIB_TEST(max(abs(mat(v8) - mat(a2))) == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_linpiece()
|
||||||
|
{
|
||||||
|
matrix<double,0,1> temp = linpiece(5, linspace(-1, 9, 2));
|
||||||
|
DLIB_CASSERT(temp.size() == 1,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(0) - 6) < 1e-13,"");
|
||||||
|
|
||||||
|
temp = linpiece(5, linspace(-1, 9, 6));
|
||||||
|
DLIB_CASSERT(temp.size() == 5,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
|
||||||
|
|
||||||
|
temp = linpiece(4, linspace(-1, 9, 6));
|
||||||
|
DLIB_CASSERT(temp.size() == 5,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(2) - 1) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
|
||||||
|
|
||||||
|
temp = linpiece(40, linspace(-1, 9, 6));
|
||||||
|
DLIB_CASSERT(temp.size() == 5,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(3) - 2) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(4) - 2) < 1e-13,"");
|
||||||
|
|
||||||
|
temp = linpiece(-40, linspace(-1, 9, 6));
|
||||||
|
DLIB_CASSERT(temp.size() == 5,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(0) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
|
||||||
|
|
||||||
|
temp = linpiece(0, linspace(-1, 9, 6));
|
||||||
|
DLIB_CASSERT(temp.size() == 5,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(0) - 1) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,"");
|
||||||
|
DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,"");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
class matrix_tester : public tester
|
class matrix_tester : public tester
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -1038,6 +1086,7 @@ namespace
|
|||||||
matrix_test();
|
matrix_test();
|
||||||
|
|
||||||
test_complex();
|
test_complex();
|
||||||
|
test_linpiece();
|
||||||
}
|
}
|
||||||
} a;
|
} a;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user