#ifndef ResNet_H #define ResNet_H #include namespace resnet { using namespace dlib; // BN is bn_con or affine layer template class BN> struct def { // the resnet basic block, where BN is bn_con or affine template using basicblock = BN>>>>; // the resnet bottleneck block template using bottleneck = BN>>>>>>>; // the resnet residual, where BLOCK is either basicblock or bottleneck template class BLOCK, long num_filters, typename SUBNET> using residual = add_prev1>>; // a resnet residual that does subsampling on both paths template class BLOCK, long num_filters, typename SUBNET> using residual_down = add_prev2>>>>>; // residual block with optional downsampling template< template class, long, typename> class RESIDUAL, template class BLOCK, long num_filters, typename SUBNET > using residual_block = relu>; template using resbasicblock_down = residual_block; template using resbottleneck_down = residual_block; // some definitions to allow the use of the repeat layer template using resbasicblock_512 = residual_block; template using resbasicblock_256 = residual_block; template using resbasicblock_128 = residual_block; template using resbasicblock_64 = residual_block; template using resbottleneck_512 = residual_block; template using resbottleneck_256 = residual_block; template using resbottleneck_128 = residual_block; template using resbottleneck_64 = residual_block; // common processing for standard resnet inputs template using input_processing = max_pool<3, 3, 2, 2, relu>>>; // the resnet backbone with basicblocks template using backbone_basicblock = repeat>>>>>>>; // the resnet backbone with bottlenecks template using backbone_bottleneck = repeat>>>>>>>; // the backbones for the classic architectures template using backbone_18 = backbone_basicblock<1, 1, 1, 2, INPUT>; template using backbone_34 = backbone_basicblock<2, 5, 3, 3, INPUT>; template using backbone_50 = backbone_bottleneck<2, 5, 3, 3, INPUT>; template using backbone_101 = backbone_bottleneck<2, 22, 3, 3, INPUT>; template using backbone_152 = backbone_bottleneck<2, 35, 7, 3, INPUT>; // the typical classifier models using n18 = loss_multiclass_log>>>; using n34 = loss_multiclass_log>>>; using n50 = loss_multiclass_log>>>; using n101 = loss_multiclass_log>>>; using n152 = loss_multiclass_log>>>; }; using train_18 = def::n18; using train_34 = def::n34; using train_50 = def::n50; using train_101 = def::n101; using train_152 = def::n152; using infer_18 = def::n18; using infer_34 = def::n34; using infer_50 = def::n50; using infer_101 = def::n101; using infer_152 = def::n152; } #endif // ResNet_H