You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2023/03/14 17:55:29 UTC
[tvm] branch unity updated: [Unity][Op] Cumsum (#14297)
This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3c65ac467d [Unity][Op] Cumsum (#14297)
3c65ac467d is described below
commit 3c65ac467d992de1d3dc93bb65d0bd1a1551ea67
Author: Chaofan Lin <17...@qq.com>
AuthorDate: Wed Mar 15 01:55:19 2023 +0800
[Unity][Op] Cumsum (#14297)
This PR introduces the cumulative sum high-level operator.
Also replace some `T.var("int64")` with `T.int64()` in `test_ast_printer`.
---
include/tvm/relax/attrs/manipulate.h | 15 +++++
python/tvm/relax/frontend/torch/fx_translator.py | 40 ++++++++++++
python/tvm/relax/op/manipulate.py | 51 +++++++++++++++
.../tvm/relax/transform/legalize_ops/manipulate.py | 5 ++
python/tvm/script/ir_builder/relax/ir.py | 2 +
python/tvm/topi/scan.py | 2 +-
src/relax/op/tensor/manipulate.cc | 47 ++++++++++++++
src/relax/op/tensor/manipulate.h | 12 ++++
tests/python/relax/test_ast_printer.py | 4 +-
tests/python/relax/test_frontend_from_fx.py | 30 +++++++++
tests/python/relax/test_op_manipulate.py | 63 +++++++++++++++++++
.../test_transform_legalize_ops_manipulate.py | 73 ++++++++++++++++++++++
.../relax/test_tvmscript_parser_op_manipulate.py | 15 +++++
13 files changed, 356 insertions(+), 3 deletions(-)
diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h
index 4982daf7e4..4aa51f2b73 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -125,6 +125,21 @@ struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
}
}; // struct TileAttrs
+/*! \brief Attributes used in cumsum operators */
+struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
+ Optional<Integer> axis;
+ DataType dtype;
+
+ TVM_DECLARE_ATTRS(CumsumAttrs, "relax.attrs.CumsumAttrs") {
+ TVM_ATTR_FIELD(axis).describe(
+ "Axis along which the cumulative sum is computed."
+ "The default (None) is to compute the cumsum over the flattened array.");
+ TVM_ATTR_FIELD(dtype).describe(
+ "Type of the returned array and of the accumulator in which the elements are summed."
+ "If dtype is not specified, it defaults to the dtype of data.");
+ }
+}; // struct CumsumAttrs
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index 41e8e775a4..0bd987cf2d 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -212,6 +212,10 @@ class TorchFXImporter:
lhs, rhs = self.retrieve_args(node)
return self._call_binary_op(relax.op.less, lhs, rhs)
+ def _eq(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ return self._call_binary_op(relax.op.equal, lhs, rhs)
+
########## Creation ##########
def _arange(self, node: fx.node.Node) -> relax.Var:
@@ -461,6 +465,38 @@ class TorchFXImporter:
dim = None
return self.block_builder.emit(relax.op.squeeze(x, dim))
+ def _cumsum(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+
+ if "dim" in node.kwargs:
+ dim = node.kwargs["dim"]
+ elif len(node.args) > 1:
+ dim = node.args[1]
+ else:
+ dim = None
+ if "dtype" in node.kwargs:
+ dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
+ else:
+ dtype = None
+ if "out" in node.kwargs:
+ raise ValueError("specifying out for cumsum is not supported yet")
+
+ return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
+
+ def _index_select(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1]
+ index = self.env[node.args[2]]
+ return self.block_builder.emit(relax.op.take(x, index, dim))
+
+ def _masked_fill(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ mask = self.env[node.args[1]]
+ value = node.args[2]
+ rx_value = relax.const(value)
+ values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+ return self.block_builder.emit(relax.op.where(mask, values, x))
+
########## Search ##########
def _argmax_argmin(self, op: Callable) -> Callable:
@@ -877,6 +913,7 @@ class TorchFXImporter:
"sqrt": self._sqrt,
"round": self._round,
"lt": self._lt,
+ "eq": self._eq,
"truediv": self._truediv,
"fill_": self._inplace_fill,
"new_ones": self._new_ones,
@@ -902,6 +939,7 @@ class TorchFXImporter:
"permute": self._permute,
"reshape": self._reshape,
"split": self._split,
+ "cumsum": self._cumsum,
"chunk": self._chunk,
"transpose": self._transpose,
"squeeze": self._squeeze,
@@ -925,6 +963,8 @@ class TorchFXImporter:
"to": lambda node: self.env[node.args[0]],
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
"layer_norm": self._layer_norm,
+ "index_select": self._index_select,
+ "masked_fill": self._masked_fill,
}
def from_fx(
diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py
index c59ab793e0..e9c3ce79d7 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -17,6 +17,7 @@
"""Manipulation operators."""
from typing import List, Optional, Tuple, Union, Callable
+from tvm import DataType
from tvm.ir.expr import PrimExpr
from tvm.tir import IntImm, FloatImm, IndexMap
@@ -388,3 +389,53 @@ def tile(data: Expr, repeats: Union[int, Tuple[int], List[int]]) -> Expr:
if isinstance(repeats, int):
repeats = [repeats]
return _ffi_api.tile(data, repeats) # type: ignore
+
+
+def cumsum(data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None):
+ """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
+ a given axis.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ axis : Optional[int]
+ Axis along which the cumulative sum is computed. The default (None) is to compute
+ the cumsum over the flattened array.
+
+ dtype : Optional[Union[str, DataType]]
+ Type of the returned array and of the accumulator in which the elements are summed.
+ If dtype is not specified, it defaults to the dtype of data.
+
+ Returns
+ -------
+ result : relax.Expr
+ The result has the same size as data, and the same shape as data if axis is not None.
+ If axis is None, the result is a 1-d array.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ a = [[1, 2, 3], [4, 5, 6]]
+
+ cumsum(a) # if axis is not provided, cumsum is done over the flattened input.
+ -> [ 1, 3, 6, 10, 15, 21]
+
+ cumsum(a, dtype="float32")
+ -> [ 1., 3., 6., 10., 15., 21.]
+
+ cumsum(a, axis=0) # sum over rows for each of the 3 columns
+ -> [[1, 2, 3],
+ [5, 7, 9]]
+
+ cumsum(a, axis=1)
+ -> [[ 1, 3, 6],
+ [ 4, 9, 15]]
+
+ a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array
+ cumsum(a, dtype=int32) # dtype should be provided to get the expected results
+ -> [1, 1, 2, 2, 3, 4, 4]
+ """
+ return _ffi_api.cumsum(data, axis, dtype) # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 7c67d5b26c..e7cae1af34 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -142,3 +142,8 @@ def _repeat(bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.tile")
def _tile(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.tile, call.args[0], call.attrs.repeats)
+
+
+@register_legalize("relax.cumsum")
+def _cumsum(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(topi.cumsum, call.args[0], call.attrs.axis, call.attrs.dtype)
diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py
index c658b6f77d..b3190ea334 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -56,6 +56,7 @@ from tvm.relax.op import (
concat,
cos,
cosh,
+ cumsum,
divide,
equal,
ewise_fma,
@@ -556,6 +557,7 @@ __all__ = [
"cos",
"cosh",
"const",
+ "cumsum",
"dataflow",
"divide",
"dtype",
diff --git a/python/tvm/topi/scan.py b/python/tvm/topi/scan.py
index 32a7e297b0..22f9ff58a5 100644
--- a/python/tvm/topi/scan.py
+++ b/python/tvm/topi/scan.py
@@ -151,7 +151,7 @@ def scanop(
def cumsum(
data: tvm.te.Tensor,
axis: Optional[int] = None,
- dtype: Optional[int] = None,
+ dtype: Optional[str] = None,
exclusive: Optional[bool] = None,
) -> tvm.te.Tensor:
"""Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.
diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc
index c3a673f8fa..d90fd41e1c 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1095,5 +1095,52 @@ TVM_REGISTER_OP("relax.tile")
.add_argument("data", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTile);
+/* relax.cumsum */
+TVM_REGISTER_NODE_TYPE(CumsumAttrs);
+
+Expr cumsum(Expr data, Optional<Integer> axis, DataType dtype) {
+ auto attrs = make_object<CumsumAttrs>();
+ attrs->axis = std::move(axis);
+ attrs->dtype = std::move(dtype);
+
+ static const Op& op = Op::Get("relax.cumsum");
+ return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(cumsum);
+
+StructInfo InferStructInfoCumsum(const Call& call, const BlockBuilder& ctx) {
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ const auto* attrs = call->attrs.as<CumsumAttrs>();
+
+ DataType out_type = attrs->dtype.is_void() ? data_sinfo->dtype : attrs->dtype;
+
+ if (!attrs->axis.defined()) {
+ // flattened
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape == nullptr) {
+ return TensorStructInfo(out_type, data_sinfo->ndim);
+ } else {
+ PrimExpr flattened_d = 1;
+ for (const auto v : data_shape->values) {
+ flattened_d *= v;
+ }
+ return TensorStructInfo(ShapeExpr(Array<PrimExpr>({flattened_d})), out_type);
+ }
+ }
+
+ if (data_sinfo->shape.defined()) {
+ return TensorStructInfo(data_sinfo->shape.value(), out_type);
+ } else {
+ return TensorStructInfo(out_type, data_sinfo->ndim);
+ }
+}
+
+TVM_REGISTER_OP("relax.cumsum")
+ .set_attrs_type<CumsumAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCumsum);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index fb75664a1d..592d8b1347 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -160,6 +160,18 @@ Expr repeat(Expr data, int repeats, Optional<Integer> axis = NullOpt);
*/
Expr tile(Expr data, Array<Integer> repeats);
+/*!
+ * \brief Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
+ * a given axis.
+ * \param data The input tensor.
+ * \param axis Axis along which the cumulative sum is computed. The default (None) is to compute
+ * the cumsum over the flattened array.
+ * \param dtype Type of the returned array and of the accumulator in which the elements are summed.
+ * If dtype is not specified, it defaults to the dtype of data.
+ * \return The computed result.
+ */
+Expr cumsum(Expr data, Optional<Integer> axis = NullOpt, DataType dtype = DataType::Void());
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py
index 1929b546dc..84b8cb1d09 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -441,7 +441,7 @@ def test_call_tir():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
- m, n = T.var("int64"), T.var("int64")
+ m, n = T.int64(), T.int64()
gv0 = R.call_tir(TestCallTIR.addone, (x,), R.Tensor((m, n), dtype="float32"))
return gv0
@@ -495,7 +495,7 @@ def test_call_tir():
def test_call_dps_packed():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
- m, n = T.var("int64"), T.var("int64")
+ m, n = T.int64(), T.int64()
gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32"))
return gv0
diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py
index 6467b6cf14..31b43070cb 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1748,6 +1748,36 @@ def test_split():
verify_model(Split(), input_info, {}, expected1)
+@tvm.testing.requires_gpu
+def test_cumsum():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ class Cumsum(Module):
+ def forward(self, input):
+ return torch.cumsum(input, dim=1, dtype=torch.int32)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tensor((1, 2, 3, 4), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, axis=1, dtype="int32")
+ gv: R.Tensor((1, 2, 3, 4), dtype="int32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Cumsum(), input_info, {}, expected1)
+
+
@tvm.testing.requires_gpu
def test_chunk():
import torch
diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py
index e1f550cc38..af20639a8e 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -42,6 +42,7 @@ def test_op_correctness():
assert relax.op.collapse_sum_to(x, (4, 5)).op == Op.get("relax.collapse_sum_to")
y = relax.Var("x", R.Tensor((4, 5), "float32"))
assert relax.op.collapse_sum_like(x, y).op == Op.get("relax.collapse_sum_like")
+ assert relax.op.cumsum(x, axis=1, dtype="int32").op == Op.get("relax.cumsum")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
@@ -2950,5 +2951,67 @@ def test_tile_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.tile(x2, r2))
+def test_cumsum_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 10, 4)))
+ x4 = relax.Var("x", R.Tensor(ndim=3))
+ x5 = relax.Var("x", R.Tensor())
+
+ _check_inference(bb, relax.op.cumsum(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32"))
+ _check_inference(
+ bb, relax.op.cumsum(x1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(bb, relax.op.cumsum(x2, axis=1), relax.TensorStructInfo(dtype="float32"))
+ _check_inference(bb, relax.op.cumsum(x3, axis=1), relax.TensorStructInfo((2, 10, 4), dtype=""))
+ _check_inference(bb, relax.op.cumsum(x4, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
+ _check_inference(bb, relax.op.cumsum(x5, axis=1), relax.TensorStructInfo(dtype=""))
+ _check_inference(bb, relax.op.cumsum(x0), relax.TensorStructInfo((80,), "float32"))
+ _check_inference(
+ bb, relax.op.cumsum(x0, axis=1, dtype="int32"), relax.TensorStructInfo((2, 10, 4), "int32")
+ )
+
+
+def test_cumsum_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ c = tir.Var("c", "int64")
+ x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+
+ _check_inference(bb, relax.op.cumsum(x, axis=1), relax.TensorStructInfo((a, b, c), "float32"))
+ _check_inference(bb, relax.op.cumsum(x), relax.TensorStructInfo((a * b * c,), "float32"))
+
+
+def test_cumsum_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+
+ _check_inference(bb, relax.op.cumsum(x0, axis=1), relax.TensorStructInfo((2, 3, 4), "float16"))
+ _check_inference(bb, relax.op.cumsum(x1, axis=1), relax.TensorStructInfo((2, 3, 4), "int8"))
+
+
+def test_cumsum_wrong_input_number():
+ x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ y = relax.Var("y", R.Tensor((2, 3, 4), "float32"))
+
+ with pytest.raises(TVMError):
+ relax.op.cumsum(x, y)
+
+
+def test_cumsum_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.cumsum(x0, axis=1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.cumsum(x1, axis=1))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index ecfa1a487a..b50ba91089 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1085,5 +1085,78 @@ def test_tile_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_cumsum():
+ # fmt: off
+ @I.ir_module
+ class Cumsum:
+ @R.function
+ def main(x: R.Tensor((3, 2, 3), "float32")):
+ gv = R.cumsum(x, axis=1, dtype="int32")
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "int32")):
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(3)), offset_factor=1)
+ with T.block("cumsum_generic"):
+ T.reads(rxplaceholder[T.int64(0):T.int64(3), T.int64(0):T.int64(2), T.int64(0):T.int64(3)])
+ T.writes(out_buf[T.int64(0):T.int64(3), T.int64(0):T.int64(2), T.int64(0):T.int64(3)])
+ for fused in T.parallel(T.int64(9)):
+ out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)] = T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % [...]
+ for _k in range(T.int64(1)):
+ out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)] = out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k [...]
+
+ @R.function
+ def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((3, 2, 3), dtype="int32"):
+ cls = Expected
+ gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((3, 2, 3), dtype="int32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Cumsum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_cumsum_symbolic():
+ # fmt: off
+ @I.ir_module
+ class Cumsum:
+ @R.function
+ def main(x: R.Tensor(("a", "b", "c"), "float32")):
+ gv = R.cumsum(x, axis=1, dtype="int32")
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def cumsum(var_rxplaceholder: T.handle, var_cumsum_generic: T.handle):
+ T.func_attr({"tir.noalias": True})
+ a, b, c = T.int64(), T.int64(), T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c), offset_factor=1)
+ out_buf = T.match_buffer(var_cumsum_generic, (a, b, c), "int32")
+ with T.block("cumsum_generic"):
+ T.reads(rxplaceholder[T.int64(0):a, T.int64(0):b, T.int64(0):c])
+ T.writes(out_buf[T.int64(0):a, T.int64(0):b, T.int64(0):c])
+ for fused in T.parallel(a * c):
+ out_buf[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c] = T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c])
+ for _k in range(b - T.int64(1)):
+ out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c] = out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) % c] + T.Cast("int32", [...]
+
+ @R.function
+ def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="int32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ cls = Expected
+ gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((a, b, c), dtype="int32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Cumsum)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
index 10b786ee4a..a797885e96 100644
--- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
+++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
@@ -388,5 +388,20 @@ def test_tile():
_check(foo, bb.get()["foo"])
+def test_cumsum():
+ @R.function
+ def foo(x: R.Tensor((2, 3, 4), "float32")):
+ gv = R.cumsum(x, axis=1, dtype="int32")
+ return gv
+
+ x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x]):
+ gv = bb.emit(relax.op.cumsum(x, axis=1, dtype="int32"))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
if __name__ == "__main__":
tvm.testing.main()