You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/08/19 04:41:17 UTC

[tvm] branch unity updated: [Unity][Frontend][NN] Add GroupNorm Layer (#15592)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 497b7481d2 [Unity][Frontend][NN] Add GroupNorm Layer (#15592)
497b7481d2 is described below

commit 497b7481d2f55c0ddf414590e31ff562a1ecb8cd
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Fri Aug 18 21:41:10 2023 -0700

    [Unity][Frontend][NN] Add GroupNorm Layer (#15592)
    
    * Add nn group_norm module
    
    * Remove debugging print
---
 python/tvm/relax/frontend/nn/modules.py        | 53 ++++++++++++++++++++++----
 python/tvm/relax/frontend/nn/op.py             | 53 ++++++++++++++++++++++++++
 tests/python/relax/test_frontend_nn_modules.py | 23 +++++++++++
 tests/python/relax/test_frontend_nn_op.py      | 25 ++++++++----
 4 files changed, 140 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py
index 228f6bdd6e..03b47619f0 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=too-many-arguments,invalid-name,protected-access
 """Builtin Modules."""
 from typing import List, Optional, Sequence, Union
 
@@ -65,7 +66,7 @@ class Linear(Module):
     Module for linear layer.
     """
 
-    def __init__(  # pylint: disable=too-many-arguments
+    def __init__(
         self,
         in_features: int,
         out_features: int,
@@ -83,7 +84,7 @@ class Linear(Module):
         else:
             self.bias = None
 
-    def forward(self, x: Tensor) -> Tensor:  # pylint: disable=invalid-name
+    def forward(self, x: Tensor) -> Tensor:
         """
         Forward method for linear layer.
 
@@ -99,7 +100,7 @@ class Linear(Module):
         """
         # x: [*B, in_features]
         # w: [in_features, out_features]
-        w = op.permute_dims(self.weight)  # pylint: disable=invalid-name
+        w = op.permute_dims(self.weight)
         # x: [*B, out_features]
         x = op.matmul(x, w, out_dtype=self.out_dtype)
         if self.bias is not None:
@@ -129,7 +130,6 @@ class RMSNorm(Module):
         else:
             self.bias = None
 
-    # pylint: disable=invalid-name
     def forward(self, x: Tensor):
         """
         Forward method for rms norm layer.
@@ -149,7 +149,46 @@ class RMSNorm(Module):
             out = op.add(out, self.bias)
         return out
 
-    # pylint: enable=invalid-name
+
+class GroupNorm(Module):
+    """
+    Module for group norm layer.
+    """
+
+    def __init__(
+        self,
+        num_groups: int,
+        num_channels: int,
+        eps: float = 1e-5,
+        affine: bool = True,
+        dtype: Optional[str] = None,
+    ):
+        super().__init__()
+        self.num_groups = num_groups
+        self.num_channels = num_channels
+        self.eps = eps
+        if affine:
+            self.weight = Parameter((num_channels,), dtype=dtype)
+            self.bias = Parameter((num_channels,), dtype=dtype)
+        else:
+            self.weight = None
+            self.bias = None
+
+    def forward(self, x: Tensor):
+        """
+        Forward method for group norm layer.
+
+        Parameters
+        ----------
+        x : Tensor
+            The input tensor.
+
+        Returns
+        -------
+        ret : Tensor
+            The output tensor for the group norm layer.
+        """
+        return op.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
 
 
 class KVCache(Effect):
@@ -284,7 +323,7 @@ class KVCache(Effect):
         self.cache = rx.BlockBuilder.current().emit(
             rx.Call(
                 rx.extern("vm.builtin.attention_kv_cache_append"),
-                args=[self.cache, new_element._expr],  # pylint: disable=protected-access
+                args=[self.cache, new_element._expr],
                 sinfo_args=[rx.ObjectStructInfo()],
             )
         )
@@ -300,7 +339,7 @@ class Embedding(Module):
         self.dim = dim
         self.weight = Parameter((num, dim), dtype=dtype)
 
-    def forward(self, x: Tensor):  # pylint: disable=invalid-name
+    def forward(self, x: Tensor):
         """
         Forward method for embedding layer.
 
diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py
index 9a98aa1adb..d4821d8185 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -527,6 +527,59 @@ def rms_norm(
     return _wrap_nested(_op.nn.rms_norm(x._expr, weight._expr, axes, epsilon), name)
 
 
+def group_norm(
+    x: Tensor,
+    num_groups: int,
+    weight: Optional[Tensor],
+    bias: Optional[Tensor],
+    eps: float = 1e-5,
+    name: str = "group_norm",
+) -> Tensor:
+    r"""
+    Applies Group Normalization over a mini-batch of inputs as described in
+    the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
+
+    .. math::
+        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+    Parameters
+    ----------
+    x : Tensor
+        Input to which rms_norm will be applied.
+
+    num_groups : int
+        Number of groups to separate the channels into.
+
+    weight : Tensor
+        The gamma scale factor.
+
+    bias : Tensor
+        The beta offset factor.
+
+    epsilon : float
+        Small float added to square mean to avoid dividing by zero.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    if weight is not None:
+        weight = weight._expr
+    if bias is not None:
+        bias = bias._expr
+    dim = len(x._expr.struct_info.shape)
+    return _wrap_nested(
+        _op.nn.group_norm(
+            x._expr, weight, bias, num_groups, channel_axis=1, axes=list(range(2, dim)), epsilon=eps
+        ),
+        name,
+    )
+
+
 def triu(x: Tensor, diagonal: int = 0, name: str = "triu") -> Tensor:
     """Return the upper triangular part of a matrix or a batch of matrices.
 
diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py
index 27bbd68329..8302257648 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -68,6 +68,29 @@ def test_rms_norm():
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
+def test_group_norm():
+    @R.function
+    def forward(
+        x: R.Tensor((2, 4, 8), dtype="float32"),
+        weight: R.Tensor((4,), dtype="float32"),
+        bias: R.Tensor((4,), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
+        with R.dataflow():
+            group_norm: R.Tensor((2, 4, 8), dtype="float32") = R.nn.group_norm(
+                x, weight, bias, num_groups=2, channel_axis=1, axes=[2]
+            )
+            gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)) = group_norm, (
+                _io,
+            )
+            R.output(gv1)
+        return gv1
+
+    mod = modules.GroupNorm(num_groups=2, num_channels=4)
+    tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}})
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
 def test_embedding():
     @R.function
     def forward(
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index 0346a6f871..41a29fa629 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -137,25 +137,36 @@ def test_datatype():
 
 def test_nn():
     class Model(Module):
-        def test(self, x: Tensor, weight: Tensor):
+        def test(self, x: Tensor, weight: Tensor, bias: Tensor):
             silu_out = op.silu(x)
             softmax_out = op.softmax(x, axis=2)
             rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1])
             rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1])
+            group_norm_out = op.group_norm(x, num_groups=1, weight=bias, bias=bias)
             return x
 
-    # fmt: off
     @R.function
-    def test(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: R.Tensor((4, 5), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
+    def test(
+        x: R.Tensor((2, 3, 4, 5), dtype="float32"),
+        weight: R.Tensor((4, 5), dtype="float32"),
+        bias: R.Tensor((3,), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
         with R.dataflow():
             silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
             softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2)
-            rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05)
-            rms_norm1: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05)
+            rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
+                x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05
+            )
+            rms_norm1: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
+                x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05
+            )
+            group_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.group_norm(
+                x, bias, bias, num_groups=1, channel_axis=1, axes=[2, 3]
+            )
             gv1: R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)) = x, (_io,)
             R.output(gv1)
         return gv1
-    # fmt: on
 
     m = Model()
     irmodule, params = m.export_tvm(
@@ -163,7 +174,7 @@ def test_nn():
             "test": {
                 "x": spec.Tensor([2, 3, 4, 5], "float32"),
                 "weight": spec.Tensor([4, 5], "float32"),
-                "bias": spec.Tensor([4, 5], "float32"),
+                "bias": spec.Tensor([3], "float32"),
             }
         }
     )