You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2023/03/14 17:55:29 UTC

[tvm] branch unity updated: [Unity][Op] Cumsum (#14297)

This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3c65ac467d [Unity][Op] Cumsum (#14297)
3c65ac467d is described below

commit 3c65ac467d992de1d3dc93bb65d0bd1a1551ea67
Author: Chaofan Lin <17...@qq.com>
AuthorDate: Wed Mar 15 01:55:19 2023 +0800

    [Unity][Op] Cumsum (#14297)
    
    This PR introduces the cumulative sum high-level operator.
    
    Also replace some `T.var("int64")` with `T.int64()` in `test_ast_printer`.
---
 include/tvm/relax/attrs/manipulate.h               | 15 +++++
 python/tvm/relax/frontend/torch/fx_translator.py   | 40 ++++++++++++
 python/tvm/relax/op/manipulate.py                  | 51 +++++++++++++++
 .../tvm/relax/transform/legalize_ops/manipulate.py |  5 ++
 python/tvm/script/ir_builder/relax/ir.py           |  2 +
 python/tvm/topi/scan.py                            |  2 +-
 src/relax/op/tensor/manipulate.cc                  | 47 ++++++++++++++
 src/relax/op/tensor/manipulate.h                   | 12 ++++
 tests/python/relax/test_ast_printer.py             |  4 +-
 tests/python/relax/test_frontend_from_fx.py        | 30 +++++++++
 tests/python/relax/test_op_manipulate.py           | 63 +++++++++++++++++++
 .../test_transform_legalize_ops_manipulate.py      | 73 ++++++++++++++++++++++
 .../relax/test_tvmscript_parser_op_manipulate.py   | 15 +++++
 13 files changed, 356 insertions(+), 3 deletions(-)

diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h
index 4982daf7e4..4aa51f2b73 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -125,6 +125,21 @@ struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
   }
 };  // struct TileAttrs
 
+/*! \brief Attributes used in cumsum operators */
+struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
+  Optional<Integer> axis;
+  DataType dtype;
+
+  TVM_DECLARE_ATTRS(CumsumAttrs, "relax.attrs.CumsumAttrs") {
+    TVM_ATTR_FIELD(axis).describe(
+        "Axis along which the cumulative sum is computed."
+        "The default (None) is to compute the cumsum over the flattened array.");
+    TVM_ATTR_FIELD(dtype).describe(
+        "Type of the returned array and of the accumulator in which the elements are summed."
+        "If dtype is not specified, it defaults to the dtype of data.");
+  }
+};  // struct CumsumAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index 41e8e775a4..0bd987cf2d 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -212,6 +212,10 @@ class TorchFXImporter:
         lhs, rhs = self.retrieve_args(node)
         return self._call_binary_op(relax.op.less, lhs, rhs)
 
+    def _eq(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        return self._call_binary_op(relax.op.equal, lhs, rhs)
+
     ########## Creation ##########
 
     def _arange(self, node: fx.node.Node) -> relax.Var:
@@ -461,6 +465,38 @@ class TorchFXImporter:
             dim = None
         return self.block_builder.emit(relax.op.squeeze(x, dim))
 
+    def _cumsum(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+
+        if "dim" in node.kwargs:
+            dim = node.kwargs["dim"]
+        elif len(node.args) > 1:
+            dim = node.args[1]
+        else:
+            dim = None
+        if "dtype" in node.kwargs:
+            dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
+        else:
+            dtype = None
+        if "out" in node.kwargs:
+            raise ValueError("specifying out for cumsum is not supported yet")
+
+        return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
+
+    def _index_select(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1]
+        index = self.env[node.args[2]]
+        return self.block_builder.emit(relax.op.take(x, index, dim))
+
+    def _masked_fill(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        mask = self.env[node.args[1]]
+        value = node.args[2]
+        rx_value = relax.const(value)
+        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+        return self.block_builder.emit(relax.op.where(mask, values, x))
+
     ########## Search ##########
 
     def _argmax_argmin(self, op: Callable) -> Callable:
@@ -877,6 +913,7 @@ class TorchFXImporter:
             "sqrt": self._sqrt,
             "round": self._round,
             "lt": self._lt,
+            "eq": self._eq,
             "truediv": self._truediv,
             "fill_": self._inplace_fill,
             "new_ones": self._new_ones,
@@ -902,6 +939,7 @@ class TorchFXImporter:
             "permute": self._permute,
             "reshape": self._reshape,
             "split": self._split,
+            "cumsum": self._cumsum,
             "chunk": self._chunk,
             "transpose": self._transpose,
             "squeeze": self._squeeze,
@@ -925,6 +963,8 @@ class TorchFXImporter:
             "to": lambda node: self.env[node.args[0]],
             "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
             "layer_norm": self._layer_norm,
+            "index_select": self._index_select,
+            "masked_fill": self._masked_fill,
         }
 
     def from_fx(
diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py
index c59ab793e0..e9c3ce79d7 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -17,6 +17,7 @@
 """Manipulation operators."""
 from typing import List, Optional, Tuple, Union, Callable
 
+from tvm import DataType
 from tvm.ir.expr import PrimExpr
 from tvm.tir import IntImm, FloatImm, IndexMap
 
@@ -388,3 +389,53 @@ def tile(data: Expr, repeats: Union[int, Tuple[int], List[int]]) -> Expr:
     if isinstance(repeats, int):
         repeats = [repeats]
     return _ffi_api.tile(data, repeats)  # type: ignore
+
+
+def cumsum(data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None):
+    """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
+    a given axis.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    axis : Optional[int]
+        Axis along which the cumulative sum is computed. The default (None) is to compute
+        the cumsum over the flattened array.
+
+    dtype : Optional[Union[str, DataType]]
+        Type of the returned array and of the accumulator in which the elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    Returns
+    -------
+    result : relax.Expr
+        The result has the same size as data, and the same shape as data if axis is not None.
+        If axis is None, the result is a 1-d array.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        a = [[1, 2, 3], [4, 5, 6]]
+
+        cumsum(a)  # if axis is not provided, cumsum is done over the flattened input.
+        -> [ 1,  3,  6, 10, 15, 21]
+
+        cumsum(a, dtype="float32")
+        -> [  1.,   3.,   6.,  10.,  15.,  21.]
+
+        cumsum(a, axis=0)  # sum over rows for each of the 3 columns
+        -> [[1, 2, 3],
+            [5, 7, 9]]
+
+        cumsum(a, axis=1)
+        -> [[ 1,  3,  6],
+            [ 4,  9, 15]]
+
+        a = [1, 0, 1, 0, 1, 1, 0]  # a is a boolean array
+        cumsum(a, dtype=int32)  # dtype should be provided to get the expected results
+        -> [1, 1, 2, 2, 3, 4, 4]
+    """
+    return _ffi_api.cumsum(data, axis, dtype)  # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 7c67d5b26c..e7cae1af34 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -142,3 +142,8 @@ def _repeat(bb: BlockBuilder, call: Call) -> Expr:
 @register_legalize("relax.tile")
 def _tile(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(topi.tile, call.args[0], call.attrs.repeats)
+
+
+@register_legalize("relax.cumsum")
+def _cumsum(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(topi.cumsum, call.args[0], call.attrs.axis, call.attrs.dtype)
diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py
index c658b6f77d..b3190ea334 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -56,6 +56,7 @@ from tvm.relax.op import (
     concat,
     cos,
     cosh,
+    cumsum,
     divide,
     equal,
     ewise_fma,
@@ -556,6 +557,7 @@ __all__ = [
     "cos",
     "cosh",
     "const",
+    "cumsum",
     "dataflow",
     "divide",
     "dtype",
diff --git a/python/tvm/topi/scan.py b/python/tvm/topi/scan.py
index 32a7e297b0..22f9ff58a5 100644
--- a/python/tvm/topi/scan.py
+++ b/python/tvm/topi/scan.py
@@ -151,7 +151,7 @@ def scanop(
 def cumsum(
     data: tvm.te.Tensor,
     axis: Optional[int] = None,
-    dtype: Optional[int] = None,
+    dtype: Optional[str] = None,
     exclusive: Optional[bool] = None,
 ) -> tvm.te.Tensor:
     """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.
diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc
index c3a673f8fa..d90fd41e1c 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1095,5 +1095,52 @@ TVM_REGISTER_OP("relax.tile")
     .add_argument("data", "Tensor", "The input tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTile);
 
+/* relax.cumsum */
+TVM_REGISTER_NODE_TYPE(CumsumAttrs);
+
+Expr cumsum(Expr data, Optional<Integer> axis, DataType dtype) {
+  auto attrs = make_object<CumsumAttrs>();
+  attrs->axis = std::move(axis);
+  attrs->dtype = std::move(dtype);
+
+  static const Op& op = Op::Get("relax.cumsum");
+  return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(cumsum);
+
+StructInfo InferStructInfoCumsum(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<CumsumAttrs>();
+
+  DataType out_type = attrs->dtype.is_void() ? data_sinfo->dtype : attrs->dtype;
+
+  if (!attrs->axis.defined()) {
+    // flattened
+    const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+    if (data_shape == nullptr) {
+      return TensorStructInfo(out_type, data_sinfo->ndim);
+    } else {
+      PrimExpr flattened_d = 1;
+      for (const auto v : data_shape->values) {
+        flattened_d *= v;
+      }
+      return TensorStructInfo(ShapeExpr(Array<PrimExpr>({flattened_d})), out_type);
+    }
+  }
+
+  if (data_sinfo->shape.defined()) {
+    return TensorStructInfo(data_sinfo->shape.value(), out_type);
+  } else {
+    return TensorStructInfo(out_type, data_sinfo->ndim);
+  }
+}
+
+TVM_REGISTER_OP("relax.cumsum")
+    .set_attrs_type<CumsumAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCumsum);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index fb75664a1d..592d8b1347 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -160,6 +160,18 @@ Expr repeat(Expr data, int repeats, Optional<Integer> axis = NullOpt);
  */
 Expr tile(Expr data, Array<Integer> repeats);
 
+/*!
+ * \brief Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
+ * a given axis.
+ * \param data The input tensor.
+ * \param axis Axis along which the cumulative sum is computed. The default (None) is to compute
+ * the cumsum over the flattened array.
+ * \param dtype Type of the returned array and of the accumulator in which the elements are summed.
+ * If dtype is not specified, it defaults to the dtype of data.
+ * \return The computed result.
+ */
+Expr cumsum(Expr data, Optional<Integer> axis = NullOpt, DataType dtype = DataType::Void());
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py
index 1929b546dc..84b8cb1d09 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -441,7 +441,7 @@ def test_call_tir():
 
         @R.function
         def foo(x: R.Tensor(("m", "n"), "float32")):
-            m, n = T.var("int64"), T.var("int64")
+            m, n = T.int64(), T.int64()
             gv0 = R.call_tir(TestCallTIR.addone, (x,), R.Tensor((m, n), dtype="float32"))
             return gv0
 
@@ -495,7 +495,7 @@ def test_call_tir():
 def test_call_dps_packed():
     @R.function
     def foo(x: R.Tensor(("m", "n"), "float32")):
-        m, n = T.var("int64"), T.var("int64")
+        m, n = T.int64(), T.int64()
         gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32"))
         return gv0
 
diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py
index 6467b6cf14..31b43070cb 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1748,6 +1748,36 @@ def test_split():
     verify_model(Split(), input_info, {}, expected1)
 
 
+@tvm.testing.requires_gpu
+def test_cumsum():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    class Cumsum(Module):
+        def forward(self, input):
+            return torch.cumsum(input, dim=1, dtype=torch.int32)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tensor((1, 2, 3, 4), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, axis=1, dtype="int32")
+                gv: R.Tensor((1, 2, 3, 4), dtype="int32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Cumsum(), input_info, {}, expected1)
+
+
 @tvm.testing.requires_gpu
 def test_chunk():
     import torch
diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py
index e1f550cc38..af20639a8e 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -42,6 +42,7 @@ def test_op_correctness():
     assert relax.op.collapse_sum_to(x, (4, 5)).op == Op.get("relax.collapse_sum_to")
     y = relax.Var("x", R.Tensor((4, 5), "float32"))
     assert relax.op.collapse_sum_like(x, y).op == Op.get("relax.collapse_sum_like")
+    assert relax.op.cumsum(x, axis=1, dtype="int32").op == Op.get("relax.cumsum")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
@@ -2950,5 +2951,67 @@ def test_tile_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.tile(x2, r2))
 
 
+def test_cumsum_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 10, 4)))
+    x4 = relax.Var("x", R.Tensor(ndim=3))
+    x5 = relax.Var("x", R.Tensor())
+
+    _check_inference(bb, relax.op.cumsum(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32"))
+    _check_inference(
+        bb, relax.op.cumsum(x1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
+    )
+    _check_inference(bb, relax.op.cumsum(x2, axis=1), relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.cumsum(x3, axis=1), relax.TensorStructInfo((2, 10, 4), dtype=""))
+    _check_inference(bb, relax.op.cumsum(x4, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
+    _check_inference(bb, relax.op.cumsum(x5, axis=1), relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.cumsum(x0), relax.TensorStructInfo((80,), "float32"))
+    _check_inference(
+        bb, relax.op.cumsum(x0, axis=1, dtype="int32"), relax.TensorStructInfo((2, 10, 4), "int32")
+    )
+
+
+def test_cumsum_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+
+    _check_inference(bb, relax.op.cumsum(x, axis=1), relax.TensorStructInfo((a, b, c), "float32"))
+    _check_inference(bb, relax.op.cumsum(x), relax.TensorStructInfo((a * b * c,), "float32"))
+
+
+def test_cumsum_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+
+    _check_inference(bb, relax.op.cumsum(x0, axis=1), relax.TensorStructInfo((2, 3, 4), "float16"))
+    _check_inference(bb, relax.op.cumsum(x1, axis=1), relax.TensorStructInfo((2, 3, 4), "int8"))
+
+
+def test_cumsum_wrong_input_number():
+    x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+    y = relax.Var("y", R.Tensor((2, 3, 4), "float32"))
+
+    with pytest.raises(TVMError):
+        relax.op.cumsum(x, y)
+
+
+def test_cumsum_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.cumsum(x0, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.cumsum(x1, axis=1))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index ecfa1a487a..b50ba91089 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1085,5 +1085,78 @@ def test_tile_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_cumsum():
+    # fmt: off
+    @I.ir_module
+    class Cumsum:
+        @R.function
+        def main(x: R.Tensor((3, 2, 3), "float32")):
+            gv = R.cumsum(x, axis=1, dtype="int32")
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "int32")):
+            T.func_attr({"tir.noalias": True})
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(3)), offset_factor=1)
+            with T.block("cumsum_generic"):
+                T.reads(rxplaceholder[T.int64(0):T.int64(3), T.int64(0):T.int64(2), T.int64(0):T.int64(3)])
+                T.writes(out_buf[T.int64(0):T.int64(3), T.int64(0):T.int64(2), T.int64(0):T.int64(3)])
+                for fused in T.parallel(T.int64(9)):
+                    out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)] = T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % [...]
+                    for _k in range(T.int64(1)):
+                        out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)] = out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k [...]
+
+        @R.function
+        def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((3, 2, 3), dtype="int32"):
+            cls = Expected
+            gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((3, 2, 3), dtype="int32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Cumsum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_cumsum_symbolic():
+    # fmt: off
+    @I.ir_module
+    class Cumsum:
+        @R.function
+        def main(x: R.Tensor(("a", "b", "c"), "float32")):
+            gv = R.cumsum(x, axis=1, dtype="int32")
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def cumsum(var_rxplaceholder: T.handle, var_cumsum_generic: T.handle):
+            T.func_attr({"tir.noalias": True})
+            a, b, c = T.int64(), T.int64(), T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c), offset_factor=1)
+            out_buf = T.match_buffer(var_cumsum_generic, (a, b, c), "int32")
+            with T.block("cumsum_generic"):
+                T.reads(rxplaceholder[T.int64(0):a, T.int64(0):b, T.int64(0):c])
+                T.writes(out_buf[T.int64(0):a, T.int64(0):b, T.int64(0):c])
+                for fused in T.parallel(a * c):
+                    out_buf[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c] = T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c])
+                    for _k in range(b - T.int64(1)):
+                        out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c] = out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) % c] + T.Cast("int32", [...]
+
+        @R.function
+        def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="int32"):
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            cls = Expected
+            gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((a, b, c), dtype="int32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Cumsum)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
index 10b786ee4a..a797885e96 100644
--- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
+++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
@@ -388,5 +388,20 @@ def test_tile():
     _check(foo, bb.get()["foo"])
 
 
+def test_cumsum():
+    @R.function
+    def foo(x: R.Tensor((2, 3, 4), "float32")):
+        gv = R.cumsum(x, axis=1, dtype="int32")
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.cumsum(x, axis=1, dtype="int32"))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()