You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/07 16:10:06 UTC
[tvm] branch unity updated: [Unity][NN] Enhance ReLU and GELU support (#15885)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 969e31a978 [Unity][NN] Enhance ReLU and GELU support (#15885)
969e31a978 is described below
commit 969e31a97854d3f06b307f86009f44919be03021
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Sun Oct 8 00:09:58 2023 +0800
[Unity][NN] Enhance ReLU and GELU support (#15885)
This PR adds support for ReLU in NN module and op,
also adds support for GELU in the NN modules.
---
python/tvm/relax/frontend/nn/__init__.py | 3 +++
python/tvm/relax/frontend/nn/modules.py | 18 ++++++++++---
python/tvm/relax/frontend/nn/op.py | 22 ++++++++++++++++
tests/python/relax/test_frontend_nn_modules.py | 36 ++++++++++++++++++++++++++
tests/python/relax/test_frontend_nn_op.py | 2 ++
5 files changed, 78 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py
index f195e6be5c..59cf32eaa8 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -18,6 +18,7 @@
from . import op, spec
from .core import Effect, ExternModule, Module, ModuleList, Parameter, Tensor
from .modules import (
+ GELU,
Conv1D,
ConvTranspose1D,
Embedding,
@@ -26,7 +27,9 @@ from .modules import (
LayerNorm,
Linear,
MultiLinear,
+ ReLU,
RMSNorm,
+ SiLU,
)
from .op import *
from .subroutine import SubroutineMixin
diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py
index 4e612dfbc2..16f27a43c8 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -71,15 +71,27 @@ def _print(_, array: NDArray) -> None:
print(f"effect.print: shape = {array.shape}, dtype = {array.dtype}, data =\n{array}")
+class ReLU(Module):
+ """Module for ReLU activation layer."""
+
+ def forward(self, x: Tensor):
+ return op.relu(x)
+
+
class SiLU(Module):
- """
- Module for SiLU activation layer.
- """
+ """Module for SiLU activation layer."""
def forward(self, x: Tensor):
return op.silu(x)
+class GELU(Module):
+ """Module for GELU activation layer."""
+
+ def forward(self, x: Tensor):
+ return op.gelu(x)
+
+
class Identity(Module):
"""Module that does nothing, sometimes useful for naming purposes."""
diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py
index 3e7b9d6bb2..c6f5737fa1 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -750,6 +750,28 @@ def astype(x: Tensor, dtype: str, name: str = "astype") -> Tensor:
return _wrap_nested(_op.astype(x._expr, dtype), name)
+def relu(x: Tensor, name: str = "relu") -> Tensor:
+ """Rectified Linear Unit (ReLU) activation function.
+
+ .. math::
+ \text{ReLU}(x) = \text{max}(x, 0)
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return _wrap_nested(_op.nn.relu(x._expr), name)
+
+
def silu(x: Tensor, name: str = "silu") -> Tensor:
r"""Sigmoid Linear Unit function
diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py
index cb207954e8..ed1c851815 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -28,6 +28,24 @@ from tvm.script import ir as I
from tvm.script import relax as R
+def test_relu():
+ @R.function
+ def forward(
+ x: R.Tensor((3, 3), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ relu: R.Tensor((3, 3), dtype="float32") = R.nn.relu(x)
+ gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = relu, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.ReLU()
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), "float32")}}, debug=True)
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_silu():
@R.function
def forward(
@@ -46,6 +64,24 @@ def test_silu():
assert_structural_equal(tvm_mod["forward"], forward, True)
+def test_gelu():
+ @R.function
+ def forward(
+ x: R.Tensor((3, 3), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gelu: R.Tensor((3, 3), dtype="float32") = R.nn.gelu(x)
+ gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = gelu, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.GELU()
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), "float32")}}, debug=True)
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_identity():
@R.function
def forward(
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index c7bef23124..fd77b76f9f 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -273,6 +273,7 @@ def test_chunk():
def test_nn():
class Model(Module):
def test(self, x: Tensor, weight: Tensor, bias: Tensor):
+ relu_out = op.relu(x)
silu_out = op.silu(x)
gelu_out = op.gelu(x)
softmax_out = op.softmax(x, axis=2)
@@ -290,6 +291,7 @@ def test_nn():
) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 4})
with R.dataflow():
+ relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x)
silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x)
softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2)