You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/07/13 08:02:46 UTC

[tvm] branch unity updated: [Unity] [Relax] [ONNX frontend] [op] Add support for Trilu operator (#15299)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1e45ca4ce3 [Unity] [Relax] [ONNX frontend] [op] Add support for Trilu operator (#15299)
1e45ca4ce3 is described below

commit 1e45ca4ce3d714a007abc959c20905599a2c2133
Author: Civitasv <37...@users.noreply.github.com>
AuthorDate: Thu Jul 13 16:02:39 2023 +0800

    [Unity] [Relax] [ONNX frontend] [op] Add support for Trilu operator (#15299)
    
    * [unity] [onnx frontend] [op] add support for trilu operator
    
    * [unity] [onnx frontend] [op] fix ci and add test for triu.
    
    * [unity] [onnx frontend] [op] fix improper indents.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 18 ++++++++++++++++++
 tests/python/relax/test_frontend_onnx.py        |  8 ++++++++
 2 files changed, 26 insertions(+)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d653bb5511..9ec340c038 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -499,6 +499,23 @@ class Sqrt(OnnxOpConverter):
         return relax.op.sqrt(inputs[0])
 
 
+class Trilu(OnnxOpConverter):
+    """Given a 2-D matrix or batches of 2-D matrices, returns the upper or
+    lower triangular part of the tensor(s)
+    """
+
+    @classmethod
+    def _impl_v14(cls, bb, inputs, attr, params):
+        upper = attr.get("upper", True)
+        x = inputs[0]
+        k = inputs[1] if len(inputs) > 1 else 0
+
+        if upper:
+            return relax.op.triu(x, k)
+        else:
+            return relax.op.tril(x, k)
+
+
 class Relu(OnnxOpConverter):
     """Converts an onnx Relu node into an equivalent Relax expression."""
 
@@ -1712,6 +1729,7 @@ def _get_convert_map():
         "Shape": Shape,
         "Tanh": Tanh,
         "Sqrt": Sqrt,
+        "Trilu": Trilu,
         "Relu": Relu,
         "Conv": Conv,
         "Pow": Pow,
diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py
index 4c4d2d5a95..c5c094e115 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -528,6 +528,14 @@ def test_relu():
     verify_unary("Relu", [32, 32])
 
 
+def test_tril():
+    verify_unary("Trilu", [3, 5, 5], attrs={"upper": False})
+
+
+def test_triu():
+    verify_unary("Trilu", [3, 5, 5], attrs={"upper": True})
+
+
 def test_conv():
     def _verify_conv(input_shape, weight_shape, output_shape):
         bias_shape = [output_shape[1]]