mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
Simplify the segmentation network structure; make the object detection network more complex in turn
This commit is contained in:
parent
316d62b254
commit
73a1ba8a86
@ -66,17 +66,17 @@ dlib::rectangle get_cropping_rect(const dlib::rectangle& rectangle)
|
||||
template <long num_filters, typename SUBNET> using con5d = dlib::con<num_filters,5,5,2,2,SUBNET>;
|
||||
template <long num_filters, typename SUBNET> using con5 = dlib::con<num_filters,5,5,1,1,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using bdownsampler = dlib::relu<dlib::bn_con<con5d<64,dlib::relu<dlib::bn_con<con5d<64,dlib::relu<dlib::bn_con<con5d<16,SUBNET>>>>>>>>>;
|
||||
template <typename SUBNET> using adownsampler = dlib::relu<dlib::affine<con5d<64,dlib::relu<dlib::affine<con5d<64,dlib::relu<dlib::affine<con5d<16,SUBNET>>>>>>>>>;
|
||||
template <typename SUBNET> using bdownsampler = dlib::relu<dlib::bn_con<con5d<128,dlib::relu<dlib::bn_con<con5d<128,dlib::relu<dlib::bn_con<con5d<32,SUBNET>>>>>>>>>;
|
||||
template <typename SUBNET> using adownsampler = dlib::relu<dlib::affine<con5d<128,dlib::relu<dlib::affine<con5d<128,dlib::relu<dlib::affine<con5d<32,SUBNET>>>>>>>>>;
|
||||
|
||||
template <typename SUBNET> using brcon5 = dlib::relu<dlib::bn_con<con5<128,SUBNET>>>;
|
||||
template <typename SUBNET> using arcon5 = dlib::relu<dlib::affine<con5<128,SUBNET>>>;
|
||||
template <typename SUBNET> using brcon5 = dlib::relu<dlib::bn_con<con5<256,SUBNET>>>;
|
||||
template <typename SUBNET> using arcon5 = dlib::relu<dlib::affine<con5<256,SUBNET>>>;
|
||||
|
||||
using det_bnet_type = dlib::loss_mmod<dlib::con<1,9,9,1,1,brcon5<brcon5<brcon5<bdownsampler<dlib::input_rgb_image_pyramid<dlib::pyramid_down<6>>>>>>>>;
|
||||
using det_anet_type = dlib::loss_mmod<dlib::con<1,9,9,1,1,arcon5<arcon5<arcon5<adownsampler<dlib::input_rgb_image_pyramid<dlib::pyramid_down<6>>>>>>>>;
|
||||
|
||||
// The segmentation network.
|
||||
// For the time being, this is very much copy-paste from dnn_semantic_segmentation.h.
|
||||
// For the time being, this is very much copy-paste from dnn_semantic_segmentation.h, although the network is made narrower (smaller feature maps).
|
||||
|
||||
template <int N, template <typename> class BN, int stride, typename SUBNET>
|
||||
using block = BN<dlib::con<N,3,3,1,1,dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
|
||||
@ -102,34 +102,34 @@ template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename SUBNET> using res64 = res<64,SUBNET>;
|
||||
template <typename SUBNET> using res128 = res<128,SUBNET>;
|
||||
template <typename SUBNET> using res256 = res<256,SUBNET>;
|
||||
template <typename SUBNET> using res512 = res<512,SUBNET>;
|
||||
template <typename SUBNET> using ares64 = ares<64,SUBNET>;
|
||||
template <typename SUBNET> using ares128 = ares<128,SUBNET>;
|
||||
template <typename SUBNET> using ares256 = ares<256,SUBNET>;
|
||||
template <typename SUBNET> using ares512 = ares<512,SUBNET>;
|
||||
template <typename SUBNET> using res16 = res<16,SUBNET>;
|
||||
template <typename SUBNET> using res24 = res<24,SUBNET>;
|
||||
template <typename SUBNET> using res32 = res<32,SUBNET>;
|
||||
template <typename SUBNET> using res48 = res<48,SUBNET>;
|
||||
template <typename SUBNET> using ares16 = ares<16,SUBNET>;
|
||||
template <typename SUBNET> using ares24 = ares<24,SUBNET>;
|
||||
template <typename SUBNET> using ares32 = ares<32,SUBNET>;
|
||||
template <typename SUBNET> using ares48 = ares<48,SUBNET>;
|
||||
|
||||
template <typename SUBNET> using level1 = dlib::repeat<2,res64,res<64,SUBNET>>;
|
||||
template <typename SUBNET> using level2 = dlib::repeat<2,res128,res_down<128,SUBNET>>;
|
||||
template <typename SUBNET> using level3 = dlib::repeat<2,res256,res_down<256,SUBNET>>;
|
||||
template <typename SUBNET> using level4 = dlib::repeat<2,res512,res_down<512,SUBNET>>;
|
||||
template <typename SUBNET> using level1 = dlib::repeat<2,res16,res<16,SUBNET>>;
|
||||
template <typename SUBNET> using level2 = dlib::repeat<2,res24,res_down<24,SUBNET>>;
|
||||
template <typename SUBNET> using level3 = dlib::repeat<2,res32,res_down<32,SUBNET>>;
|
||||
template <typename SUBNET> using level4 = dlib::repeat<2,res48,res_down<48,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares64,ares<64,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares128,ares_down<128,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares256,ares_down<256,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares512,ares_down<512,SUBNET>>;
|
||||
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares16,ares<16,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares24,ares_down<24,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares32,ares_down<32,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares48,ares_down<48,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using level1t = dlib::repeat<2,res64,res_up<64,SUBNET>>;
|
||||
template <typename SUBNET> using level2t = dlib::repeat<2,res128,res_up<128,SUBNET>>;
|
||||
template <typename SUBNET> using level3t = dlib::repeat<2,res256,res_up<256,SUBNET>>;
|
||||
template <typename SUBNET> using level4t = dlib::repeat<2,res512,res_up<512,SUBNET>>;
|
||||
template <typename SUBNET> using level1t = dlib::repeat<2,res16,res_up<16,SUBNET>>;
|
||||
template <typename SUBNET> using level2t = dlib::repeat<2,res24,res_up<24,SUBNET>>;
|
||||
template <typename SUBNET> using level3t = dlib::repeat<2,res32,res_up<32,SUBNET>>;
|
||||
template <typename SUBNET> using level4t = dlib::repeat<2,res48,res_up<48,SUBNET>>;
|
||||
|
||||
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares64,ares_up<64,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares128,ares_up<128,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares256,ares_up<256,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares512,ares_up<512,SUBNET>>;
|
||||
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares16,ares_up<16,SUBNET>>;
|
||||
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares24,ares_up<24,SUBNET>>;
|
||||
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares32,ares_up<32,SUBNET>>;
|
||||
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares48,ares_up<48,SUBNET>>;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
@ -166,7 +166,7 @@ static const char* instance_segmentation_net_filename = "instance_segmentation_v
|
||||
// training network type
|
||||
using seg_bnet_type = dlib::loss_multiclass_log_per_pixel<
|
||||
dlib::cont<2,1,1,1,1,
|
||||
dlib::relu<dlib::bn_con<dlib::cont<64,7,7,2,2,
|
||||
dlib::relu<dlib::bn_con<dlib::cont<16,7,7,2,2,
|
||||
concat_utag1<level1t<
|
||||
concat_utag2<level2t<
|
||||
concat_utag3<level3t<
|
||||
@ -175,14 +175,14 @@ using seg_bnet_type = dlib::loss_multiclass_log_per_pixel<
|
||||
level3<utag3<
|
||||
level2<utag2<
|
||||
level1<dlib::max_pool<3,3,2,2,utag1<
|
||||
dlib::relu<dlib::bn_con<dlib::con<64,7,7,2,2,
|
||||
dlib::relu<dlib::bn_con<dlib::con<16,7,7,2,2,
|
||||
dlib::input<dlib::matrix<dlib::rgb_pixel>>
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>>;
|
||||
|
||||
// testing network type (replaced batch normalization with fixed affine transforms)
|
||||
using seg_anet_type = dlib::loss_multiclass_log_per_pixel<
|
||||
dlib::cont<2,1,1,1,1,
|
||||
dlib::relu<dlib::affine<dlib::cont<64,7,7,2,2,
|
||||
dlib::relu<dlib::affine<dlib::cont<16,7,7,2,2,
|
||||
concat_utag1<alevel1t<
|
||||
concat_utag2<alevel2t<
|
||||
concat_utag3<alevel3t<
|
||||
@ -191,7 +191,7 @@ using seg_anet_type = dlib::loss_multiclass_log_per_pixel<
|
||||
alevel3<utag3<
|
||||
alevel2<utag2<
|
||||
alevel1<dlib::max_pool<3,3,2,2,utag1<
|
||||
dlib::relu<dlib::affine<dlib::con<64,7,7,2,2,
|
||||
dlib::relu<dlib::affine<dlib::con<16,7,7,2,2,
|
||||
dlib::input<dlib::matrix<dlib::rgb_pixel>>
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>>;
|
||||
|
||||
|
@ -663,7 +663,7 @@ int main(int argc, char** argv) try
|
||||
|
||||
// mini-batches smaller than the default can be used with GPUs having less memory
|
||||
const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 60;
|
||||
const unsigned int seg_minibatch_size = argc >= 4 ? std::stoi(argv[3]) : 20;
|
||||
const unsigned int seg_minibatch_size = argc >= 4 ? std::stoi(argv[3]) : 100;
|
||||
cout << "det mini-batch size: " << det_minibatch_size << endl;
|
||||
cout << "seg mini-batch size: " << seg_minibatch_size << endl;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user