You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/07/13 08:02:46 UTC
[tvm] branch unity updated: [Unity] [Relax] [ONNX frontend] [op] Add support for Trilu operator (#15299)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1e45ca4ce3 [Unity] [Relax] [ONNX frontend] [op] Add support for Trilu operator (#15299)
1e45ca4ce3 is described below
commit 1e45ca4ce3d714a007abc959c20905599a2c2133
Author: Civitasv <37...@users.noreply.github.com>
AuthorDate: Thu Jul 13 16:02:39 2023 +0800
[Unity] [Relax] [ONNX frontend] [op] Add support for Trilu operator (#15299)
* [unity] [onnx frontend] [op] add support for trilu operator
* [unity] [onnx frontend] [op] fix ci and add test for triu.
* [unity] [onnx frontend] [op] fix improper indents.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 18 ++++++++++++++++++
tests/python/relax/test_frontend_onnx.py | 8 ++++++++
2 files changed, 26 insertions(+)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d653bb5511..9ec340c038 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -499,6 +499,23 @@ class Sqrt(OnnxOpConverter):
return relax.op.sqrt(inputs[0])
+class Trilu(OnnxOpConverter):
+ """Given a 2-D matrix or batches of 2-D matrices, returns the upper or
+ lower triangular part of the tensor(s)
+ """
+
+ @classmethod
+ def _impl_v14(cls, bb, inputs, attr, params):
+ upper = attr.get("upper", True)
+ x = inputs[0]
+ k = inputs[1] if len(inputs) > 1 else 0
+
+ if upper:
+ return relax.op.triu(x, k)
+ else:
+ return relax.op.tril(x, k)
+
+
class Relu(OnnxOpConverter):
"""Converts an onnx Relu node into an equivalent Relax expression."""
@@ -1712,6 +1729,7 @@ def _get_convert_map():
"Shape": Shape,
"Tanh": Tanh,
"Sqrt": Sqrt,
+ "Trilu": Trilu,
"Relu": Relu,
"Conv": Conv,
"Pow": Pow,
diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py
index 4c4d2d5a95..c5c094e115 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -528,6 +528,14 @@ def test_relu():
verify_unary("Relu", [32, 32])
+def test_tril():
+ verify_unary("Trilu", [3, 5, 5], attrs={"upper": False})
+
+
+def test_triu():
+ verify_unary("Trilu", [3, 5, 5], attrs={"upper": True})
+
+
def test_conv():
def _verify_conv(input_shape, weight_shape, output_shape):
bias_shape = [output_shape[1]]