You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/20 10:19:03 UTC

[GitHub] chi-hung commented on issue #11283: Group Norm

chi-hung commented on issue #11283: Group Norm
URL: https://github.com/apache/incubator-mxnet/issues/11283#issuecomment-440220226
 
 
   Well, I have implemented *GroupNorm*. It's slower than ```nn.BatchNorm```, but it works (as the code below):
   ```python
   class GroupNorm(nn.HybridBlock):
       """
       If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
       GroupNorm achieves good results even at small batch sizes.
       Reference:
         https://arxiv.org/pdf/1803.08494.pdf
       """
       def __init__(self, num_channels, num_groups=32, eps=1e-5,
                    multi_precision=False, **kwargs):
           super(GroupNorm, self).__init__(**kwargs)
   
           with self.name_scope():
               self.weight = self.params.get('weight', grad_req='write',
                                             shape=(1, num_channels, 1, 1))
               self.bias = self.params.get('bias', grad_req='write',
                                           shape=(1, num_channels, 1, 1))
           self.C = num_channels
           self.G = num_groups
           self.eps = eps
           self.multi_precision = multi_precision
   
           assert self.C % self.G == 0
   
       def hybrid_forward(self, F, x, weight, bias):
   
           x_new = F.reshape(x, (0, self.G, -1))                                # (N,C,H,W) -> (N,G,H*W*C//G)
   
           if self.multi_precision:
               mean = F.mean(F.cast(x_new, "float32"),
                             axis=-1, keepdims=True)                            # (N,G,H*W*C//G) -> (N,G,1)
               mean = F.cast(mean, "float16")
           else:
               mean = F.mean(x_new, axis=-1, keepdims=True)
   
           centered_x_new = F.broadcast_minus(x_new, mean)                      # (N,G,H*W*C//G)
   
           if self.multi_precision:
               var = F.mean(F.cast(F.square(centered_x_new),"float32"),
                            axis=-1, keepdims=True)                             # (N,G,H*W*C//G) -> (N,G,1)
               var = F.cast(var, "float16")
           else:
               var = F.mean(F.square(centered_x_new), axis=-1, keepdims=True)
   
           x_new = F.broadcast_div(centered_x_new, F.sqrt(var + self.eps)       # (N,G,H*W*C//G) -> (N,C,H,W)
                                   ).reshape_like(x)
           x_new = F.broadcast_add(F.broadcast_mul(x_new, weight),bias)
           return x_new
   ```
   Clearly there are several issues, for example:
   * An operator such as ```F.moments()```  (quite common) is not implemented in MXNet yet. Hence, my implementation here could be slow?
   * When training with mixed-precision, the above implementation cast a FP16-input into FP32 to avoid loss of precision while calculating both mean & variance. Casting a FP16-tensor to FP32 and then back to FP16 wastes time (this stupid step can be eliminated if we implement this layer at the level of CUDA).
   
   I think this layer is quite important, as not everyone have plenty of GPUs ( if you have plenty, then ```F.contrib.SyncBatchNorm()``` will work well ).
   
   P.s. a question to the MXNet authors:
     There seem to be an OP called ```F.SumSquare()```( see: https://github.com/dmlc/gluon-cv/blob/0a699a5ccc21310c7ce41d4737f0de9f54fbf45a/gluoncv/model_zoo/syncbn.py#L206 ), which is used for the calculation of the second-order moment I guess. I didn't find it in MXNet's API..., does this OP really exist?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services