You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/26 23:54:55 UTC

[tvm] branch main updated: [Relay, TOPI] Add numpy style cumsum op (#7334)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e0d356  [Relay, TOPI] Add numpy style cumsum op (#7334)
1e0d356 is described below

commit 1e0d3569b94f650243f4d0ac204d196e3be8b0aa
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Jan 27 08:54:36 2021 +0900

    [Relay, TOPI] Add numpy style cumsum op (#7334)
    
    * Add cumsum relay/topi op
    
    * relay tests working
    
    * add torch frontend converter
    
    * fix for importing detr
    
    * fix bad merge
    
    * begin cuda cumsum
    
    * support non innermost axis
    
    * support rank higher than 3
    
    * making binop parameter
    
    * fix overflow issue in thrust scan
    
    * generic binop parameter working
    
    * relay test working
    
    * fixed for bool input
    
    * remove pytorch change
    
    * fix pylint
    
    * doc update
    
    * Update python/tvm/topi/cumsum.py
    
    Co-authored-by: Tristan Konolige <tr...@gmail.com>
    
    * Update tests/python/relay/test_op_level3.py
    
    Co-authored-by: Tristan Konolige <tr...@gmail.com>
    
    * add example outputs
    
    * add supported input and output dtype in thrust log
    
    * adding more loop var names
    
    * fix cpplint
    
    * fix missing check for the cuda target in nms thrust sort
    
    * parallelize cpu cumsum
    
    * making binop argument tir function
    
    * update doc for binop
    
    * doc update
    
    Co-authored-by: Tristan Konolige <tr...@gmail.com>
---
 include/tvm/relay/attrs/transform.h          |  10 ++
 python/tvm/relay/op/_transform.py            |  12 +-
 python/tvm/relay/op/strategy/cuda.py         |  12 ++
 python/tvm/relay/op/strategy/generic.py      |  21 +++
 python/tvm/relay/op/transform.py             |  49 +++++
 python/tvm/topi/__init__.py                  |   1 +
 python/tvm/topi/cuda/__init__.py             |   1 +
 python/tvm/topi/cuda/nms.py                  |   3 +-
 python/tvm/topi/cuda/scan.py                 | 255 +++++++++++++++++++--------
 python/tvm/topi/cuda/sort.py                 |   7 +-
 python/tvm/topi/cumsum.py                    | 106 +++++++++++
 python/tvm/topi/utils.py                     |   5 +
 src/relay/op/tensor/transform.cc             |  52 ++++++
 src/runtime/contrib/thrust/thrust.cu         |  73 ++++++--
 tests/python/contrib/test_thrust.py          |   4 +-
 tests/python/relay/test_op_level3.py         |  36 ++++
 tests/python/topi/python/test_topi_cumsum.py |  72 ++++++++
 17 files changed, 625 insertions(+), 94 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index efa44e0..4316624 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -438,6 +438,16 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
   }
 };  // struct MatrixSetDiagAttrs
 
+/*! \brief Attributes used in cumsum operator */
+struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
+  Integer axis;
+  DataType dtype;
+  TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
+    TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>());
+    TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
+  }
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 05ca6d2..fd07c98 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -103,7 +103,7 @@ def compute_scatter_add(attrs, inputs, output_type):
 
 _reg.register_strategy("scatter_add", strategy.scatter_add_strategy)
 
-# scatter
+# scatter_nd
 @_reg.register_compute("scatter_nd")
 def compute_scatter_nd(attrs, inputs, output_type):
     """Compute definition of scatter_nd"""
@@ -112,6 +112,16 @@ def compute_scatter_nd(attrs, inputs, output_type):
 
 _reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)
 
+# cumsum
+@_reg.register_compute("cumsum")
+def compute_cumsum(attrs, inputs, output_type):
+    """Compute definition of cumsum"""
+    return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)]
+
+
+_reg.register_strategy("cumsum", strategy.cumsum_strategy)
+_reg.register_shape_func("cumsum", False, elemwise_shape_func)
+
 #####################
 #  Shape functions  #
 #####################
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 3863df0..346e934 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -996,3 +996,15 @@ def argwhere_strategy_cuda(attrs, inputs, out_type, target):
         name="argwhere.cuda",
     )
     return strategy
+
+
+@cumsum_strategy.register(["cuda", "gpu"])
+def cumsum_strategy_cuda(attrs, inputs, out_type, target):
+    """cumsum cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_cumsum(topi.cuda.cumsum),
+        wrap_topi_schedule(topi.cuda.schedule_scan),
+        name="cumsum.cuda",
+    )
+    return strategy
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 8dd9dc5..3ad75fa 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1361,3 +1361,24 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
         name="threefry_split.generic",
     )
     return strategy
+
+
+def wrap_compute_cumsum(topi_compute):
+    """Wrap cumsum topi compute"""
+
+    def _compute_cumsum(attrs, inputs, _):
+        return [topi_compute(inputs[0], attrs.axis, attrs.dtype)]
+
+    return _compute_cumsum
+
+
+@override_native_generic_func("cumsum_strategy")
+def cumsum_strategy(attrs, inputs, out_type, target):
+    """cumsum generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_cumsum(topi.cumsum),
+        wrap_topi_schedule(topi.generic.schedule_extern),
+        name="cumsum.generic",
+    )
+    return strategy
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 7e7f9b2..6785ff2 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -1320,3 +1320,52 @@ def adv_index(inputs):
         Output tensor.
     """
     return _make.adv_index(Tuple(inputs))
+
+
+def cumsum(data, axis=None, dtype=None):
+    """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
+    a given axis.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    axis : int, optional
+        Axis along which the cumulative sum is computed. The default (None) is to compute
+        the cumsum over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    Returns
+    -------
+    result : relay.Expr
+        The result has the same size as data, and the same shape as data if axis is not None.
+        If axis is None, the result is a 1-d array.
+
+    Examples
+    --------
+    .. code-block:: python
+        a = [[1,2,3], [4,5,6]]
+
+        cumsum(a)  # if axis is not provided, cumsum is done over the flattened input.
+        -> [ 1,  3,  6, 10, 15, 21]
+
+        cumsum(a, dtype="float32")
+        -> [  1.,   3.,   6.,  10.,  15.,  21.]
+
+        cumsum(a, axis=0)  # sum over rows for each of the 3 columns
+        -> [[1, 2, 3],
+            [5, 7, 9]]
+
+        cumsum(a, axis=1)
+        -> [[ 1,  3,  6],
+            [ 4,  9, 15]]
+
+        a = [1, 0, 1, 0, 1, 1, 0]  # a is a boolean array
+        cumsum(a, dtype=int32)  # dtype should be provided to get the expected results
+        -> [1, 1, 2, 2, 3, 4, 4]
+    """
+    return _make.cumsum(data, axis, dtype)
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index cb94b5b..873901d 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -40,6 +40,7 @@ from .sort import *
 from .scatter import *
 from .scatter_add import *
 from .argwhere import *
+from .cumsum import *
 from . import generic
 from . import nn
 from . import x86
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 42bf980..e0ff5a1 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -56,3 +56,4 @@ from .conv2d_hwnc_tensorcore import *
 from .correlation import *
 from .sparse import *
 from .argwhere import *
+from .scan import *
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 32691da..2d6e1e4 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -609,7 +609,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
         tag="fetch_score",
     )
 
-    if is_thrust_available():
+    target = tvm.target.Target.current()
+    if target and target.kind.name == "cuda" and is_thrust_available():
         sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
     else:
         sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py
index f19e4a1..232d679 100644
--- a/python/tvm/topi/cuda/scan.py
+++ b/python/tvm/topi/cuda/scan.py
@@ -19,30 +19,41 @@
 import tvm
 from tvm import te
 from tvm._ffi import get_global_func
-from ..transform import expand_dims, squeeze
-from ..utils import ceil_div
+from ..transform import expand_dims, squeeze, transpose, reshape
+from ..utils import ceil_div, swap, prod, get_const_int
 from ..math import cast
 from .. import tag
 from .injective import schedule_injective_from_existing
 
 
-def exclusive_sum_scan2d_ir(data, output, reduction=None):
+def _get_thrust_func_name(tvmop):
+    tvmop_to_thrust_func_name = {tvm.tir.generic.add: "tvm.contrib.thrust.sum_scan"}
+    assert tvmop in tvmop_to_thrust_func_name, "{} not supported by thrust".format(tvmop)
+    return tvmop_to_thrust_func_name[tvmop]
+
+
+def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add):
     """Low level IR to do exclusive sum scan along rows of 2D input.
 
     Parameters
     ----------
     data : Buffer
-        Input data. 2-D Buffer with shape [batch_size, scan_axis_size].
+        Input N-D Buffer. Scan is done over the innermost axis.
 
     output: Buffer
-        A buffer to store the output scan, of the same size as data
+        A buffer to store the output scan, of the same shape as data
 
     reduction: Buffer, optional
-        1D Buffer of size [batch_size], to store the sum of each row.
+        (N-1)-D Buffer, to store the sum of each scan axis.
+
+    binop: function, optional
+        A binary associative op to use for scan. The function takes two TIR expressions
+        and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute
+        prefix sum.
     """
 
-    batch_size = data.shape[0]
-    scan_axis_size = data.shape[1]
+    batch_size = prod(data.shape[:-1])
+    scan_axis_size = data.shape[-1]
 
     ib = tvm.tir.ir_builder.create()
 
@@ -76,7 +87,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None):
             ib.scope_attr(by, "thread_extent", nthread_by)
             tid = bx * nthread_tx + tx
             with ib.if_scope(tid < scan_axis_size):
-                output[by, tid] = data[by, tid]
+                output[by * scan_axis_size + tid] = cast(data[by * scan_axis_size + tid], out_dtype)
 
         nthread_tx = max_threads
         nthread_bx = ceil_div(scan_axis_size, max_threads)
@@ -111,9 +122,10 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None):
                     middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
                     end[0] = tvm.te.min(start[0] + width, scan_axis_size)
                     with ib.if_scope(middle[0] < scan_axis_size):
-                        output[by * scan_axis_size + end[0] - 1] += output[
-                            by * scan_axis_size + middle[0] - 1
-                        ]
+                        output[by * scan_axis_size + end[0] - 1] = binop(
+                            output[by * scan_axis_size + end[0] - 1],
+                            output[by * scan_axis_size + middle[0] - 1],
+                        )
 
         # Down Sweep of exclusive scan
         with ib.new_scope():
@@ -153,28 +165,33 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None):
                         output[by * scan_axis_size + middle[0] - 1] = output[
                             by * scan_axis_size + end[0] - 1
                         ]
-                        output[by * scan_axis_size + end[0] - 1] += tmp[0]
+                        output[by * scan_axis_size + end[0] - 1] = binop(
+                            output[by * scan_axis_size + end[0] - 1], tmp[0]
+                        )
     return ib.get()
 
 
-def get_reduction_from_exclusive_scan(data, ex_scan_output):
+def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generic.add):
     """Return the sum of the last element of data and the exclusive scan output.
     The is the reduction of data along each row (for 2-D case).
 
     Parameters
     ----------
     data : tvm.te.Tensor
-        Input data. 1-D tensor with shape [scan_axis_size], or
-        2-D tensor with shape [batch_size, scan_axis_size].
+        Input data of any shape
 
     ex_scan_output : tvm.te.Tensor
-        1-D tensor that is the exclusive scan of the input, or
-        2-D tensor storing the exclusive scan of each row.
+        The output of exclusive scan on data
+
+    binop: function, optional
+        A binary associative op to use for scan. The function takes two TIR expressions
+        and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute
+        prefix sum.
 
     Returns
     -------
     reduction : tvm.te.Tensor
-        1-D tensor storing the reduction of each row.
+        (N-1)-D tensor storing the reduction of each scan axis.
     """
     ndim = len(data.shape)
     if ndim == 1:
@@ -182,8 +199,8 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output):
         ex_scan_output = expand_dims(ex_scan_output, axis=0)
 
     def ir(data, data_ex_scan, reduction):
-        batch_size = data.shape[0]
-        num_anchors = data.shape[1]
+        batch_size = prod(data.shape[:-1])
+        scan_axis_size = data.shape[-1]
 
         ib = tvm.tir.ir_builder.create()
 
@@ -201,21 +218,23 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output):
             ib.scope_attr(bx, "thread_extent", nthread_bx)
             tid = bx * max_threads + tx
             with ib.if_scope(tid < batch_size):
-                with ib.if_scope(num_anchors > 0):
-                    reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1]
+                with ib.if_scope(scan_axis_size > 0):
+                    reduction[tid] = binop(
+                        data_ex_scan[tid * scan_axis_size + scan_axis_size - 1],
+                        data[tid, scan_axis_size - 1],
+                    )
                 with ib.else_scope():
                     reduction[tid] = 0
 
         return ib.get()
 
-    assert len(data.shape) == 2, "Only 2D input supported for now"
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8)
     ex_scan_output_buf = tvm.tir.decl_buffer(
         ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8
     )
 
     reduction = te.extern(
-        [(data.shape[0],)],
+        [data.shape[:-1]],
         [data, ex_scan_output],
         lambda ins, outs: ir(ins[0], ins[1], outs[0]),
         dtype=[ex_scan_output.dtype],
@@ -235,14 +254,15 @@ def is_thrust_available():
     return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None
 
 
-def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False):
-    """Do exclusive scan on 1D input or along rows of 2D input, using thrust.
+def scan_thrust(
+    data, output_dtype, exclusive=True, return_reduction=False, binop=tvm.tir.generic.add
+):
+    """Do exclusive or inclusive scan on 1D or multidimensional input, using thrust.
 
     Parameters
     ----------
     data : tvm.te.Tensor
-        Input data. 1-D tensor with shape [scan_axis_size], or
-        2-D tensor with shape [batch_size, scan_axis_size].
+        Input data of any shape. The scan is done over the innermost axis.
 
     output_dtype: string
         The dtype of the output scan tensor.
@@ -251,99 +271,104 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False):
         Whether or not do exclusive or inclusive scan.
 
     return_reduction: bool, optional
-        Whether or not return a 1-D tensor storing the reduction of each row.
+        Whether or not return a (N-1)-D tensor storing the reduction of each scan axis.
         Reductions are computed as part of the upsweep pass, so there is no extra cost.
-        If False, reductions are ignored.
+        If False, reductions are ignored. It must be False when exclusive is False.
+
+    binop: function, optional
+        A binary associative op to use for scan. Since we need to lookup the corresponding
+        thrust function, arbitrariy callables are not supported. Currently only
+        tvm.tir.generic.add can be passed in.
 
     Returns
     -------
     output : tvm.te.Tensor
-        1-D tensor that is the exclusive scan of the input, or
-        2-D tensor storing the exclusive scan of each row.
+        A N-D tensor of the same rank N and shape as the input data.
 
     reduction : tvm.te.Tensor, optional
-        1-D tensor storing the reduction of each row.
+        (N-1)-D tensor storing the reduction of each scan axis.
         Returned if return_reduction is True.
     """
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
     output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8)
+
     output = te.extern(
         [data.shape],
         [data],
         lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive
+            _get_thrust_func_name(binop), ins[0], outs[0], exclusive
         ),
         dtype=[output_dtype],
         in_buffers=[data_buf],
         out_buffers=[output_buf],
-        name="exclusive_sum_scan2d",
-        tag="exclusive_sum_scan2d_gpu",
+        name="exclusive_scan_thrust",
+        tag="exclusive_scan_thrust_gpu",
     )
 
     if return_reduction:
         assert exclusive, "return_reduction should be False for inclusive scan"
-        reduction = get_reduction_from_exclusive_scan(data, output)
+        reduction = get_reduction_from_exclusive_scan(data, output, binop)
         return output, reduction
 
     return output
 
 
-def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None):
-    """Do exclusive scan on 1D input or along rows of 2D input.
+def exclusive_scan(
+    data, axis=-1, return_reduction=False, output_dtype=None, binop=tvm.tir.generic.add
+):
+    """Do exclusive scan on 1D or multidimensional input.
 
     Parameters
     ----------
     data : tvm.te.Tensor
-        Input data. 1-D tensor with shape [scan_axis_size], or
-        2-D tensor with shape [batch_size, scan_axis_size].
+        Input data of any shape.
 
     axis: int, optional
-        The axis to do scan on. For now, only the inner most axis is supported.
+        The axis to do scan on. By default, scan is done on the innermost axis.
 
     return_reduction: bool, optional
-        Whether or not return a 1-D tensor storing the reduction of each row.
+        Whether or not return a tensor storing the reduction over each scan axis.
+        If the input rank is N, this tensor is of rank N - 1.
         Reductions are computed as part of the upsweep pass, so there is no extra cost.
         If False, reductions are ignored.
 
     output_dtype: string, optional
         The dtype of the output scan tensor. If not provided, the dtype of the input is used.
 
+    binop: function, optional
+        A binary associative op to use for scan. The function takes two TIR expressions
+        and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute
+        prefix sum.
+
     Returns
     -------
     output : tvm.te.Tensor
-        1-D tensor that is the exclusive scan of the input, or
-        2-D tensor storing the exclusive scan of each row.
+        A N-D tensor of the same rank N and shape as the input data.
 
     reduction : tvm.te.Tensor, optional
-        1-D tensor storing the reduction of each row.
+        (N-1)-D tensor storing the reduction of each scan axis.
         Returned if return_reduction is True.
     """
-    # TODO(masahi): Support other binary operators
-    ndim = len(data.shape)
-    if axis < 0:
-        axis += ndim
-    assert axis == ndim - 1, "Only support scan on the inner most axis."
 
-    if output_dtype is None:
-        output_dtype = data.dtype
+    def do_scan(data, output_dtype):
+        target = tvm.target.Target.current()
+        if target and target.kind.name == "cuda" and is_thrust_available():
+            return scan_thrust(
+                data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
+            )
 
-    target = tvm.target.Target.current()
-    if target and target.kind.name == "cuda" and is_thrust_available():
-        return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction)
+        if ndim == 1:
+            # TIR exclusive scan accepts only 2D or higher-rank inputs.
+            data = expand_dims(data, axis=0)
 
-    if ndim == 1:
-        # TIR exclusive scan accepts only 2D inputs.
-        data = expand_dims(data, axis=0)
+        data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+        output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8)
 
-    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
-    output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8)
-
-    if len(data.shape) == 2:
         if return_reduction:
             output, reduction = te.extern(
-                [data.shape, (data.shape[0],)],
+                [data.shape, data.shape[:-1]],
                 [data],
-                lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]),
+                lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], outs[1], binop=binop),
                 dtype=[data.dtype, output_dtype],
                 in_buffers=[data_buf],
                 name="exclusive_scan",
@@ -353,7 +378,7 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None):
             output = te.extern(
                 [data.shape],
                 [data],
-                lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]),
+                lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], binop=binop),
                 dtype=[output_dtype],
                 in_buffers=[data_buf],
                 out_buffers=[output_buf],
@@ -361,13 +386,38 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None):
                 tag="exclusive_scan_gpu",
             )
             reduction = None
-    else:
-        assert False, "Unsupported dimension {}".format(ndim)
 
-    if ndim == 1:
-        output = squeeze(output, 0)
+        if ndim == 1:
+            output = squeeze(output, 0)
+            if return_reduction:
+                reduction = squeeze(reduction, 0)
+
         if return_reduction:
-            reduction = squeeze(reduction, 0)
+            return output, reduction
+
+        return output
+
+    if output_dtype is None or output_dtype == "":
+        output_dtype = data.dtype
+
+    ndim = len(data.shape)
+    if axis < 0:
+        axis += ndim
+
+    # If scan axis is not the innermost one, swap the scan and the innermost axes
+    # Scan is always done on the innermost axis, for performance reason.
+    if axis != ndim - 1:
+        axes = swap(list(range(ndim)), axis)
+        data = transpose(data, axes)
+
+    if return_reduction:
+        output, reduction = do_scan(data, output_dtype)
+    else:
+        output = do_scan(data, output_dtype)
+
+    if axis != ndim - 1:
+        axes = swap(list(range(ndim)), axis)
+        output = transpose(output, axes)
 
     if return_reduction:
         return output, reduction
@@ -375,6 +425,38 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None):
     return output
 
 
+def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add):
+    """Do inclusive scan on 1D or multidimensional input.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input data of any shape.
+
+    axis: int, optional
+        The axis to do scan on. By default, scan is done on the innermost axis.
+
+    output_dtype: string, optional
+        The dtype of the output scan tensor. If not provided, the dtype of the input is used.
+
+    binop: function, optional
+        A binary associative op to use for scan. The function takes two TIR expressions
+        and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute
+        prefix sum.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        A N-D tensor of the same rank N as the input data.
+    """
+    ex_scan = exclusive_scan(data, axis, output_dtype=output_dtype, binop=binop)
+
+    if output_dtype is not None and data.dtype != output_dtype and output_dtype != "":
+        data = cast(data, output_dtype)
+
+    return binop(data, ex_scan)
+
+
 def schedule_scan(outs):
     """Schedule for scan operator.
 
@@ -404,3 +486,32 @@ def schedule_scan(outs):
     for out in outs:
         traverse(out.op)
     return s
+
+
+def cumsum(data, axis=None, dtype=None):
+    """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input data to the operator.
+
+    axis : int, optional
+        Axis along which the cumulative sum is computed. The default (None) is to compute
+        the cumsum over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        The result has the same size as data, and the same shape as data if axis is not None.
+        If axis is None, the result is a 1-d array.
+    """
+    if axis is None:
+        axis = 0
+        data = reshape(data, (prod(data.shape),))
+    axis = get_const_int(axis)
+    return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index 1834038..c0f076f 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -23,12 +23,7 @@ from tvm._ffi import get_global_func
 from .injective import schedule_injective_from_existing
 from ..transform import strided_slice, transpose
 from .. import tag
-from ..utils import ceil_div
-
-
-def swap(arr, axis):
-    """ swap arr[axis] and arr[-1] """
-    return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]]
+from ..utils import ceil_div, swap
 
 
 def _schedule_sort(outs):
diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py
new file mode 100644
index 0000000..855427b
--- /dev/null
+++ b/python/tvm/topi/cumsum.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Cumsum operator"""
+from ..tir import decl_buffer, ir_builder
+from ..te import extern
+from .utils import prod, get_const_int
+from .math import cast
+
+
+def cumsum(data, axis=None, dtype=None):
+    """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input data to the operator.
+
+    axis : int, optional
+        Axis along which the cumulative sum is computed. The default (None) is to compute
+        the cumsum over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        The result has the same size as data, and the same shape as data if axis is not None.
+        If axis is None, the result is a 1-d array.
+    """
+    if dtype is None or dtype == "":
+        dtype = data.dtype
+
+    def maybe_cast(x):
+        if dtype != data.dtype:
+            return cast(x, dtype)
+        return x
+
+    axis_mul_before = 1
+    axis_mul_after = 1
+
+    if axis is None:
+        axis = 0
+        cumsum_axis_len = prod(data.shape)
+        shape = (cumsum_axis_len,)
+    else:
+        if not isinstance(axis, int):
+            axis = get_const_int(axis)
+
+        shape = data.shape
+        cumsum_axis_len = shape[axis]
+
+        if axis < 0:
+            axis = len(shape) + axis
+
+        for i, value in enumerate(shape, 0):
+            if i < axis:
+                axis_mul_before *= value
+            elif i > axis:
+                axis_mul_after *= value
+
+    def gen_ir(data_buf, out_buf):
+        ib = ir_builder.create()
+        data_buf = ib.buffer_ptr(data_buf)
+        out_buf = ib.buffer_ptr(out_buf)
+
+        with ib.for_range(0, axis_mul_before * axis_mul_after, "fused", kind="parallel") as fused:
+            i = fused // axis_mul_after
+            j = fused % axis_mul_after
+            base_idx = i * cumsum_axis_len * axis_mul_after + j
+            out_buf[base_idx] = maybe_cast(data_buf[base_idx])
+            with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k:
+                k = _k + 1
+                cur_idx = base_idx + k * axis_mul_after
+                prev_idx = base_idx + (k - 1) * axis_mul_after
+                out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx])
+
+        return ib.get()
+
+    out_buf = decl_buffer(shape, dtype, "out_buf")
+
+    return extern(
+        [shape],
+        [data],
+        lambda ins, outs: gen_ir(ins[0], outs[0]),
+        dtype=dtype,
+        out_buffers=[out_buf],
+        name="cumsum_generic",
+        tag="cumsum_generic",
+    )
diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py
index dfc226f..cd9f0c6 100644
--- a/python/tvm/topi/utils.py
+++ b/python/tvm/topi/utils.py
@@ -492,3 +492,8 @@ def is_empty_shape(shape):
 def ceil_div(a, b):
     """Return ceil division of a by b"""
     return tvm.tir.indexdiv(a + (b - 1), b)
+
+
+def swap(arr, axis):
+    """ swap arr[axis] and arr[-1] """
+    return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]]
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index ecfde35..0e868cd 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3674,5 +3674,57 @@ RELAY_REGISTER_OP("adv_index")
     .set_attr<TOpPattern>("TOpPattern", kInjective)
     .set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);
 
+TVM_REGISTER_NODE_TYPE(CumsumAttrs);
+
+bool CumsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+               const TypeReporter& reporter) {
+  // types: [data, output]
+  ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output";
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    ICHECK(types[0].as<IncompleteTypeNode>())
+        << "cumsum: expect input type to be TensorType but get " << types[0];
+    return false;
+  }
+
+  const auto* param = attrs.as<CumsumAttrs>();
+
+  auto dtype = param->dtype;
+  if (dtype.is_void()) {
+    dtype = data->dtype;
+  }
+
+  if (param->axis.defined()) {
+    reporter->Assign(types[1], TensorType(data->shape, dtype));
+  } else {
+    auto prod = data->shape[0];
+    for (size_t i = 1; i < data->shape.size(); ++i) {
+      prod = prod * data->shape[i];
+    }
+    reporter->Assign(types[1], TensorType({prod}, dtype));
+  }
+
+  return true;
+}
+
+Expr MakeCumsum(Expr data, Integer axis, DataType dtype) {
+  auto attrs = make_object<CumsumAttrs>();
+  attrs->dtype = dtype;
+  attrs->axis = axis;
+  static const Op& op = Op::Get("cumsum");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.cumsum").set_body_typed(MakeCumsum);
+
+RELAY_REGISTER_OP("cumsum")
+    .describe(
+        R"doc(Return the cumulative sum of the elements along a given axis.)doc" TVM_ADD_FILELINE)
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_support_level(3)
+    .add_type_rel("Cumsum", CumsumRel)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu
index 4e3e3a8..7295d4c 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -275,9 +275,22 @@ void thrust_scan(DLTensor* data,
 
   if (scan_size == 0) return;
 
-  if (data->ndim == 1 || (data->ndim == 2 && data->shape[0] == 1)) {
-    if (exclusive) {
+  size_t size = 1;
+  for (int i = 0; i < data->ndim; ++i) size *= data->shape[i];
+
+  const bool need_cast = std::is_same<InType, OutType>::value == false;
+
+  auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) {
+    return static_cast<OutType>(v);
+  }); // NOLINT(*)
+
+  if (size == static_cast<size_t>(data->shape[data->ndim - 1])) {
+    if (exclusive && need_cast) {
+      thrust::exclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
+    } else if (exclusive && !need_cast) {
       thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
+    } else if (!exclusive && need_cast) {
+      thrust::inclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
     } else {
       thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
     }
@@ -288,17 +301,19 @@ void thrust_scan(DLTensor* data,
 
     // This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,...,
     // without materializing the sequence vector
-    auto counting_iter = thrust::counting_iterator<int64_t>(0);
+    auto counting_iter = thrust::counting_iterator<size_t>(0);
     // Without __host__ annotation, cub crashes
-    auto linear_index_to_scan_key = [scan_size] __host__ __device__(int64_t i) {
+    auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) {
         return i / scan_size;
     }; // NOLINT(*)
     auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key);
-    int64_t size = 1;
-    for (int i = 0; i < data->ndim; ++i) size *= data->shape[i];
 
-    if (exclusive) {
+    if (exclusive && need_cast) {
+      thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
+    } else if (exclusive && !need_cast) {
       thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
+    } else if (!exclusive && need_cast) {
+      thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
     } else {
       thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
     }
@@ -315,28 +330,62 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
   auto in_dtype = DLDataType2String(data->dtype);
   auto out_dtype = DLDataType2String(output->dtype);
 
-  if (in_dtype == "int32") {
+  if (in_dtype == "bool") {
+    if (out_dtype == "int32") {
+      thrust_scan<bool, int>(data, output, exclusive);
+    } else if (out_dtype == "int64") {
+      thrust_scan<bool, int64_t>(data, output, exclusive);
+    } else if (out_dtype == "float32") {
+      thrust_scan<bool, float>(data, output, exclusive);
+    } else if (out_dtype == "float64") {
+      thrust_scan<bool, double>(data, output, exclusive);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                 << ". Supported output dtypes are int32, int64, float32, and float64";
+    }
+  } else if (in_dtype == "int32") {
     if (out_dtype == "int32") {
       thrust_scan<int, int>(data, output, exclusive);
     } else if (out_dtype == "int64") {
       thrust_scan<int, int64_t>(data, output, exclusive);
+    } else if (out_dtype == "float32") {
+      thrust_scan<int, float>(data, output, exclusive);
+    } else if (out_dtype == "float64") {
+      thrust_scan<int, double>(data, output, exclusive);
     } else {
-      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                 << ". Supported output dtypes are int32, int64, float32, and float64";
     }
   } else if (in_dtype == "int64") {
     if (out_dtype == "int64") {
       thrust_scan<int64_t, int64_t>(data, output, exclusive);
+    } else if (out_dtype == "float32") {
+      thrust_scan<int64_t, float>(data, output, exclusive);
+    } else if (out_dtype == "float64") {
+      thrust_scan<int64_t, double>(data, output, exclusive);
     } else {
-      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                 << ". Supported output dtypes are int64, float32, and float64";
     }
   } else if (in_dtype == "float32") {
     if (out_dtype == "float32") {
       thrust_scan<float, float>(data, output, exclusive);
+    } else if (out_dtype == "float64") {
+      thrust_scan<float, double>(data, output, exclusive);
     } else {
-      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                 << ". Supported output dtypes are float32, and float64";
+    }
+  } else if (in_dtype == "float64") {
+    if (out_dtype == "float64") {
+      thrust_scan<double, double>(data, output, exclusive);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                 << ". Supported output dtype is float64";
     }
   } else {
-    LOG(FATAL) << "Unsupported input dtype: " << in_dtype;
+    LOG(FATAL) << "Unsupported input dtype: " << in_dtype
+               << ". Supported input dtypes are bool, int32, int64, float32, and float64";
   }
 });
 
diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py
index 5f66d46..c5b6a29 100644
--- a/tests/python/contrib/test_thrust.py
+++ b/tests/python/contrib/test_thrust.py
@@ -59,7 +59,7 @@ def test_exclusive_scan():
         print("skip because thrust is not enabled...")
         return
 
-    for ishape in [(1,), (10, 10)]:
+    for ishape in [(10,), (10, 10), (10, 10, 10)]:
         values = te.placeholder(ishape, name="values", dtype="int32")
 
         with tvm.target.Target("cuda"):
@@ -75,7 +75,7 @@ def test_exclusive_scan():
         if len(ishape) == 1:
             reduction_shape = ()
         else:
-            reduction_shape = (ishape[0],)
+            reduction_shape = ishape[:-1]
 
         reduction_np_out = np.zeros(reduction_shape, np.int32)
 
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 5e44170..559eb24 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1311,6 +1311,7 @@ def test_sparse_to_dense():
     # verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
 
 
+@tvm.testing.uses_gpu
 def test_adv_index():
     def verify_adv_index(data_shape, index_shapes):
         dtype = "float32"
@@ -1342,6 +1343,40 @@ def test_adv_index():
     verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
 
 
+@tvm.testing.parametrize_targets
+def test_cumsum(target, ctx):
+    def verify_cumsum(data_np, np_out, axis=None, out_dtype=None, rtol=1e-5, atol=1e-5):
+        inp = relay.var("data", relay.TensorType(data_np.shape, str(data_np.dtype)))
+
+        out = relay.op.cumsum(inp, axis, out_dtype)
+        func = relay.Function([inp], out)
+
+        for kind in ["graph", "debug"]:
+            intrp = relay.create_executor(kind, ctx=ctx, target=target)
+            op_res = intrp.evaluate(func)(data_np)
+            tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=rtol, atol=atol)
+
+    data = np.array([2, 3, 0])
+    verify_cumsum(data, np.cumsum(data))
+    verify_cumsum(data, np.cumsum(data), out_dtype="int64")
+
+    data = np.random.randn(10, 10)
+    verify_cumsum(data, np.cumsum(data))
+    verify_cumsum(data, np.cumsum(data, axis=0), axis=0)
+    verify_cumsum(data, np.cumsum(data, axis=1), axis=1)
+
+    data = np.random.randn(10, 5, 10).astype("float32")
+    verify_cumsum(data, np.cumsum(data), rtol=1e-4, atol=1e-4)
+    verify_cumsum(data, np.cumsum(data, axis=0), axis=0, rtol=1e-4, atol=1e-4)
+    verify_cumsum(data, np.cumsum(data, axis=1), axis=1, rtol=1e-4, atol=1e-4)
+    verify_cumsum(data, np.cumsum(data, axis=-1), axis=-1, rtol=1e-4, atol=1e-4)
+
+    data = np.random.rand(10) > 0.5
+    data = data.astype(np.int32)
+    verify_cumsum(data, np.cumsum(data, dtype=np.int32))
+    verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64")
+
+
 if __name__ == "__main__":
     test_cast()
     test_zeros_ones()
@@ -1379,3 +1414,4 @@ if __name__ == "__main__":
     test_sparse_to_dense()
     test_fixed_point_multiply()
     test_adv_index()
+    test_cumsum()
diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py
new file mode 100644
index 0000000..a01a496
--- /dev/null
+++ b/tests/python/topi/python/test_topi_cumsum.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import topi
+import tvm.topi.testing
+
+
+@tvm.testing.parametrize_targets
+def test_cumsum(ctx, target):
+    def check_cumsum(np_ref, data, axis=None, dtype=None):
+        implementations = {
+            "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern),
+            "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+            "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+        }
+        fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
+        tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
+
+    data = np.array([2, 3, 0])
+    check_cumsum(np.cumsum(data), data)
+
+    data = np.random.rand(10) > 0.5
+    data = data.astype(np.int32)
+    check_cumsum(np.cumsum(data, dtype=np.int32), data)
+    check_cumsum(np.cumsum(data), data, dtype="int64")
+
+    data = np.random.rand(10) > 0.5
+    check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
+
+    for in_dtype in ["float32", "float64"]:
+        data = np.random.randn(10, 10).astype(in_dtype)
+        check_cumsum(np.cumsum(data), data)
+        check_cumsum(np.cumsum(data, axis=0), data, axis=0)
+        check_cumsum(np.cumsum(data, axis=1), data, axis=1)
+
+        data = np.random.randn(10, 5, 10).astype(in_dtype)
+        check_cumsum(np.cumsum(data), data)
+        check_cumsum(np.cumsum(data, axis=0), data, axis=0)
+        check_cumsum(np.cumsum(data, axis=1), data, axis=1)
+        check_cumsum(np.cumsum(data, axis=-1), data, axis=-1)
+
+    for in_dtype in ["int32", "int64"]:
+        data = np.random.randint(-100, 100, size=(100, 100)).astype(in_dtype)
+        check_cumsum(np.cumsum(data, dtype=in_dtype), data)
+        check_cumsum(np.cumsum(data), data, dtype="int64")
+        check_cumsum(np.cumsum(data, axis=0, dtype=in_dtype), data, axis=0)
+        check_cumsum(np.cumsum(data, axis=1, dtype=in_dtype), data, axis=1)
+
+        data = np.random.randint(1 << 30, (1 << 31) - 1, size=(100)).astype(in_dtype)
+        check_cumsum(np.cumsum(data), data, dtype="int64")
+
+
+if __name__ == "__main__":
+    test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
+    test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
+    test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))