You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/04/17 14:54:32 UTC

[incubator-tvm] branch master updated: [RELAY][PYTORCH]GroupNorm op support added (#5358)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new f49fc36  [RELAY][PYTORCH]GroupNorm op support added (#5358)
f49fc36 is described below

commit f49fc366ce82f1b757b5140ee59f59226d6094fa
Author: Samuel <si...@huawei.com>
AuthorDate: Fri Apr 17 20:24:24 2020 +0530

    [RELAY][PYTORCH]GroupNorm op support added (#5358)
---
 include/tvm/relay/attrs/nn.h                  | 24 +++++++++
 python/tvm/relay/frontend/pytorch.py          | 21 ++++++++
 python/tvm/relay/op/nn/nn.py                  | 69 +++++++++++++++++++++++++
 src/relay/op/nn/nn.cc                         | 74 +++++++++++++++++++++++++++
 src/relay/transforms/simplify_inference.cc    | 66 ++++++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 23 +++++++++
 6 files changed, 277 insertions(+)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index 536e414..f985a90 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -959,6 +959,30 @@ struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
 };  // struct LayerNormAttrs
 
 
+/*! \brief Attributes used in group_norm operator */
+struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
+  int num_groups;
+  int axis;
+  double epsilon;
+  bool center;
+  bool scale;
+
+  TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") {
+    TVM_ATTR_FIELD(num_groups).set_default(0)
+      .describe("Specify number of groups to separate the channels into.");
+    TVM_ATTR_FIELD(axis).set_default(1)
+      .describe("Specify which shape axis denotes the channel.");
+    TVM_ATTR_FIELD(epsilon).set_default(1e-5)
+      .describe("Small float added to variance to avoid dividing by zero");
+    TVM_ATTR_FIELD(center).set_default(true)
+      .describe("If true, add offset of beta to normalized tensor; "
+                "otherwise, beta is ignored.");
+    TVM_ATTR_FIELD(scale).set_default(true)
+      .describe("If true, multiply by gamma; otherwise, gamma is ignored.");
+  }
+};  // struct GroupNormAttrs
+
+
 /*! \brief Attributes for LRN operator */
 struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
   int size;
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index ed31d34..9da3ecf 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -831,6 +831,26 @@ def _layer_norm():
                                  scale=True)
     return _impl
 
+
+def _group_norm():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        gamma = inputs[2]
+        beta = inputs[3]
+        num_groups = inputs[1]
+        epsilon = float(inputs[4])
+
+        return _op.nn.group_norm(data,
+                                 gamma=gamma,
+                                 beta=beta,
+                                 num_groups=num_groups,
+                                 axis=1,
+                                 epsilon=epsilon,
+                                 center=True,
+                                 scale=True)
+    return _impl
+
+
 def _transpose(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1630,6 +1650,7 @@ def _get_convert_map(prelude):
         "aten::batch_norm"                      : _batch_norm(),
         "aten::instance_norm"                   : _instance_norm(),
         "aten::layer_norm"                      : _layer_norm(),
+        "aten::group_norm"                      : _group_norm(),
         "aten::transpose"                       : _transpose(prelude),
         "aten::transpose_"                      : _transpose(prelude),
         "aten::t"                               : _transpose(prelude),
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index d0a81bc..622b0fa 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -1708,6 +1708,75 @@ def layer_norm(data,
     return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale)
 
 
+def group_norm(data,
+               gamma,
+               beta,
+               num_groups,
+               axis=1,
+               epsilon=1e-5,
+               center=True,
+               scale=True):
+    r"""
+    Group normalization normalizes over group of channels for each training examples.
+    We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
+    all the channels into a single group, group normalization becomes Layer normalization.
+    And, when we put each channel into different groups it becomes Instance normalization
+
+    https://arxiv.org/pdf/1803.08494.pdf
+
+    Applies group normalization to the n-dimensional input array by seperating the input channels
+    into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
+    The mean and standard-deviation are calculated separately over the each group. gamma and
+    beta are learnable per-channel affine transform parameter vectors of size num_channels.
+
+    .. math::
+
+        out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
+            * gamma + beta
+
+    Unlike batch normalization, the mean and var are computed along a group of channels.
+
+    If the input has size k on axis 1, then both gamma and beta have shape (k,).
+
+    .. note::
+
+        This operator can be optimized away for inference.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        Input to which group_norm will be applied.
+
+    gamma : tvm.relay.Expr
+        The gamma scale factor.
+
+    beta : tvm.relay.Expr
+        The beta offset factor.
+
+    num_groups : int
+        The number of groups to separate the channels into.
+
+    axis : int, optional, default=1
+        The axis of the channels.
+
+    epsilon : double, optional, default=1e-5
+        Small float added to variance to avoid dividing by zero.
+
+    center : boolean, optional, default=True
+        If True, add offset of beta to normalized tensor, If False,
+        beta is ignored.
+
+    scale : boolean, optional, default=True
+        If True, multiply by gamma. If False, gamma is not used.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The normalized data.
+    """
+    return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)
+
+
 def batch_matmul(x, y):
     r"""
     Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index b9ba74f..5cdca80 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -852,6 +852,80 @@ RELAY_REGISTER_OP("nn.layer_norm")
 .set_support_level(1)
 .add_type_rel("LayerNorm", LayerNormRel);
 
+// group_norm
+TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
+
+bool GroupNormRel(const Array<Type>& types,
+                  int num_inputs,
+                  const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  const GroupNormAttrs* param = attrs.as<GroupNormAttrs>();
+  int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
+  CHECK(axis >= 0 && axis < (int)data->shape.size());
+  reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
+  reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
+  reporter->Assign(types[3], TensorType(data->shape, data->dtype));
+
+  return true;
+}
+
+Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups,
+                   int axis, double epsilon, bool center, bool scale) {
+  auto attrs = make_object<GroupNormAttrs>();
+  attrs->num_groups =  num_groups;
+  attrs->axis = axis;
+  attrs->epsilon = epsilon;
+  attrs->center = center;
+  attrs->scale = scale;
+  static const Op& op = Op::Get("nn.group_norm");
+  return Call(op, {data, gamma, beta}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 8>(MakeGroupNorm, args, rv);
+  });
+
+RELAY_REGISTER_OP("nn.group_norm")
+.describe(R"code(
+Group normalization normalizes over group of channels for each training examples.
+We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
+all the channels into a single group, group normalization becomes Layer normalization.
+And, when we put each channel into different groups it becomes Instance normalization
+
+https://arxiv.org/pdf/1803.08494.pdf
+
+Applies group normalization to the n-dimensional input array by seperating the input channels
+into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
+The mean and standard-deviation are calculated separately over the each group. gamma and
+beta are learnable per-channel affine transform parameter vectors of size num_channels.
+
+.. math::
+
+    out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
+        * gamma + beta
+
+Unlike batch normalization, the mean and var are computed along a group of channels.
+
+If the input has size k on axis 1, then both gamma and beta have shape (k,).
+
+.. note::
+
+    This operator can be optimized away for inference.
+
+)code" TVM_ADD_FILELINE)
+.set_attrs_type<GroupNormAttrs>()
+.set_num_inputs(3)
+.add_argument("data", "Tensor", "Input to which group_norm will be applied.")
+.add_argument("gamma", "Tensor", "The gamma scale factor.")
+.add_argument("beta", "Tensor", "The beta offset factor.")
+.set_support_level(1)
+.add_type_rel("GroupNorm", GroupNormRel);
+
+
 // relay.nn.batch_matmul
 bool BatchMatmulRel(const Array<Type>& types,
                     int num_inputs,
diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc
index d349fdd..a9ceec2 100644
--- a/src/relay/transforms/simplify_inference.cc
+++ b/src/relay/transforms/simplify_inference.cc
@@ -64,6 +64,66 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
   return out;
 }
 
+
+Expr GroupNormToInferUnpack(const Attrs attrs,
+                            Expr data,
+                            Expr gamma,
+                            Expr beta,
+                            Type tdata) {
+  auto ttype = tdata.as<TensorTypeNode>();
+  CHECK(ttype);
+  const auto param = attrs.as<GroupNormAttrs>();
+  CHECK(param);
+
+  int ndim = ttype->shape.size();
+  int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
+  Array<Integer> reduced_axes;
+  Array<Integer> new_shape;
+  Array<Integer> old_shape;
+
+  int num_groups = param->num_groups;
+  int channel = ttype->shape[axis].as<IntImmNode>()->value;
+
+  // old_shape = N, C, H, W
+  // new shape = N, num_groups, C/num_groups, H, W
+  // reduce_axes = axis of (C/num_groups, H, W)
+  for (int i = 0; i < ndim; ++i) {
+      auto val = ttype->shape[i].as<IntImmNode>()->value;
+
+      // Save the old shape to reshape later
+      old_shape.push_back(val);
+      if (i == axis) {
+          new_shape.push_back(num_groups);
+          new_shape.push_back(channel / num_groups);
+          reduced_axes.push_back(i + 1);
+          continue;
+      }
+      if (i >= axis) {
+          reduced_axes.push_back(i + 1);
+      }
+      new_shape.push_back(val);
+  }
+
+  data = Reshape(data, new_shape);
+
+  Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
+  Expr mean = Mean(data, {reduced_axes}, true, false);
+  Expr var = Variance(data, mean, {reduced_axes}, true, false);
+  Expr denom = Sqrt(Add(var, epsilon));
+  Expr out = Divide(Subtract(data, mean), denom);
+
+  out = Reshape(out, old_shape);
+
+  if (param->scale) {
+    out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
+  }
+  if (param->center) {
+    out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
+  }
+
+  return out;
+}
+
 Expr LayerNormToInferUnpack(const Attrs attrs,
                             Expr data,
                             Expr gamma,
@@ -143,6 +203,7 @@ class InferenceSimplifier : public ExprMutator {
         dropout_op_(Op::Get("nn.dropout")),
         instance_norm_op_(Op::Get("nn.instance_norm")),
         layer_norm_op_(Op::Get("nn.layer_norm")),
+        group_norm_op_(Op::Get("nn.group_norm")),
         l2_norm_op_(Op::Get("nn.l2_normalize")) {}
 
   Expr VisitExpr_(const TupleGetItemNode* n) final {
@@ -170,6 +231,10 @@ class InferenceSimplifier : public ExprMutator {
       const auto* call = new_n.as<CallNode>();
       return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
                                     n->args[0]->checked_type());
+    } else if (n->op == group_norm_op_) {
+      const auto* call = new_n.as<CallNode>();
+      return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
+                                    n->args[0]->checked_type());
     } else if (n->op == instance_norm_op_) {
       const auto* call = new_n.as<CallNode>();
       return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
@@ -189,6 +254,7 @@ class InferenceSimplifier : public ExprMutator {
   const Op& dropout_op_;
   const Op& instance_norm_op_;
   const Op& layer_norm_op_;
+  const Op& group_norm_op_;
   const Op& l2_norm_op_;
   std::unordered_map<Expr, Type, ObjectHash, ObjectEqual> ty_map_;
 };
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 573fa7e..c692c5e 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -717,6 +717,28 @@ def test_forward_layernorm():
         init_weight(ln.eval())
         verify_model(ln.eval(), input_data=inp)
 
+
+def test_forward_groupnorm():
+    input_shape = [10, 6, 5, 5]
+    input_data = torch.rand(input_shape).float()
+
+    # Separate 6 channels into 3 groups
+    verify_model(torch.nn.GroupNorm(3, 6).eval(), input_data=input_data)
+
+    # Put all 6 channels into a single group (equivalent with LayerNorm)
+    verify_model(torch.nn.GroupNorm(1, 6).eval(), input_data=input_data)
+
+    # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
+    verify_model(torch.nn.GroupNorm(6, 6).eval(), input_data=input_data)
+
+    input_shape = [1, 10, 4, 7]
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.GroupNorm(1, 10).eval(), input_data=input_data)
+    verify_model(torch.nn.GroupNorm(2, 10).eval(), input_data=input_data)
+    verify_model(torch.nn.GroupNorm(5, 10).eval(), input_data=input_data)
+    verify_model(torch.nn.GroupNorm(10, 10).eval(), input_data=input_data)
+
+
 def test_forward_reshape():
     torch.set_grad_enabled(False)
     input_shape = [2, 1, 10, 1, 10]
@@ -1865,6 +1887,7 @@ if __name__ == "__main__":
     test_forward_batchnorm()
     test_forward_instancenorm()
     test_forward_layernorm()
+    test_forward_groupnorm()
     test_forward_transpose()
     test_forward_size()
     test_forward_view()