You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/07 16:10:06 UTC

[tvm] branch unity updated: [Unity][NN] Enhance ReLU and GELU support (#15885)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 969e31a978 [Unity][NN] Enhance ReLU and GELU support (#15885)
969e31a978 is described below

commit 969e31a97854d3f06b307f86009f44919be03021
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Sun Oct 8 00:09:58 2023 +0800

    [Unity][NN] Enhance ReLU and GELU support (#15885)
    
    This PR adds support for ReLU in NN module and op,
    also adds support for GELU in the NN modules.
---
 python/tvm/relax/frontend/nn/__init__.py       |  3 +++
 python/tvm/relax/frontend/nn/modules.py        | 18 ++++++++++---
 python/tvm/relax/frontend/nn/op.py             | 22 ++++++++++++++++
 tests/python/relax/test_frontend_nn_modules.py | 36 ++++++++++++++++++++++++++
 tests/python/relax/test_frontend_nn_op.py      |  2 ++
 5 files changed, 78 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py
index f195e6be5c..59cf32eaa8 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -18,6 +18,7 @@
 from . import op, spec
 from .core import Effect, ExternModule, Module, ModuleList, Parameter, Tensor
 from .modules import (
+    GELU,
     Conv1D,
     ConvTranspose1D,
     Embedding,
@@ -26,7 +27,9 @@ from .modules import (
     LayerNorm,
     Linear,
     MultiLinear,
+    ReLU,
     RMSNorm,
+    SiLU,
 )
 from .op import *
 from .subroutine import SubroutineMixin
diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py
index 4e612dfbc2..16f27a43c8 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -71,15 +71,27 @@ def _print(_, array: NDArray) -> None:
     print(f"effect.print: shape = {array.shape}, dtype = {array.dtype}, data =\n{array}")
 
 
+class ReLU(Module):
+    """Module for ReLU activation layer."""
+
+    def forward(self, x: Tensor):
+        return op.relu(x)
+
+
 class SiLU(Module):
-    """
-    Module for SiLU activation layer.
-    """
+    """Module for SiLU activation layer."""
 
     def forward(self, x: Tensor):
         return op.silu(x)
 
 
+class GELU(Module):
+    """Module for GELU activation layer."""
+
+    def forward(self, x: Tensor):
+        return op.gelu(x)
+
+
 class Identity(Module):
     """Module that does nothing, sometimes useful for naming purposes."""
 
diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py
index 3e7b9d6bb2..c6f5737fa1 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -750,6 +750,28 @@ def astype(x: Tensor, dtype: str, name: str = "astype") -> Tensor:
     return _wrap_nested(_op.astype(x._expr, dtype), name)
 
 
+def relu(x: Tensor, name: str = "relu") -> Tensor:
+    """Rectified Linear Unit (ReLU) activation function.
+
+    .. math::
+        \text{ReLU}(x) = \text{max}(x, 0)
+
+    Parameters
+    ----------
+    x : Tensor
+        The input data.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return _wrap_nested(_op.nn.relu(x._expr), name)
+
+
 def silu(x: Tensor, name: str = "silu") -> Tensor:
     r"""Sigmoid Linear Unit function
 
diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py
index cb207954e8..ed1c851815 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -28,6 +28,24 @@ from tvm.script import ir as I
 from tvm.script import relax as R
 
 
+def test_relu():
+    @R.function
+    def forward(
+        x: R.Tensor((3, 3), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+        R.func_attr({"num_input": 2})
+        with R.dataflow():
+            relu: R.Tensor((3, 3), dtype="float32") = R.nn.relu(x)
+            gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = relu, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.ReLU()
+    tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), "float32")}}, debug=True)
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
 def test_silu():
     @R.function
     def forward(
@@ -46,6 +64,24 @@ def test_silu():
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
+def test_gelu():
+    @R.function
+    def forward(
+        x: R.Tensor((3, 3), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+        R.func_attr({"num_input": 2})
+        with R.dataflow():
+            gelu: R.Tensor((3, 3), dtype="float32") = R.nn.gelu(x)
+            gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = gelu, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.GELU()
+    tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), "float32")}}, debug=True)
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
 def test_identity():
     @R.function
     def forward(
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index c7bef23124..fd77b76f9f 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -273,6 +273,7 @@ def test_chunk():
 def test_nn():
     class Model(Module):
         def test(self, x: Tensor, weight: Tensor, bias: Tensor):
+            relu_out = op.relu(x)
             silu_out = op.silu(x)
             gelu_out = op.gelu(x)
             softmax_out = op.softmax(x, axis=2)
@@ -290,6 +291,7 @@ def test_nn():
     ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
         R.func_attr({"num_input": 4})
         with R.dataflow():
+            relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x)
             silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
             gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x)
             softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2)