You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/02/09 16:39:17 UTC

[GitHub] xiaolin-cheng opened a new issue #14106: Question on ResNet C++ example with Cifar10 dataset

xiaolin-cheng opened a new issue #14106: Question on ResNet C++ example with Cifar10 dataset
URL: https://github.com/apache/incubator-mxnet/issues/14106
 
 
   Hello all,
   
   I modified the ResNet C++ example to be a ResNet18 model (4 levels and 2 blocks) working with Cifar10 dataset, but unfortunately it didn't work (validation accuracy always ~0.1). Only ResNet6 (1 level, 2 blocks) and ResNet10 (2 levels, 2 blocks) worked. My code is the following. Could you help me point out where I did wrong? Thank you very much!
   
   `Symbol getConv(const std::string & name, Symbol data,
                  int  num_filter,
                  Shape kernel, Shape stride, Shape pad,
                  bool with_relu,
                  mx_float bn_momentum) 
   `
   {
    Symbol conv_w(name + "_w");
     Symbol conv = ConvolutionNoBias(name, data, conv_w,
                                     kernel, num_filter, stride, Shape(1, 1),
                                     pad, 1, 512);
   
     Symbol gamma(name + "_gamma");
     Symbol beta(name + "_beta");
     Symbol mmean(name + "_mmean");
     Symbol mvar(name + "_mvar");
     Symbol bn = BatchNorm(name + "_bn", conv, gamma, beta, mmean, mvar, 2e-5, bn_momentum, false);
   
     if (with_relu) {
       return Activation(name + "_relu", bn, "relu");
     } else {
       return bn;
     }
   }
   
   `Symbol makeBlock(const std::string & name, Symbol data, int num_filter,
                    bool dim_match, mx_float bn_momentum) `
   {
     Shape stride;
     if (dim_match) {
       stride = Shape(1, 1);
     } else {
       stride = Shape(2, 2);
     }
   
     Symbol conv1 = getConv(name + "_conv1", data, num_filter,
                            Shape(3, 3), stride, Shape(1, 1),
                            true, bn_momentum);
   
     Symbol conv2 = getConv(name + "_conv2", conv1, num_filter,
                            Shape(3, 3), Shape(1, 1), Shape(1, 1),
                            false, bn_momentum);
   
     Symbol shortcut;
   
     if (dim_match) {
       shortcut = data;
     } else {
       Symbol shortcut_w(name + "_proj_w");
       shortcut = ConvolutionNoBias(name + "_proj", data, shortcut_w,
                                    Shape(2, 2), num_filter,
                                    Shape(2, 2), Shape(1, 1), Shape(0, 0),
                                    1, 512);
     }
   
     Symbol fused = shortcut + conv2;
     return Activation(name + "_relu", fused, "relu");
   }
   
   `Symbol getBody(Symbol data, int num_level, int num_block, int num_filter, mx_float bn_momentum) {
     for (int level = 0; level < num_level; level++) `
   {
       for (int block = 0; block < num_block; block++) {
         data = makeBlock("level" + std::to_string(level + 1) + "_block" + std::to_string(block + 1),
                          data, num_filter * (std::pow(2, level)),
                          (level == 0 || block > 0), bn_momentum);
       }
     }
     return data;
   }
   
   `Symbol ResNetSymbol(int num_class, int num_level = 2, int num_block = 2, int num_filter = 64, mx_float bn_momentum = 0.9)`
   {
     // data and label
     Symbol data = Symbol::Variable("data");
     Symbol data_label = Symbol::Variable("data_label");
   
   //===== top =====//
     Symbol conv = getConv("conv0", data, num_filter,
                           Shape(7, 7), Shape(2, 2), Shape(3, 3),
                           true, bn_momentum);
     Symbol max_pool = Pooling("max_pool", conv, Shape(3, 3), PoolingPoolType::kMax,
                               false, false, PoolingPoolingConvention::kValid,
                               Shape(2, 2), Shape(1, 1));
   
   //===== body =====//
     Symbol body = getBody(conv, num_level, num_block, num_filter, bn_momentum);
   
   //===== pool and fc =====//
     Symbol avg_pool = Pooling("avg_pool", body, Shape(7, 7), PoolingPoolType::kAvg,
                               true, false, PoolingPoolingConvention::kValid);
     Symbol flatten = Flatten("flatten", avg_pool);
     Symbol fc_w("fc_w"), fc_b("fc_b");
     Symbol fc = FullyConnected("fc", flatten, fc_w, fc_b, num_class);
   
     return SoftmaxOutput("softmax", fc, data_label);
   }

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services