You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/08/19 04:41:17 UTC
[tvm] branch unity updated: [Unity][Frontend][NN] Add GroupNorm Layer (#15592)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 497b7481d2 [Unity][Frontend][NN] Add GroupNorm Layer (#15592)
497b7481d2 is described below
commit 497b7481d2f55c0ddf414590e31ff562a1ecb8cd
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Fri Aug 18 21:41:10 2023 -0700
[Unity][Frontend][NN] Add GroupNorm Layer (#15592)
* Add nn group_norm module
* Remove debugging print
---
python/tvm/relax/frontend/nn/modules.py | 53 ++++++++++++++++++++++----
python/tvm/relax/frontend/nn/op.py | 53 ++++++++++++++++++++++++++
tests/python/relax/test_frontend_nn_modules.py | 23 +++++++++++
tests/python/relax/test_frontend_nn_op.py | 25 ++++++++----
4 files changed, 140 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py
index 228f6bdd6e..03b47619f0 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=too-many-arguments,invalid-name,protected-access
"""Builtin Modules."""
from typing import List, Optional, Sequence, Union
@@ -65,7 +66,7 @@ class Linear(Module):
Module for linear layer.
"""
- def __init__( # pylint: disable=too-many-arguments
+ def __init__(
self,
in_features: int,
out_features: int,
@@ -83,7 +84,7 @@ class Linear(Module):
else:
self.bias = None
- def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
+ def forward(self, x: Tensor) -> Tensor:
"""
Forward method for linear layer.
@@ -99,7 +100,7 @@ class Linear(Module):
"""
# x: [*B, in_features]
# w: [in_features, out_features]
- w = op.permute_dims(self.weight) # pylint: disable=invalid-name
+ w = op.permute_dims(self.weight)
# x: [*B, out_features]
x = op.matmul(x, w, out_dtype=self.out_dtype)
if self.bias is not None:
@@ -129,7 +130,6 @@ class RMSNorm(Module):
else:
self.bias = None
- # pylint: disable=invalid-name
def forward(self, x: Tensor):
"""
Forward method for rms norm layer.
@@ -149,7 +149,46 @@ class RMSNorm(Module):
out = op.add(out, self.bias)
return out
- # pylint: enable=invalid-name
+
+class GroupNorm(Module):
+ """
+ Module for group norm layer.
+ """
+
+ def __init__(
+ self,
+ num_groups: int,
+ num_channels: int,
+ eps: float = 1e-5,
+ affine: bool = True,
+ dtype: Optional[str] = None,
+ ):
+ super().__init__()
+ self.num_groups = num_groups
+ self.num_channels = num_channels
+ self.eps = eps
+ if affine:
+ self.weight = Parameter((num_channels,), dtype=dtype)
+ self.bias = Parameter((num_channels,), dtype=dtype)
+ else:
+ self.weight = None
+ self.bias = None
+
+ def forward(self, x: Tensor):
+ """
+ Forward method for group norm layer.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input tensor.
+
+ Returns
+ -------
+ ret : Tensor
+ The output tensor for the group norm layer.
+ """
+ return op.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class KVCache(Effect):
@@ -284,7 +323,7 @@ class KVCache(Effect):
self.cache = rx.BlockBuilder.current().emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_append"),
- args=[self.cache, new_element._expr], # pylint: disable=protected-access
+ args=[self.cache, new_element._expr],
sinfo_args=[rx.ObjectStructInfo()],
)
)
@@ -300,7 +339,7 @@ class Embedding(Module):
self.dim = dim
self.weight = Parameter((num, dim), dtype=dtype)
- def forward(self, x: Tensor): # pylint: disable=invalid-name
+ def forward(self, x: Tensor):
"""
Forward method for embedding layer.
diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py
index 9a98aa1adb..d4821d8185 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -527,6 +527,59 @@ def rms_norm(
return _wrap_nested(_op.nn.rms_norm(x._expr, weight._expr, axes, epsilon), name)
+def group_norm(
+ x: Tensor,
+ num_groups: int,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ eps: float = 1e-5,
+ name: str = "group_norm",
+) -> Tensor:
+ r"""
+ Applies Group Normalization over a mini-batch of inputs as described in
+ the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
+
+ .. math::
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ Parameters
+ ----------
+ x : Tensor
+ Input to which rms_norm will be applied.
+
+ num_groups : int
+ Number of groups to separate the channels into.
+
+ weight : Tensor
+ The gamma scale factor.
+
+ bias : Tensor
+ The beta offset factor.
+
+ epsilon : float
+ Small float added to square mean to avoid dividing by zero.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ if weight is not None:
+ weight = weight._expr
+ if bias is not None:
+ bias = bias._expr
+ dim = len(x._expr.struct_info.shape)
+ return _wrap_nested(
+ _op.nn.group_norm(
+ x._expr, weight, bias, num_groups, channel_axis=1, axes=list(range(2, dim)), epsilon=eps
+ ),
+ name,
+ )
+
+
def triu(x: Tensor, diagonal: int = 0, name: str = "triu") -> Tensor:
"""Return the upper triangular part of a matrix or a batch of matrices.
diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py
index 27bbd68329..8302257648 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -68,6 +68,29 @@ def test_rms_norm():
assert_structural_equal(tvm_mod["forward"], forward, True)
+def test_group_norm():
+ @R.function
+ def forward(
+ x: R.Tensor((2, 4, 8), dtype="float32"),
+ weight: R.Tensor((4,), dtype="float32"),
+ bias: R.Tensor((4,), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ group_norm: R.Tensor((2, 4, 8), dtype="float32") = R.nn.group_norm(
+ x, weight, bias, num_groups=2, channel_axis=1, axes=[2]
+ )
+ gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)) = group_norm, (
+ _io,
+ )
+ R.output(gv1)
+ return gv1
+
+ mod = modules.GroupNorm(num_groups=2, num_channels=4)
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}})
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_embedding():
@R.function
def forward(
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index 0346a6f871..41a29fa629 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -137,25 +137,36 @@ def test_datatype():
def test_nn():
class Model(Module):
- def test(self, x: Tensor, weight: Tensor):
+ def test(self, x: Tensor, weight: Tensor, bias: Tensor):
silu_out = op.silu(x)
softmax_out = op.softmax(x, axis=2)
rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1])
rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1])
+ group_norm_out = op.group_norm(x, num_groups=1, weight=bias, bias=bias)
return x
- # fmt: off
@R.function
- def test(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: R.Tensor((4, 5), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
+ def test(
+ x: R.Tensor((2, 3, 4, 5), dtype="float32"),
+ weight: R.Tensor((4, 5), dtype="float32"),
+ bias: R.Tensor((3,), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
with R.dataflow():
silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2)
- rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05)
- rms_norm1: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05)
+ rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
+ x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05
+ )
+ rms_norm1: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
+ x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05
+ )
+ group_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.group_norm(
+ x, bias, bias, num_groups=1, channel_axis=1, axes=[2, 3]
+ )
gv1: R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)) = x, (_io,)
R.output(gv1)
return gv1
- # fmt: on
m = Model()
irmodule, params = m.export_tvm(
@@ -163,7 +174,7 @@ def test_nn():
"test": {
"x": spec.Tensor([2, 3, 4, 5], "float32"),
"weight": spec.Tensor([4, 5], "float32"),
- "bias": spec.Tensor([4, 5], "float32"),
+ "bias": spec.Tensor([3], "float32"),
}
}
)