mirror of
https://github.com/davisking/dlib.git
synced 2024-11-01 10:14:53 +08:00
simplify resnet definition by reusing struct template parameter (#2010)
* simplify definition by reusing struct template parameter * put resnet into its own namespace * fix infer names * rename struct impl to def
This commit is contained in:
parent
3a53c78ad2
commit
c832d3b2fc
@ -29,7 +29,7 @@ namespace model
|
|||||||
using net_type = loss_metric<
|
using net_type = loss_metric<
|
||||||
fc_no_bias<128,
|
fc_no_bias<128,
|
||||||
avg_pool_everything<
|
avg_pool_everything<
|
||||||
typename resnet<BN>::template backbone_50<
|
typename resnet::def<BN>::template backbone_50<
|
||||||
input_rgb_image
|
input_rgb_image
|
||||||
>>>>;
|
>>>>;
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ int main() try
|
|||||||
{
|
{
|
||||||
// Now, let's define the classic ResNet50 network and load the pretrained model on
|
// Now, let's define the classic ResNet50 network and load the pretrained model on
|
||||||
// ImageNet.
|
// ImageNet.
|
||||||
resnet<bn_con>::n50 resnet50;
|
resnet::train_50 resnet50;
|
||||||
std::vector<string> labels;
|
std::vector<string> labels;
|
||||||
deserialize("resnet50_1000_imagenet_classifier.dnn") >> resnet50 >> labels;
|
deserialize("resnet50_1000_imagenet_classifier.dnn") >> resnet50 >> labels;
|
||||||
|
|
||||||
|
@ -3,85 +3,77 @@
|
|||||||
|
|
||||||
#include <dlib/dnn.h>
|
#include <dlib/dnn.h>
|
||||||
|
|
||||||
// BATCHNORM must be bn_con or affine layer
|
namespace resnet
|
||||||
template<template<typename> class BATCHNORM>
|
{
|
||||||
struct resnet
|
using namespace dlib;
|
||||||
|
// BN is bn_con or affine layer
|
||||||
|
template<template<typename> class BN>
|
||||||
|
struct def
|
||||||
{
|
{
|
||||||
// the resnet basic block, where BN is bn_con or affine
|
// the resnet basic block, where BN is bn_con or affine
|
||||||
template<long num_filters, template<typename> class BN, int stride, typename SUBNET>
|
template<long num_filters, int stride, typename SUBNET>
|
||||||
using basicblock = BN<dlib::con<num_filters, 3, 3, 1, 1,
|
using basicblock = BN<con<num_filters, 3, 3, 1, 1,
|
||||||
dlib::relu<BN<dlib::con<num_filters, 3, 3, stride, stride, SUBNET>>>>>;
|
relu<BN<con<num_filters, 3, 3, stride, stride, SUBNET>>>>>;
|
||||||
|
|
||||||
// the resnet bottleneck block
|
// the resnet bottleneck block
|
||||||
template<long num_filters, template<typename> class BN, int stride, typename SUBNET>
|
template<long num_filters, int stride, typename SUBNET>
|
||||||
using bottleneck = BN<dlib::con<4 * num_filters, 1, 1, 1, 1,
|
using bottleneck = BN<con<4 * num_filters, 1, 1, 1, 1,
|
||||||
dlib::relu<BN<dlib::con<num_filters, 3, 3, stride, stride,
|
relu<BN<con<num_filters, 3, 3, stride, stride,
|
||||||
dlib::relu<BN<dlib::con<num_filters, 1, 1, 1, 1, SUBNET>>>>>>>>;
|
relu<BN<con<num_filters, 1, 1, 1, 1, SUBNET>>>>>>>>;
|
||||||
|
|
||||||
// the resnet residual
|
// the resnet residual, where BLOCK is either basicblock or bottleneck
|
||||||
template<
|
template<template<long, int, typename> class BLOCK, long num_filters, typename SUBNET>
|
||||||
template<long, template<typename> class, int, typename> class BLOCK, // basicblock or bottleneck
|
using residual = add_prev1<BLOCK<num_filters, 1, tag1<SUBNET>>>;
|
||||||
long num_filters,
|
|
||||||
template<typename> class BN, // bn_con or affine
|
|
||||||
typename SUBNET
|
|
||||||
> // adds the block to the result of tag1 (the subnet)
|
|
||||||
using residual = dlib::add_prev1<BLOCK<num_filters, BN, 1, dlib::tag1<SUBNET>>>;
|
|
||||||
|
|
||||||
// a resnet residual that does subsampling on both paths
|
// a resnet residual that does subsampling on both paths
|
||||||
|
template<template<long, int, typename> class BLOCK, long num_filters, typename SUBNET>
|
||||||
|
using residual_down = add_prev2<avg_pool<2, 2, 2, 2,
|
||||||
|
skip1<tag2<BLOCK<num_filters, 2,
|
||||||
|
tag1<SUBNET>>>>>>;
|
||||||
|
|
||||||
|
// residual block with optional downsampling
|
||||||
template<
|
template<
|
||||||
template<long, template<typename> class, int, typename> class BLOCK, // basicblock or bottleneck
|
template<template<long, int, typename> class, long, typename> class RESIDUAL,
|
||||||
|
template<long, int, typename> class BLOCK,
|
||||||
long num_filters,
|
long num_filters,
|
||||||
template<typename> class BN, // bn_con or affine
|
|
||||||
typename SUBNET
|
typename SUBNET
|
||||||
>
|
>
|
||||||
using residual_down = dlib::add_prev2<dlib::avg_pool<2, 2, 2, 2,
|
using residual_block = relu<RESIDUAL<BLOCK, num_filters, SUBNET>>;
|
||||||
dlib::skip1<dlib::tag2<BLOCK<num_filters, BN, 2,
|
|
||||||
dlib::tag1<SUBNET>>>>>>;
|
|
||||||
|
|
||||||
// residual block with optional downsampling and custom regularization (bn_con or affine)
|
|
||||||
template<
|
|
||||||
template<template<long, template<typename> class, int, typename> class, long, template<typename>class, typename> class RESIDUAL,
|
|
||||||
template<long, template<typename> class, int, typename> class BLOCK,
|
|
||||||
long num_filters,
|
|
||||||
template<typename> class BN, // bn_con or affine
|
|
||||||
typename SUBNET
|
|
||||||
>
|
|
||||||
using residual_block = dlib::relu<RESIDUAL<BLOCK, num_filters, BN, SUBNET>>;
|
|
||||||
|
|
||||||
template<long num_filters, typename SUBNET>
|
template<long num_filters, typename SUBNET>
|
||||||
using resbasicblock_down = residual_block<residual_down, basicblock, num_filters, BATCHNORM, SUBNET>;
|
using resbasicblock_down = residual_block<residual_down, basicblock, num_filters, SUBNET>;
|
||||||
template<long num_filters, typename SUBNET>
|
template<long num_filters, typename SUBNET>
|
||||||
using resbottleneck_down = residual_block<residual_down, bottleneck, num_filters, BATCHNORM, SUBNET>;
|
using resbottleneck_down = residual_block<residual_down, bottleneck, num_filters, SUBNET>;
|
||||||
|
|
||||||
// some definitions to allow the use of the repeat layer
|
// some definitions to allow the use of the repeat layer
|
||||||
template<typename SUBNET> using resbasicblock_512 = residual_block<residual, basicblock, 512, BATCHNORM, SUBNET>;
|
template<typename SUBNET> using resbasicblock_512 = residual_block<residual, basicblock, 512, SUBNET>;
|
||||||
template<typename SUBNET> using resbasicblock_256 = residual_block<residual, basicblock, 256, BATCHNORM, SUBNET>;
|
template<typename SUBNET> using resbasicblock_256 = residual_block<residual, basicblock, 256, SUBNET>;
|
||||||
template<typename SUBNET> using resbasicblock_128 = residual_block<residual, basicblock, 128, BATCHNORM, SUBNET>;
|
template<typename SUBNET> using resbasicblock_128 = residual_block<residual, basicblock, 128, SUBNET>;
|
||||||
template<typename SUBNET> using resbasicblock_64 = residual_block<residual, basicblock, 64, BATCHNORM, SUBNET>;
|
template<typename SUBNET> using resbasicblock_64 = residual_block<residual, basicblock, 64, SUBNET>;
|
||||||
template<typename SUBNET> using resbottleneck_512 = residual_block<residual, bottleneck, 512, BATCHNORM, SUBNET>;
|
template<typename SUBNET> using resbottleneck_512 = residual_block<residual, bottleneck, 512, SUBNET>;
|
||||||
template<typename SUBNET> using resbottleneck_256 = residual_block<residual, bottleneck, 256, BATCHNORM, SUBNET>;
|
template<typename SUBNET> using resbottleneck_256 = residual_block<residual, bottleneck, 256, SUBNET>;
|
||||||
template<typename SUBNET> using resbottleneck_128 = residual_block<residual, bottleneck, 128, BATCHNORM, SUBNET>;
|
template<typename SUBNET> using resbottleneck_128 = residual_block<residual, bottleneck, 128, SUBNET>;
|
||||||
template<typename SUBNET> using resbottleneck_64 = residual_block<residual, bottleneck, 64, BATCHNORM, SUBNET>;
|
template<typename SUBNET> using resbottleneck_64 = residual_block<residual, bottleneck, 64, SUBNET>;
|
||||||
|
|
||||||
// common processing for standard resnet inputs
|
// common processing for standard resnet inputs
|
||||||
template<template<typename> class BN, typename INPUT>
|
template<typename INPUT>
|
||||||
using input_processing = dlib::max_pool<3, 3, 2, 2, dlib::relu<BN<dlib::con<64, 7, 7, 2, 2, INPUT>>>>;
|
using input_processing = max_pool<3, 3, 2, 2, relu<BN<con<64, 7, 7, 2, 2, INPUT>>>>;
|
||||||
|
|
||||||
// the resnet backbone with basicblocks
|
// the resnet backbone with basicblocks
|
||||||
template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
|
template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
|
||||||
using backbone_basicblock =
|
using backbone_basicblock =
|
||||||
dlib::repeat<nb_512, resbasicblock_512, resbasicblock_down<512,
|
repeat<nb_512, resbasicblock_512, resbasicblock_down<512,
|
||||||
dlib::repeat<nb_256, resbasicblock_256, resbasicblock_down<256,
|
repeat<nb_256, resbasicblock_256, resbasicblock_down<256,
|
||||||
dlib::repeat<nb_128, resbasicblock_128, resbasicblock_down<128,
|
repeat<nb_128, resbasicblock_128, resbasicblock_down<128,
|
||||||
dlib::repeat<nb_64, resbasicblock_64, input_processing<BATCHNORM, INPUT>>>>>>>>;
|
repeat<nb_64, resbasicblock_64, input_processing<INPUT>>>>>>>>;
|
||||||
|
|
||||||
// the resnet backbone with bottlenecks
|
// the resnet backbone with bottlenecks
|
||||||
template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
|
template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
|
||||||
using backbone_bottleneck =
|
using backbone_bottleneck =
|
||||||
dlib::repeat<nb_512, resbottleneck_512, resbottleneck_down<512,
|
repeat<nb_512, resbottleneck_512, resbottleneck_down<512,
|
||||||
dlib::repeat<nb_256, resbottleneck_256, resbottleneck_down<256,
|
repeat<nb_256, resbottleneck_256, resbottleneck_down<256,
|
||||||
dlib::repeat<nb_128, resbottleneck_128, resbottleneck_down<128,
|
repeat<nb_128, resbottleneck_128, resbottleneck_down<128,
|
||||||
dlib::repeat<nb_64, resbottleneck_64, input_processing<BATCHNORM, INPUT>>>>>>>>;
|
repeat<nb_64, resbottleneck_64, input_processing<INPUT>>>>>>>>;
|
||||||
|
|
||||||
// the backbones for the classic architectures
|
// the backbones for the classic architectures
|
||||||
template<typename INPUT> using backbone_18 = backbone_basicblock<1, 1, 1, 2, INPUT>;
|
template<typename INPUT> using backbone_18 = backbone_basicblock<1, 1, 1, 2, INPUT>;
|
||||||
@ -91,11 +83,24 @@ struct resnet
|
|||||||
template<typename INPUT> using backbone_152 = backbone_bottleneck<2, 35, 7, 3, INPUT>;
|
template<typename INPUT> using backbone_152 = backbone_bottleneck<2, 35, 7, 3, INPUT>;
|
||||||
|
|
||||||
// the typical classifier models
|
// the typical classifier models
|
||||||
using n18 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_18<dlib::input_rgb_image>>>>;
|
using n18 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_18<input_rgb_image>>>>;
|
||||||
using n34 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_34<dlib::input_rgb_image>>>>;
|
using n34 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_34<input_rgb_image>>>>;
|
||||||
using n50 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_50<dlib::input_rgb_image>>>>;
|
using n50 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_50<input_rgb_image>>>>;
|
||||||
using n101 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_101<dlib::input_rgb_image>>>>;
|
using n101 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_101<input_rgb_image>>>>;
|
||||||
using n152 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_152<dlib::input_rgb_image>>>>;
|
using n152 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_152<input_rgb_image>>>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using train_18 = def<bn_con>::n18;
|
||||||
|
using train_34 = def<bn_con>::n34;
|
||||||
|
using train_50 = def<bn_con>::n50;
|
||||||
|
using train_101 = def<bn_con>::n101;
|
||||||
|
using train_152 = def<bn_con>::n152;
|
||||||
|
|
||||||
|
using infer_18 = def<affine>::n18;
|
||||||
|
using infer_34 = def<affine>::n34;
|
||||||
|
using infer_50 = def<affine>::n50;
|
||||||
|
using infer_101 = def<affine>::n101;
|
||||||
|
using infer_152 = def<affine>::n152;
|
||||||
|
}
|
||||||
|
|
||||||
#endif // ResNet_H
|
#endif // ResNet_H
|
||||||
|
Loading…
Reference in New Issue
Block a user