You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2022/08/02 19:49:05 UTC

[tvm] branch main updated: [Relay][Op] Trilu operator implementation (#12124)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b8893b557a [Relay][Op] Trilu operator implementation (#12124)
b8893b557a is described below

commit b8893b557a6c213dfe06f4069fad3cf5ad70051e
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue Aug 2 12:48:59 2022 -0700

    [Relay][Op] Trilu operator implementation (#12124)
    
    * Added topi trilu implementation
    
    * Implemented and tested full Trilu op.
    
    * Fix test type.
    
    * Add tril zero tests.
    
    * Add pytorch trilu integration.
    
    * Clean up torch integration.
    
    * Readded skip for zero tests.
---
 include/tvm/relay/attrs/transform.h             |  9 ++++
 python/tvm/relay/frontend/onnx.py               | 15 +++++++
 python/tvm/relay/frontend/pytorch.py            | 35 ++++-----------
 python/tvm/relay/op/_transform.py               |  4 ++
 python/tvm/relay/op/op_attrs.py                 |  5 +++
 python/tvm/relay/op/strategy/generic.py         | 28 ++++++++++++
 python/tvm/relay/op/transform.py                | 43 ++++++++++++++++++
 python/tvm/topi/transform.py                    | 58 +++++++++++++++++++++++++
 src/relay/op/tensor/transform.cc                | 50 +++++++++++++++++++++
 tests/python/frontend/onnx/test_forward.py      | 16 -------
 tests/python/frontend/pytorch/test_forward.py   | 10 +++++
 tests/python/relay/test_op_level3.py            | 29 +++++++++++++
 tests/python/topi/python/test_topi_transform.py | 39 +++++++++++++++++
 13 files changed, 298 insertions(+), 43 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index b9f8c6e1e8..2741d68eec 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -575,6 +575,15 @@ struct StftAttrs : public tvm::AttrsNode<StftAttrs> {
   }
 };  // struct StftAttrs
 
+struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
+  bool upper;
+
+  TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs") {
+    TVM_ATTR_FIELD(upper).set_default(true).describe(
+        "Whether to keep the upper or lower half of the diagonal.");
+  }
+};  // struct TriluAttrs
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 3b5bf9acfa..e78e65dc4e 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -4685,6 +4685,20 @@ class Einsum(OnnxOpConverter):
         return _op.einsum(inputs, equation)
 
 
+class Trilu(OnnxOpConverter):
+    """Operator converter for Trilu"""
+
+    @classmethod
+    def _impl_v14(cls, inputs, attr, params):
+        upper = attr.get("upper", True)
+        if len(inputs) == 2:
+            data, k = inputs
+        else:
+            data = inputs[0]
+            k = 0
+        return _op.trilu(data, k, upper)
+
+
 class RandomNormal(OnnxOpConverter):
     """Operator converter for random_normal"""
 
@@ -5345,6 +5359,7 @@ def _get_convert_map(opset):
         "CumSum": CumSum.get_converter(opset),
         "Unique": Unique.get_converter(opset),
         "Einsum": Einsum.get_converter(opset),
+        "Trilu": Trilu.get_converter(opset),
         # defs/control_flow
         "Loop": Loop.get_converter(opset),
         "If": If.get_converter(opset),
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 1bd3232871..74ea249a47 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -318,31 +318,6 @@ class PyTorchOpConverter:
         (dtype,) = input_types
         return _op.power(inputs[0], _expr.const(2, dtype))
 
-    def tril(self, inputs, input_types):
-        data = inputs[0]
-        if len(inputs) == 2:
-            k_value = inputs[1]
-        else:
-            k_value = 0
-        input_shape = self.infer_shape(data)
-        k1, k2 = input_shape[-2:]
-        k1 = k_value + 1
-        diag_input = _op.zeros(input_shape, dtype=input_types[0])
-        return _op.matrix_set_diag(data, diag_input, k=(k1, k2))
-
-    def triu(self, inputs, input_types):
-        data = inputs[0]
-        if len(inputs) == 2:
-            k_value = inputs[1]
-        else:
-            k_value = 0
-        input_shape = self.infer_shape(data)
-        k1, k2 = input_shape[-2:]
-        k1 = (k1 * -1) - 1
-        k2 = k_value - 1
-        diag_input = _op.zeros(input_shape, dtype=input_types[0])
-        return _op.matrix_set_diag(data, diag_input, k=(k1, k2))
-
     def lerp(self, inputs, input_types):
         if len(inputs) != 3:
             msg = "Wrong number of arguments (%d) to parse." % (len(inputs))
@@ -3405,6 +3380,12 @@ class PyTorchOpConverter:
             inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners
         )
 
+    def trilu(self, inputs, input_types, mode):
+        data = inputs[0]
+        k = inputs[1] if inputs[1] else 0
+        upper = True if mode == "triu" else False
+        return _op.trilu(data, k, upper)
+
     # Operator mappings
     def create_convert_map(self):
         self.convert_map = {
@@ -3567,8 +3548,8 @@ class PyTorchOpConverter:
             "aten::sqrt": self.make_unary("sqrt"),
             "aten::rsqrt": self.make_unary("rsqrt"),
             "aten::square": self.square,
-            "aten::tril": self.tril,
-            "aten::triu": self.triu,
+            "aten::tril": functools.partial(self.trilu, mode="tril"),
+            "aten::triu": functools.partial(self.trilu, mode="triu"),
             "aten::ceil": self.make_unary("ceil"),
             "aten::floor": self.make_unary("floor"),
             "aten::round": self.make_unary("round"),
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index baf616a946..951de06967 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -191,6 +191,10 @@ def stft_shape_func(attrs, inputs, _):
     ]
 
 
+# trilu
+_reg.register_strategy("trilu", strategy.trilu_strategy)
+
+
 # scatter_add
 @_reg.register_compute("scatter_add")
 def compute_scatter_add(attrs, inputs, output_type):
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 8b92fdf267..7e8367abbb 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -617,3 +617,8 @@ class NLLLossAttrs(Attrs):
 @tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs")
 class FixedPointMultiplyAttrs(Attrs):
     """Attributes used in fixed_point_multiply operators"""
+
+
+@tvm._ffi.register_object("relay.attrs.TriluAttrs")
+class TriluAttrs(Attrs):
+    """Attributes used in trilu operators"""
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 6074b0a69c..95558b5f3d 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1460,6 +1460,34 @@ def wrap_compute_stft(topi_compute):
     return _compute_stft
 
 
+# trilu
+@override_native_generic_func("trilu_strategy")
+def trilu_strategy(attrs, outs, out_type, target):
+    """trilu generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_trilu(topi.trilu),
+        wrap_topi_schedule(topi.generic.schedule_extern),
+        name="trilu.generic",
+    )
+    return strategy
+
+
+def wrap_compute_trilu(topi_compute):
+    """Wrap trilu compute"""
+
+    def _compute_trilu(attrs, inputs, output_type):
+        return [
+            topi_compute(
+                inputs[0],
+                inputs[1],
+                attrs.upper,
+            )
+        ]
+
+    return _compute_trilu
+
+
 # roi_pool
 @generic_func
 def schedule_roi_pool(attrs, outs, target):
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index b5d44781e5..e7ae5f7d83 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -1889,3 +1889,46 @@ def stft(
         window = _make.ones([n_fft], "int32")
 
     return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)
+
+
+def trilu(data, k, upper=True):
+    """
+    Given a 2-D matrix or batches of 2-D matrices, returns the
+    upper or lower triangular part of the tensor.
+
+    Parameters
+    ----------
+    data: relay.Expr
+        The tensor that trilu will be applied to. Must be either
+        a 2D matrix or a tensor of batches of 2D matrices.
+
+    k: int
+        The number of diagonals above or below the main diagonal
+        to exclude or include.
+
+    upper: bool, optional
+        If True, only upper triangular values of input are kept,
+        if False, the lower triangular values are kept.
+
+
+    Returns
+    -------
+    ret : relay.Expr
+        The new tensor with appropriate diagonals set to zero.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        x = [[0, 1, 2],
+             [3, 4, 5],
+             [6, 7, 8]]
+
+        relay.trilu(x, True, 0) =
+            [[0, 1, 2],
+             [0, 4, 5],
+             [0, 0, 8]]
+    """
+    if not isinstance(k, Expr):
+        k = const(k, dtype="int32")
+    return _make.trilu(data, k, upper)
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index d99d6772b0..e12f80e2ef 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -1001,3 +1001,61 @@ def sliding_window(data, axis, window_shape, strides):
         The resulting tensor.
     """
     return cpp.sliding_window(data, axis, window_shape, strides)
+
+
+def trilu(data, k, upper):
+    """
+    Given a 2-D matrix or batches of 2-D matrices, returns the
+    upper or lower triangular part of the tensor.
+
+    Parameters
+    ----------
+    data: tvm.te.Tensor
+        The tensor that trilu will be applied to. Must be either
+        a 2D matrix or a tensor of batches of 2D matrices.
+
+    k: tvm.te.Tensor
+        The number of diagonals above or below the main diagonal
+        to exclude or include.
+
+    upper: bool
+        If True, only upper triangular values of input are kept,
+        if False, the lower triangular values are kept.
+
+
+    Returns
+    -------
+    ret : relay.Expr
+        The new tensor with appropriate diagonals set to zero.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        x = [[0, 1, 2],
+             [3, 4, 5],
+             [6, 7, 8]]
+
+        relay.trilu(x, True, 0) =
+            [[0, 1, 2],
+             [0, 4, 5],
+             [0, 0, 8]]
+    """
+    # Make sure datatype is consistent.
+    if k.dtype != "int32":
+        k = tvm.tir.Cast("int32", k)
+
+    # Check either above or below diagonal depending on upper.
+    check_op = tvm.tir.GE
+    if upper:
+        check_op = tvm.tir.LE
+
+    def _apply_trilu(*indices):
+        row_index = indices[-2]
+        col_index = indices[-1]
+        other_indices = indices[:-2]
+        check_position = check_op(row_index, col_index - k)
+        value = data(*other_indices, row_index, col_index)
+        return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype))
+
+    return te.compute(data.shape, _apply_trilu, name="trilu")
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 989ab2ad25..f90cd91e92 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -4230,5 +4230,55 @@ RELAY_REGISTER_OP("invert_permutation")
     .set_attr<TOpPattern>("TOpPattern", kInjective)
     .set_attr<TOpIsStateful>("TOpIsStateful", false);
 
+// Trilu
+
+TVM_REGISTER_NODE_TYPE(TriluAttrs);
+
+bool TriluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+              const TypeReporter& reporter) {
+  // types: [data, k, result]
+  ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided";
+  ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided";
+  auto data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    ICHECK(types[0].as<IncompleteTypeNode>())
+        << "Trilu: expect input type to be TensorType but get " << types[0];
+    return false;
+  }
+
+  auto k = types[1].as<TensorTypeNode>();
+  if (k == nullptr) {
+    ICHECK(types[1].as<IncompleteTypeNode>())
+        << "Trilu: expect k type to be TensorType but get " << types[1];
+    return false;
+  }
+
+  ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k;
+
+  // Output shape is the same as input shape.
+  reporter->Assign(types[2], TensorType(data->shape, data->dtype));
+  return true;
+}
+
+Expr MakeTrilu(Expr data, Expr k, bool upper) {
+  auto attrs = make_object<TriluAttrs>();
+  attrs->upper = upper;
+  static const Op& op = Op::Get("trilu");
+  return Call(op, {data, k}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.trilu").set_body_typed(MakeTrilu);
+
+RELAY_REGISTER_OP("trilu")
+    .describe(
+        R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal.
+    )code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor")
+    .add_argument("k", "Tensor", "The number of diagonals above or below the main to exclude.")
+    .add_type_rel("trilu", TriluRel)
+    .set_support_level(3)
+    .set_attr<TOpPattern>("TOpPattern", kElemWise);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 0b2e51e544..e500f0902c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5242,23 +5242,7 @@ unsupported_onnx_tests = [
     "test_training_dropout_mask",
     "test_training_dropout_zero_ratio",
     "test_training_dropout_zero_ratio_mask",
-    "test_tril",
-    "test_tril_pos",
-    "test_tril_square",
-    "test_tril_square_neg",
-    "test_tril_neg",
-    "test_tril_one_row_neg",
-    "test_tril_out_neg",
-    "test_tril_out_pos",
     "test_tril_zero",
-    "test_triu",
-    "test_triu_one_row",
-    "test_triu_out_neg_out",
-    "test_triu_out_pos",
-    "test_triu_neg",
-    "test_triu_pos",
-    "test_triu_square",
-    "test_triu_square_neg",
     "test_triu_zero",
     "test_unique_sorted_with_axis",
     "test_unique_sorted_with_axis_3d",
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index f52c7168b3..1d07c780b7 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4616,5 +4616,15 @@ def test_lerp():
     verify_model(test_fn, [x, y, w[0]])
 
 
+def test_trilu():
+    def _test_trilu(op, diagonal):
+        return lambda inp: op(inp, diagonal)
+
+    for op in [torch.triu, torch.tril]:
+        verify_model(_test_trilu(op, 0), [torch.rand(size=[3, 3])])
+        verify_model(_test_trilu(op, 1), [torch.rand(size=[6, 6])])
+        verify_model(_test_trilu(op, -2), [torch.rand(size=[6, 6])])
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index f91a027de4..b641ba1fdb 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -2207,5 +2207,34 @@ class TestSTFT:
         )
 
 
+def test_trilu(target="llvm", dev=tvm.cpu()):
+    def verify_trilu(data_shape, upper=True, k=0):
+        data = relay.var("data", relay.TensorType(data_shape, "float32"))
+        y = relay.trilu(data, k, upper)
+        mod = tvm.ir.IRModule.from_expr(y)
+
+        data_np = np.random.normal(size=data_shape).astype("float32")
+        tvm_res = (
+            relay.create_executor("graph", mod=mod, device=dev, target=target)
+            .evaluate()(data_np)
+            .numpy()
+        )
+        if upper:
+            np_res = np.triu(data_np, k)
+        else:
+            np_res = np.tril(data_np, k)
+        tvm.testing.assert_allclose(tvm_res, np_res)
+
+    # Test upper and lower triangle
+    verify_trilu((3, 3), True, 0)
+    verify_trilu((3, 3), False, 0)
+    # Test larger matrices with offset.
+    verify_trilu((6, 6), True, 1)
+    verify_trilu((6, 6), False, 2)
+    verify_trilu((6, 6), False, -2)
+    # Test batch size
+    verify_trilu((8, 6, 6), False, -2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index 180f267650..c3155c948a 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -812,6 +812,31 @@ def verify_adv_index(data_shape, index_shapes, indice_dtype="int64"):
         check_device(target, dev)
 
 
+def verify_trilu(input_shape, upper, k=0):
+    x = te.placeholder(shape=input_shape, name="x", dtype="float32")
+    k_tir = tvm.tir.const(k, dtype="int32")
+    trilu_result = topi.transform.trilu(x, k_tir, upper)
+
+    def check_device(target, dev):
+        print("Running on target: %s" % target)
+        with tvm.target.Target(target):
+            s = tvm.topi.testing.get_injective_schedule(target)(trilu_result)
+        fn = tvm.build(s, [x, trilu_result], target, name="trilu")
+        x_npy = np.random.normal(size=input_shape).astype(x.dtype)
+        if upper:
+            out_npy = np.triu(x_npy, k)
+        else:
+            out_npy = np.tril(x_npy, k)
+        x_nd = tvm.nd.array(x_npy, dev)
+        out_nd = tvm.nd.array(np.empty(x_npy.shape).astype(trilu_result.dtype), dev)
+        fn(x_nd, out_nd)
+        out_topi = out_nd.numpy()
+        tvm.testing.assert_allclose(out_topi, out_npy)
+
+    for target, dev in tvm.testing.enabled_targets():
+        check_device(target, dev)
+
+
 @tvm.testing.uses_gpu
 def test_strided_slice():
     verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
@@ -1256,6 +1281,19 @@ def test_adv_index():
         verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)], indice_dtype=indice_dtype)
 
 
+@tvm.testing.uses_gpu
+def test_trilu():
+    # Test upper and lower triangle
+    verify_trilu((3, 3), True, 0)
+    verify_trilu((3, 3), False, 0)
+    # Test larger matrices with offset.
+    verify_trilu((6, 6), True, 1)
+    verify_trilu((6, 6), False, 2)
+    verify_trilu((6, 6), False, -2)
+    # Test batch size
+    verify_trilu((8, 6, 6), False, -2)
+
+
 if __name__ == "__main__":
     test_strided_slice()
     test_concatenate()
@@ -1283,3 +1321,4 @@ if __name__ == "__main__":
     test_sparse_to_dense()
     test_matrix_set_diag()
     test_adv_index()
+    test_trilu()