You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2022/08/02 19:49:05 UTC
[tvm] branch main updated: [Relay][Op] Trilu operator implementation (#12124)
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b8893b557a [Relay][Op] Trilu operator implementation (#12124)
b8893b557a is described below
commit b8893b557a6c213dfe06f4069fad3cf5ad70051e
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue Aug 2 12:48:59 2022 -0700
[Relay][Op] Trilu operator implementation (#12124)
* Added topi trilu implementation
* Implemented and tested full Trilu op.
* Fix test type.
* Add tril zero tests.
* Add pytorch trilu integration.
* Clean up torch integration.
* Readded skip for zero tests.
---
include/tvm/relay/attrs/transform.h | 9 ++++
python/tvm/relay/frontend/onnx.py | 15 +++++++
python/tvm/relay/frontend/pytorch.py | 35 ++++-----------
python/tvm/relay/op/_transform.py | 4 ++
python/tvm/relay/op/op_attrs.py | 5 +++
python/tvm/relay/op/strategy/generic.py | 28 ++++++++++++
python/tvm/relay/op/transform.py | 43 ++++++++++++++++++
python/tvm/topi/transform.py | 58 +++++++++++++++++++++++++
src/relay/op/tensor/transform.cc | 50 +++++++++++++++++++++
tests/python/frontend/onnx/test_forward.py | 16 -------
tests/python/frontend/pytorch/test_forward.py | 10 +++++
tests/python/relay/test_op_level3.py | 29 +++++++++++++
tests/python/topi/python/test_topi_transform.py | 39 +++++++++++++++++
13 files changed, 298 insertions(+), 43 deletions(-)
diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index b9f8c6e1e8..2741d68eec 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -575,6 +575,15 @@ struct StftAttrs : public tvm::AttrsNode<StftAttrs> {
}
}; // struct StftAttrs
+struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
+ bool upper;
+
+ TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs") {
+ TVM_ATTR_FIELD(upper).set_default(true).describe(
+ "Whether to keep the upper or lower half of the diagonal.");
+ }
+}; // struct TriluAttrs
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 3b5bf9acfa..e78e65dc4e 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -4685,6 +4685,20 @@ class Einsum(OnnxOpConverter):
return _op.einsum(inputs, equation)
+class Trilu(OnnxOpConverter):
+ """Operator converter for Trilu"""
+
+ @classmethod
+ def _impl_v14(cls, inputs, attr, params):
+ upper = attr.get("upper", True)
+ if len(inputs) == 2:
+ data, k = inputs
+ else:
+ data = inputs[0]
+ k = 0
+ return _op.trilu(data, k, upper)
+
+
class RandomNormal(OnnxOpConverter):
"""Operator converter for random_normal"""
@@ -5345,6 +5359,7 @@ def _get_convert_map(opset):
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
+ "Trilu": Trilu.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 1bd3232871..74ea249a47 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -318,31 +318,6 @@ class PyTorchOpConverter:
(dtype,) = input_types
return _op.power(inputs[0], _expr.const(2, dtype))
- def tril(self, inputs, input_types):
- data = inputs[0]
- if len(inputs) == 2:
- k_value = inputs[1]
- else:
- k_value = 0
- input_shape = self.infer_shape(data)
- k1, k2 = input_shape[-2:]
- k1 = k_value + 1
- diag_input = _op.zeros(input_shape, dtype=input_types[0])
- return _op.matrix_set_diag(data, diag_input, k=(k1, k2))
-
- def triu(self, inputs, input_types):
- data = inputs[0]
- if len(inputs) == 2:
- k_value = inputs[1]
- else:
- k_value = 0
- input_shape = self.infer_shape(data)
- k1, k2 = input_shape[-2:]
- k1 = (k1 * -1) - 1
- k2 = k_value - 1
- diag_input = _op.zeros(input_shape, dtype=input_types[0])
- return _op.matrix_set_diag(data, diag_input, k=(k1, k2))
-
def lerp(self, inputs, input_types):
if len(inputs) != 3:
msg = "Wrong number of arguments (%d) to parse." % (len(inputs))
@@ -3405,6 +3380,12 @@ class PyTorchOpConverter:
inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners
)
+ def trilu(self, inputs, input_types, mode):
+ data = inputs[0]
+ k = inputs[1] if inputs[1] else 0
+ upper = True if mode == "triu" else False
+ return _op.trilu(data, k, upper)
+
# Operator mappings
def create_convert_map(self):
self.convert_map = {
@@ -3567,8 +3548,8 @@ class PyTorchOpConverter:
"aten::sqrt": self.make_unary("sqrt"),
"aten::rsqrt": self.make_unary("rsqrt"),
"aten::square": self.square,
- "aten::tril": self.tril,
- "aten::triu": self.triu,
+ "aten::tril": functools.partial(self.trilu, mode="tril"),
+ "aten::triu": functools.partial(self.trilu, mode="triu"),
"aten::ceil": self.make_unary("ceil"),
"aten::floor": self.make_unary("floor"),
"aten::round": self.make_unary("round"),
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index baf616a946..951de06967 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -191,6 +191,10 @@ def stft_shape_func(attrs, inputs, _):
]
+# trilu
+_reg.register_strategy("trilu", strategy.trilu_strategy)
+
+
# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 8b92fdf267..7e8367abbb 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -617,3 +617,8 @@ class NLLLossAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs")
class FixedPointMultiplyAttrs(Attrs):
"""Attributes used in fixed_point_multiply operators"""
+
+
+@tvm._ffi.register_object("relay.attrs.TriluAttrs")
+class TriluAttrs(Attrs):
+ """Attributes used in trilu operators"""
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 6074b0a69c..95558b5f3d 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1460,6 +1460,34 @@ def wrap_compute_stft(topi_compute):
return _compute_stft
+# trilu
+@override_native_generic_func("trilu_strategy")
+def trilu_strategy(attrs, outs, out_type, target):
+ """trilu generic strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_trilu(topi.trilu),
+ wrap_topi_schedule(topi.generic.schedule_extern),
+ name="trilu.generic",
+ )
+ return strategy
+
+
+def wrap_compute_trilu(topi_compute):
+ """Wrap trilu compute"""
+
+ def _compute_trilu(attrs, inputs, output_type):
+ return [
+ topi_compute(
+ inputs[0],
+ inputs[1],
+ attrs.upper,
+ )
+ ]
+
+ return _compute_trilu
+
+
# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index b5d44781e5..e7ae5f7d83 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -1889,3 +1889,46 @@ def stft(
window = _make.ones([n_fft], "int32")
return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)
+
+
+def trilu(data, k, upper=True):
+ """
+ Given a 2-D matrix or batches of 2-D matrices, returns the
+ upper or lower triangular part of the tensor.
+
+ Parameters
+ ----------
+ data: relay.Expr
+ The tensor that trilu will be applied to. Must be either
+ a 2D matrix or a tensor of batches of 2D matrices.
+
+ k: int
+ The number of diagonals above or below the main diagonal
+ to exclude or include.
+
+ upper: bool, optional
+ If True, only upper triangular values of input are kept,
+ if False, the lower triangular values are kept.
+
+
+ Returns
+ -------
+ ret : relay.Expr
+ The new tensor with appropriate diagonals set to zero.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ x = [[0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8]]
+
+ relay.trilu(x, True, 0) =
+ [[0, 1, 2],
+ [0, 4, 5],
+ [0, 0, 8]]
+ """
+ if not isinstance(k, Expr):
+ k = const(k, dtype="int32")
+ return _make.trilu(data, k, upper)
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index d99d6772b0..e12f80e2ef 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -1001,3 +1001,61 @@ def sliding_window(data, axis, window_shape, strides):
The resulting tensor.
"""
return cpp.sliding_window(data, axis, window_shape, strides)
+
+
+def trilu(data, k, upper):
+ """
+ Given a 2-D matrix or batches of 2-D matrices, returns the
+ upper or lower triangular part of the tensor.
+
+ Parameters
+ ----------
+ data: tvm.te.Tensor
+ The tensor that trilu will be applied to. Must be either
+ a 2D matrix or a tensor of batches of 2D matrices.
+
+ k: tvm.te.Tensor
+ The number of diagonals above or below the main diagonal
+ to exclude or include.
+
+ upper: bool
+ If True, only upper triangular values of input are kept,
+ if False, the lower triangular values are kept.
+
+
+ Returns
+ -------
+ ret : relay.Expr
+ The new tensor with appropriate diagonals set to zero.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ x = [[0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8]]
+
+ relay.trilu(x, True, 0) =
+ [[0, 1, 2],
+ [0, 4, 5],
+ [0, 0, 8]]
+ """
+ # Make sure datatype is consistent.
+ if k.dtype != "int32":
+ k = tvm.tir.Cast("int32", k)
+
+ # Check either above or below diagonal depending on upper.
+ check_op = tvm.tir.GE
+ if upper:
+ check_op = tvm.tir.LE
+
+ def _apply_trilu(*indices):
+ row_index = indices[-2]
+ col_index = indices[-1]
+ other_indices = indices[:-2]
+ check_position = check_op(row_index, col_index - k)
+ value = data(*other_indices, row_index, col_index)
+ return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype))
+
+ return te.compute(data.shape, _apply_trilu, name="trilu")
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 989ab2ad25..f90cd91e92 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -4230,5 +4230,55 @@ RELAY_REGISTER_OP("invert_permutation")
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TOpIsStateful>("TOpIsStateful", false);
+// Trilu
+
+TVM_REGISTER_NODE_TYPE(TriluAttrs);
+
+bool TriluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // types: [data, k, result]
+ ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided";
+ ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided";
+ auto data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ ICHECK(types[0].as<IncompleteTypeNode>())
+ << "Trilu: expect input type to be TensorType but get " << types[0];
+ return false;
+ }
+
+ auto k = types[1].as<TensorTypeNode>();
+ if (k == nullptr) {
+ ICHECK(types[1].as<IncompleteTypeNode>())
+ << "Trilu: expect k type to be TensorType but get " << types[1];
+ return false;
+ }
+
+ ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k;
+
+ // Output shape is the same as input shape.
+ reporter->Assign(types[2], TensorType(data->shape, data->dtype));
+ return true;
+}
+
+Expr MakeTrilu(Expr data, Expr k, bool upper) {
+ auto attrs = make_object<TriluAttrs>();
+ attrs->upper = upper;
+ static const Op& op = Op::Get("trilu");
+ return Call(op, {data, k}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.trilu").set_body_typed(MakeTrilu);
+
+RELAY_REGISTER_OP("trilu")
+ .describe(
+ R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal.
+ )code" TVM_ADD_FILELINE)
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor")
+ .add_argument("k", "Tensor", "The number of diagonals above or below the main to exclude.")
+ .add_type_rel("trilu", TriluRel)
+ .set_support_level(3)
+ .set_attr<TOpPattern>("TOpPattern", kElemWise);
+
} // namespace relay
} // namespace tvm
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 0b2e51e544..e500f0902c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5242,23 +5242,7 @@ unsupported_onnx_tests = [
"test_training_dropout_mask",
"test_training_dropout_zero_ratio",
"test_training_dropout_zero_ratio_mask",
- "test_tril",
- "test_tril_pos",
- "test_tril_square",
- "test_tril_square_neg",
- "test_tril_neg",
- "test_tril_one_row_neg",
- "test_tril_out_neg",
- "test_tril_out_pos",
"test_tril_zero",
- "test_triu",
- "test_triu_one_row",
- "test_triu_out_neg_out",
- "test_triu_out_pos",
- "test_triu_neg",
- "test_triu_pos",
- "test_triu_square",
- "test_triu_square_neg",
"test_triu_zero",
"test_unique_sorted_with_axis",
"test_unique_sorted_with_axis_3d",
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index f52c7168b3..1d07c780b7 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4616,5 +4616,15 @@ def test_lerp():
verify_model(test_fn, [x, y, w[0]])
+def test_trilu():
+ def _test_trilu(op, diagonal):
+ return lambda inp: op(inp, diagonal)
+
+ for op in [torch.triu, torch.tril]:
+ verify_model(_test_trilu(op, 0), [torch.rand(size=[3, 3])])
+ verify_model(_test_trilu(op, 1), [torch.rand(size=[6, 6])])
+ verify_model(_test_trilu(op, -2), [torch.rand(size=[6, 6])])
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index f91a027de4..b641ba1fdb 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -2207,5 +2207,34 @@ class TestSTFT:
)
+def test_trilu(target="llvm", dev=tvm.cpu()):
+ def verify_trilu(data_shape, upper=True, k=0):
+ data = relay.var("data", relay.TensorType(data_shape, "float32"))
+ y = relay.trilu(data, k, upper)
+ mod = tvm.ir.IRModule.from_expr(y)
+
+ data_np = np.random.normal(size=data_shape).astype("float32")
+ tvm_res = (
+ relay.create_executor("graph", mod=mod, device=dev, target=target)
+ .evaluate()(data_np)
+ .numpy()
+ )
+ if upper:
+ np_res = np.triu(data_np, k)
+ else:
+ np_res = np.tril(data_np, k)
+ tvm.testing.assert_allclose(tvm_res, np_res)
+
+ # Test upper and lower triangle
+ verify_trilu((3, 3), True, 0)
+ verify_trilu((3, 3), False, 0)
+ # Test larger matrices with offset.
+ verify_trilu((6, 6), True, 1)
+ verify_trilu((6, 6), False, 2)
+ verify_trilu((6, 6), False, -2)
+ # Test batch size
+ verify_trilu((8, 6, 6), False, -2)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index 180f267650..c3155c948a 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -812,6 +812,31 @@ def verify_adv_index(data_shape, index_shapes, indice_dtype="int64"):
check_device(target, dev)
+def verify_trilu(input_shape, upper, k=0):
+ x = te.placeholder(shape=input_shape, name="x", dtype="float32")
+ k_tir = tvm.tir.const(k, dtype="int32")
+ trilu_result = topi.transform.trilu(x, k_tir, upper)
+
+ def check_device(target, dev):
+ print("Running on target: %s" % target)
+ with tvm.target.Target(target):
+ s = tvm.topi.testing.get_injective_schedule(target)(trilu_result)
+ fn = tvm.build(s, [x, trilu_result], target, name="trilu")
+ x_npy = np.random.normal(size=input_shape).astype(x.dtype)
+ if upper:
+ out_npy = np.triu(x_npy, k)
+ else:
+ out_npy = np.tril(x_npy, k)
+ x_nd = tvm.nd.array(x_npy, dev)
+ out_nd = tvm.nd.array(np.empty(x_npy.shape).astype(trilu_result.dtype), dev)
+ fn(x_nd, out_nd)
+ out_topi = out_nd.numpy()
+ tvm.testing.assert_allclose(out_topi, out_npy)
+
+ for target, dev in tvm.testing.enabled_targets():
+ check_device(target, dev)
+
+
@tvm.testing.uses_gpu
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
@@ -1256,6 +1281,19 @@ def test_adv_index():
verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)], indice_dtype=indice_dtype)
+@tvm.testing.uses_gpu
+def test_trilu():
+ # Test upper and lower triangle
+ verify_trilu((3, 3), True, 0)
+ verify_trilu((3, 3), False, 0)
+ # Test larger matrices with offset.
+ verify_trilu((6, 6), True, 1)
+ verify_trilu((6, 6), False, 2)
+ verify_trilu((6, 6), False, -2)
+ # Test batch size
+ verify_trilu((8, 6, 6), False, -2)
+
+
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
@@ -1283,3 +1321,4 @@ if __name__ == "__main__":
test_sparse_to_dense()
test_matrix_set_diag()
test_adv_index()
+ test_trilu()