You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/26 23:54:55 UTC
[tvm] branch main updated: [Relay,
TOPI] Add numpy style cumsum op (#7334)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1e0d356 [Relay, TOPI] Add numpy style cumsum op (#7334)
1e0d356 is described below
commit 1e0d3569b94f650243f4d0ac204d196e3be8b0aa
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Jan 27 08:54:36 2021 +0900
[Relay, TOPI] Add numpy style cumsum op (#7334)
* Add cumsum relay/topi op
* relay tests working
* add torch frontend converter
* fix for importing detr
* fix bad merge
* begin cuda cumsum
* support non innermost axis
* support rank higher than 3
* making binop parameter
* fix overflow issue in thrust scan
* generic binop parameter working
* relay test working
* fixed for bool input
* remove pytorch change
* fix pylint
* doc update
* Update python/tvm/topi/cumsum.py
Co-authored-by: Tristan Konolige <tr...@gmail.com>
* Update tests/python/relay/test_op_level3.py
Co-authored-by: Tristan Konolige <tr...@gmail.com>
* add example outputs
* add supported input and output dtype in thrust log
* adding more loop var names
* fix cpplint
* fix missing check for the cuda target in nms thrust sort
* parallelize cpu cumsum
* making binop argument tir function
* update doc for binop
* doc update
Co-authored-by: Tristan Konolige <tr...@gmail.com>
---
include/tvm/relay/attrs/transform.h | 10 ++
python/tvm/relay/op/_transform.py | 12 +-
python/tvm/relay/op/strategy/cuda.py | 12 ++
python/tvm/relay/op/strategy/generic.py | 21 +++
python/tvm/relay/op/transform.py | 49 +++++
python/tvm/topi/__init__.py | 1 +
python/tvm/topi/cuda/__init__.py | 1 +
python/tvm/topi/cuda/nms.py | 3 +-
python/tvm/topi/cuda/scan.py | 255 +++++++++++++++++++--------
python/tvm/topi/cuda/sort.py | 7 +-
python/tvm/topi/cumsum.py | 106 +++++++++++
python/tvm/topi/utils.py | 5 +
src/relay/op/tensor/transform.cc | 52 ++++++
src/runtime/contrib/thrust/thrust.cu | 73 ++++++--
tests/python/contrib/test_thrust.py | 4 +-
tests/python/relay/test_op_level3.py | 36 ++++
tests/python/topi/python/test_topi_cumsum.py | 72 ++++++++
17 files changed, 625 insertions(+), 94 deletions(-)
diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index efa44e0..4316624 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -438,6 +438,16 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
}
}; // struct MatrixSetDiagAttrs
+/*! \brief Attributes used in cumsum operator */
+struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
+ Integer axis;
+ DataType dtype;
+ TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
+ TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>());
+ TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
+ }
+};
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 05ca6d2..fd07c98 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -103,7 +103,7 @@ def compute_scatter_add(attrs, inputs, output_type):
_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)
-# scatter
+# scatter_nd
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
@@ -112,6 +112,16 @@ def compute_scatter_nd(attrs, inputs, output_type):
_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)
+# cumsum
+@_reg.register_compute("cumsum")
+def compute_cumsum(attrs, inputs, output_type):
+ """Compute definition of cumsum"""
+ return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)]
+
+
+_reg.register_strategy("cumsum", strategy.cumsum_strategy)
+_reg.register_shape_func("cumsum", False, elemwise_shape_func)
+
#####################
# Shape functions #
#####################
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 3863df0..346e934 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -996,3 +996,15 @@ def argwhere_strategy_cuda(attrs, inputs, out_type, target):
name="argwhere.cuda",
)
return strategy
+
+
+@cumsum_strategy.register(["cuda", "gpu"])
+def cumsum_strategy_cuda(attrs, inputs, out_type, target):
+ """cumsum cuda strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_cumsum(topi.cuda.cumsum),
+ wrap_topi_schedule(topi.cuda.schedule_scan),
+ name="cumsum.cuda",
+ )
+ return strategy
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 8dd9dc5..3ad75fa 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1361,3 +1361,24 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
name="threefry_split.generic",
)
return strategy
+
+
+def wrap_compute_cumsum(topi_compute):
+ """Wrap cumsum topi compute"""
+
+ def _compute_cumsum(attrs, inputs, _):
+ return [topi_compute(inputs[0], attrs.axis, attrs.dtype)]
+
+ return _compute_cumsum
+
+
+@override_native_generic_func("cumsum_strategy")
+def cumsum_strategy(attrs, inputs, out_type, target):
+ """cumsum generic strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_cumsum(topi.cumsum),
+ wrap_topi_schedule(topi.generic.schedule_extern),
+ name="cumsum.generic",
+ )
+ return strategy
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 7e7f9b2..6785ff2 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -1320,3 +1320,52 @@ def adv_index(inputs):
Output tensor.
"""
return _make.adv_index(Tuple(inputs))
+
+
+def cumsum(data, axis=None, dtype=None):
+ """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
+ a given axis.
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data to the operator.
+
+ axis : int, optional
+ Axis along which the cumulative sum is computed. The default (None) is to compute
+ the cumsum over the flattened array.
+
+ dtype : string, optional
+ Type of the returned array and of the accumulator in which the elements are summed.
+ If dtype is not specified, it defaults to the dtype of data.
+
+ Returns
+ -------
+ result : relay.Expr
+ The result has the same size as data, and the same shape as data if axis is not None.
+ If axis is None, the result is a 1-d array.
+
+ Examples
+ --------
+ .. code-block:: python
+ a = [[1,2,3], [4,5,6]]
+
+ cumsum(a) # if axis is not provided, cumsum is done over the flattened input.
+ -> [ 1, 3, 6, 10, 15, 21]
+
+ cumsum(a, dtype="float32")
+ -> [ 1., 3., 6., 10., 15., 21.]
+
+ cumsum(a, axis=0) # sum over rows for each of the 3 columns
+ -> [[1, 2, 3],
+ [5, 7, 9]]
+
+ cumsum(a, axis=1)
+ -> [[ 1, 3, 6],
+ [ 4, 9, 15]]
+
+ a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array
+ cumsum(a, dtype=int32) # dtype should be provided to get the expected results
+ -> [1, 1, 2, 2, 3, 4, 4]
+ """
+ return _make.cumsum(data, axis, dtype)
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index cb94b5b..873901d 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -40,6 +40,7 @@ from .sort import *
from .scatter import *
from .scatter_add import *
from .argwhere import *
+from .cumsum import *
from . import generic
from . import nn
from . import x86
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 42bf980..e0ff5a1 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -56,3 +56,4 @@ from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
from .argwhere import *
+from .scan import *
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 32691da..2d6e1e4 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -609,7 +609,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
tag="fetch_score",
)
- if is_thrust_available():
+ target = tvm.target.Target.current()
+ if target and target.kind.name == "cuda" and is_thrust_available():
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py
index f19e4a1..232d679 100644
--- a/python/tvm/topi/cuda/scan.py
+++ b/python/tvm/topi/cuda/scan.py
@@ -19,30 +19,41 @@
import tvm
from tvm import te
from tvm._ffi import get_global_func
-from ..transform import expand_dims, squeeze
-from ..utils import ceil_div
+from ..transform import expand_dims, squeeze, transpose, reshape
+from ..utils import ceil_div, swap, prod, get_const_int
from ..math import cast
from .. import tag
from .injective import schedule_injective_from_existing
-def exclusive_sum_scan2d_ir(data, output, reduction=None):
+def _get_thrust_func_name(tvmop):
+ tvmop_to_thrust_func_name = {tvm.tir.generic.add: "tvm.contrib.thrust.sum_scan"}
+ assert tvmop in tvmop_to_thrust_func_name, "{} not supported by thrust".format(tvmop)
+ return tvmop_to_thrust_func_name[tvmop]
+
+
+def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add):
"""Low level IR to do exclusive sum scan along rows of 2D input.
Parameters
----------
data : Buffer
- Input data. 2-D Buffer with shape [batch_size, scan_axis_size].
+ Input N-D Buffer. Scan is done over the innermost axis.
output: Buffer
- A buffer to store the output scan, of the same size as data
+ A buffer to store the output scan, of the same shape as data
reduction: Buffer, optional
- 1D Buffer of size [batch_size], to store the sum of each row.
+ (N-1)-D Buffer, to store the sum of each scan axis.
+
+ binop: function, optional
+ A binary associative op to use for scan. The function takes two TIR expressions
+ and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute
+ prefix sum.
"""
- batch_size = data.shape[0]
- scan_axis_size = data.shape[1]
+ batch_size = prod(data.shape[:-1])
+ scan_axis_size = data.shape[-1]
ib = tvm.tir.ir_builder.create()
@@ -76,7 +87,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None):
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < scan_axis_size):
- output[by, tid] = data[by, tid]
+ output[by * scan_axis_size + tid] = cast(data[by * scan_axis_size + tid], out_dtype)
nthread_tx = max_threads
nthread_bx = ceil_div(scan_axis_size, max_threads)
@@ -111,9 +122,10 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, scan_axis_size)
with ib.if_scope(middle[0] < scan_axis_size):
- output[by * scan_axis_size + end[0] - 1] += output[
- by * scan_axis_size + middle[0] - 1
- ]
+ output[by * scan_axis_size + end[0] - 1] = binop(
+ output[by * scan_axis_size + end[0] - 1],
+ output[by * scan_axis_size + middle[0] - 1],
+ )
# Down Sweep of exclusive scan
with ib.new_scope():
@@ -153,28 +165,33 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None):
output[by * scan_axis_size + middle[0] - 1] = output[
by * scan_axis_size + end[0] - 1
]
- output[by * scan_axis_size + end[0] - 1] += tmp[0]
+ output[by * scan_axis_size + end[0] - 1] = binop(
+ output[by * scan_axis_size + end[0] - 1], tmp[0]
+ )
return ib.get()
-def get_reduction_from_exclusive_scan(data, ex_scan_output):
+def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generic.add):
"""Return the sum of the last element of data and the exclusive scan output.
The is the reduction of data along each row (for 2-D case).
Parameters
----------
data : tvm.te.Tensor
- Input data. 1-D tensor with shape [scan_axis_size], or
- 2-D tensor with shape [batch_size, scan_axis_size].
+ Input data of any shape
ex_scan_output : tvm.te.Tensor
- 1-D tensor that is the exclusive scan of the input, or
- 2-D tensor storing the exclusive scan of each row.
+ The output of exclusive scan on data
+
+ binop: function, optional
+ A binary associative op to use for scan. The function takes two TIR expressions
+ and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute
+ prefix sum.
Returns
-------
reduction : tvm.te.Tensor
- 1-D tensor storing the reduction of each row.
+ (N-1)-D tensor storing the reduction of each scan axis.
"""
ndim = len(data.shape)
if ndim == 1:
@@ -182,8 +199,8 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output):
ex_scan_output = expand_dims(ex_scan_output, axis=0)
def ir(data, data_ex_scan, reduction):
- batch_size = data.shape[0]
- num_anchors = data.shape[1]
+ batch_size = prod(data.shape[:-1])
+ scan_axis_size = data.shape[-1]
ib = tvm.tir.ir_builder.create()
@@ -201,21 +218,23 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output):
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
- with ib.if_scope(num_anchors > 0):
- reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1]
+ with ib.if_scope(scan_axis_size > 0):
+ reduction[tid] = binop(
+ data_ex_scan[tid * scan_axis_size + scan_axis_size - 1],
+ data[tid, scan_axis_size - 1],
+ )
with ib.else_scope():
reduction[tid] = 0
return ib.get()
- assert len(data.shape) == 2, "Only 2D input supported for now"
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8)
ex_scan_output_buf = tvm.tir.decl_buffer(
ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8
)
reduction = te.extern(
- [(data.shape[0],)],
+ [data.shape[:-1]],
[data, ex_scan_output],
lambda ins, outs: ir(ins[0], ins[1], outs[0]),
dtype=[ex_scan_output.dtype],
@@ -235,14 +254,15 @@ def is_thrust_available():
return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None
-def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False):
- """Do exclusive scan on 1D input or along rows of 2D input, using thrust.
+def scan_thrust(
+ data, output_dtype, exclusive=True, return_reduction=False, binop=tvm.tir.generic.add
+):
+ """Do exclusive or inclusive scan on 1D or multidimensional input, using thrust.
Parameters
----------
data : tvm.te.Tensor
- Input data. 1-D tensor with shape [scan_axis_size], or
- 2-D tensor with shape [batch_size, scan_axis_size].
+ Input data of any shape. The scan is done over the innermost axis.
output_dtype: string
The dtype of the output scan tensor.
@@ -251,99 +271,104 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False):
Whether or not do exclusive or inclusive scan.
return_reduction: bool, optional
- Whether or not return a 1-D tensor storing the reduction of each row.
+ Whether or not return a (N-1)-D tensor storing the reduction of each scan axis.
Reductions are computed as part of the upsweep pass, so there is no extra cost.
- If False, reductions are ignored.
+ If False, reductions are ignored. It must be False when exclusive is False.
+
+ binop: function, optional
+ A binary associative op to use for scan. Since we need to lookup the corresponding
+ thrust function, arbitrariy callables are not supported. Currently only
+ tvm.tir.generic.add can be passed in.
Returns
-------
output : tvm.te.Tensor
- 1-D tensor that is the exclusive scan of the input, or
- 2-D tensor storing the exclusive scan of each row.
+ A N-D tensor of the same rank N and shape as the input data.
reduction : tvm.te.Tensor, optional
- 1-D tensor storing the reduction of each row.
+ (N-1)-D tensor storing the reduction of each scan axis.
Returned if return_reduction is True.
"""
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8)
+
output = te.extern(
[data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
- "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive
+ _get_thrust_func_name(binop), ins[0], outs[0], exclusive
),
dtype=[output_dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
- name="exclusive_sum_scan2d",
- tag="exclusive_sum_scan2d_gpu",
+ name="exclusive_scan_thrust",
+ tag="exclusive_scan_thrust_gpu",
)
if return_reduction:
assert exclusive, "return_reduction should be False for inclusive scan"
- reduction = get_reduction_from_exclusive_scan(data, output)
+ reduction = get_reduction_from_exclusive_scan(data, output, binop)
return output, reduction
return output
-def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None):
- """Do exclusive scan on 1D input or along rows of 2D input.
+def exclusive_scan(
+ data, axis=-1, return_reduction=False, output_dtype=None, binop=tvm.tir.generic.add
+):
+ """Do exclusive scan on 1D or multidimensional input.
Parameters
----------
data : tvm.te.Tensor
- Input data. 1-D tensor with shape [scan_axis_size], or
- 2-D tensor with shape [batch_size, scan_axis_size].
+ Input data of any shape.
axis: int, optional
- The axis to do scan on. For now, only the inner most axis is supported.
+ The axis to do scan on. By default, scan is done on the innermost axis.
return_reduction: bool, optional
- Whether or not return a 1-D tensor storing the reduction of each row.
+ Whether or not return a tensor storing the reduction over each scan axis.
+ If the input rank is N, this tensor is of rank N - 1.
Reductions are computed as part of the upsweep pass, so there is no extra cost.
If False, reductions are ignored.
output_dtype: string, optional
The dtype of the output scan tensor. If not provided, the dtype of the input is used.
+ binop: function, optional
+ A binary associative op to use for scan. The function takes two TIR expressions
+ and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute
+ prefix sum.
+
Returns
-------
output : tvm.te.Tensor
- 1-D tensor that is the exclusive scan of the input, or
- 2-D tensor storing the exclusive scan of each row.
+ A N-D tensor of the same rank N and shape as the input data.
reduction : tvm.te.Tensor, optional
- 1-D tensor storing the reduction of each row.
+ (N-1)-D tensor storing the reduction of each scan axis.
Returned if return_reduction is True.
"""
- # TODO(masahi): Support other binary operators
- ndim = len(data.shape)
- if axis < 0:
- axis += ndim
- assert axis == ndim - 1, "Only support scan on the inner most axis."
- if output_dtype is None:
- output_dtype = data.dtype
+ def do_scan(data, output_dtype):
+ target = tvm.target.Target.current()
+ if target and target.kind.name == "cuda" and is_thrust_available():
+ return scan_thrust(
+ data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
+ )
- target = tvm.target.Target.current()
- if target and target.kind.name == "cuda" and is_thrust_available():
- return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction)
+ if ndim == 1:
+ # TIR exclusive scan accepts only 2D or higher-rank inputs.
+ data = expand_dims(data, axis=0)
- if ndim == 1:
- # TIR exclusive scan accepts only 2D inputs.
- data = expand_dims(data, axis=0)
+ data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+ output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8)
- data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
- output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8)
-
- if len(data.shape) == 2:
if return_reduction:
output, reduction = te.extern(
- [data.shape, (data.shape[0],)],
+ [data.shape, data.shape[:-1]],
[data],
- lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]),
+ lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], outs[1], binop=binop),
dtype=[data.dtype, output_dtype],
in_buffers=[data_buf],
name="exclusive_scan",
@@ -353,7 +378,7 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None):
output = te.extern(
[data.shape],
[data],
- lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]),
+ lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], binop=binop),
dtype=[output_dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
@@ -361,13 +386,38 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None):
tag="exclusive_scan_gpu",
)
reduction = None
- else:
- assert False, "Unsupported dimension {}".format(ndim)
- if ndim == 1:
- output = squeeze(output, 0)
+ if ndim == 1:
+ output = squeeze(output, 0)
+ if return_reduction:
+ reduction = squeeze(reduction, 0)
+
if return_reduction:
- reduction = squeeze(reduction, 0)
+ return output, reduction
+
+ return output
+
+ if output_dtype is None or output_dtype == "":
+ output_dtype = data.dtype
+
+ ndim = len(data.shape)
+ if axis < 0:
+ axis += ndim
+
+ # If scan axis is not the innermost one, swap the scan and the innermost axes
+ # Scan is always done on the innermost axis, for performance reason.
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ data = transpose(data, axes)
+
+ if return_reduction:
+ output, reduction = do_scan(data, output_dtype)
+ else:
+ output = do_scan(data, output_dtype)
+
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ output = transpose(output, axes)
if return_reduction:
return output, reduction
@@ -375,6 +425,38 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None):
return output
+def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add):
+ """Do inclusive scan on 1D or multidimensional input.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ Input data of any shape.
+
+ axis: int, optional
+ The axis to do scan on. By default, scan is done on the innermost axis.
+
+ output_dtype: string, optional
+ The dtype of the output scan tensor. If not provided, the dtype of the input is used.
+
+ binop: function, optional
+ A binary associative op to use for scan. The function takes two TIR expressions
+ and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute
+ prefix sum.
+
+ Returns
+ -------
+ output : tvm.te.Tensor
+ A N-D tensor of the same rank N as the input data.
+ """
+ ex_scan = exclusive_scan(data, axis, output_dtype=output_dtype, binop=binop)
+
+ if output_dtype is not None and data.dtype != output_dtype and output_dtype != "":
+ data = cast(data, output_dtype)
+
+ return binop(data, ex_scan)
+
+
def schedule_scan(outs):
"""Schedule for scan operator.
@@ -404,3 +486,32 @@ def schedule_scan(outs):
for out in outs:
traverse(out.op)
return s
+
+
+def cumsum(data, axis=None, dtype=None):
+ """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ The input data to the operator.
+
+ axis : int, optional
+ Axis along which the cumulative sum is computed. The default (None) is to compute
+ the cumsum over the flattened array.
+
+ dtype : string, optional
+ Type of the returned array and of the accumulator in which the elements are summed.
+ If dtype is not specified, it defaults to the dtype of data.
+
+ Returns
+ -------
+ result : tvm.te.Tensor
+ The result has the same size as data, and the same shape as data if axis is not None.
+ If axis is None, the result is a 1-d array.
+ """
+ if axis is None:
+ axis = 0
+ data = reshape(data, (prod(data.shape),))
+ axis = get_const_int(axis)
+ return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index 1834038..c0f076f 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -23,12 +23,7 @@ from tvm._ffi import get_global_func
from .injective import schedule_injective_from_existing
from ..transform import strided_slice, transpose
from .. import tag
-from ..utils import ceil_div
-
-
-def swap(arr, axis):
- """ swap arr[axis] and arr[-1] """
- return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]]
+from ..utils import ceil_div, swap
def _schedule_sort(outs):
diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py
new file mode 100644
index 0000000..855427b
--- /dev/null
+++ b/python/tvm/topi/cumsum.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Cumsum operator"""
+from ..tir import decl_buffer, ir_builder
+from ..te import extern
+from .utils import prod, get_const_int
+from .math import cast
+
+
+def cumsum(data, axis=None, dtype=None):
+ """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ The input data to the operator.
+
+ axis : int, optional
+ Axis along which the cumulative sum is computed. The default (None) is to compute
+ the cumsum over the flattened array.
+
+ dtype : string, optional
+ Type of the returned array and of the accumulator in which the elements are summed.
+ If dtype is not specified, it defaults to the dtype of data.
+
+ Returns
+ -------
+ result : tvm.te.Tensor
+ The result has the same size as data, and the same shape as data if axis is not None.
+ If axis is None, the result is a 1-d array.
+ """
+ if dtype is None or dtype == "":
+ dtype = data.dtype
+
+ def maybe_cast(x):
+ if dtype != data.dtype:
+ return cast(x, dtype)
+ return x
+
+ axis_mul_before = 1
+ axis_mul_after = 1
+
+ if axis is None:
+ axis = 0
+ cumsum_axis_len = prod(data.shape)
+ shape = (cumsum_axis_len,)
+ else:
+ if not isinstance(axis, int):
+ axis = get_const_int(axis)
+
+ shape = data.shape
+ cumsum_axis_len = shape[axis]
+
+ if axis < 0:
+ axis = len(shape) + axis
+
+ for i, value in enumerate(shape, 0):
+ if i < axis:
+ axis_mul_before *= value
+ elif i > axis:
+ axis_mul_after *= value
+
+ def gen_ir(data_buf, out_buf):
+ ib = ir_builder.create()
+ data_buf = ib.buffer_ptr(data_buf)
+ out_buf = ib.buffer_ptr(out_buf)
+
+ with ib.for_range(0, axis_mul_before * axis_mul_after, "fused", kind="parallel") as fused:
+ i = fused // axis_mul_after
+ j = fused % axis_mul_after
+ base_idx = i * cumsum_axis_len * axis_mul_after + j
+ out_buf[base_idx] = maybe_cast(data_buf[base_idx])
+ with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k:
+ k = _k + 1
+ cur_idx = base_idx + k * axis_mul_after
+ prev_idx = base_idx + (k - 1) * axis_mul_after
+ out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx])
+
+ return ib.get()
+
+ out_buf = decl_buffer(shape, dtype, "out_buf")
+
+ return extern(
+ [shape],
+ [data],
+ lambda ins, outs: gen_ir(ins[0], outs[0]),
+ dtype=dtype,
+ out_buffers=[out_buf],
+ name="cumsum_generic",
+ tag="cumsum_generic",
+ )
diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py
index dfc226f..cd9f0c6 100644
--- a/python/tvm/topi/utils.py
+++ b/python/tvm/topi/utils.py
@@ -492,3 +492,8 @@ def is_empty_shape(shape):
def ceil_div(a, b):
"""Return ceil division of a by b"""
return tvm.tir.indexdiv(a + (b - 1), b)
+
+
+def swap(arr, axis):
+ """ swap arr[axis] and arr[-1] """
+ return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]]
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index ecfde35..0e868cd 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3674,5 +3674,57 @@ RELAY_REGISTER_OP("adv_index")
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);
+TVM_REGISTER_NODE_TYPE(CumsumAttrs);
+
+bool CumsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // types: [data, output]
+ ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output";
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ ICHECK(types[0].as<IncompleteTypeNode>())
+ << "cumsum: expect input type to be TensorType but get " << types[0];
+ return false;
+ }
+
+ const auto* param = attrs.as<CumsumAttrs>();
+
+ auto dtype = param->dtype;
+ if (dtype.is_void()) {
+ dtype = data->dtype;
+ }
+
+ if (param->axis.defined()) {
+ reporter->Assign(types[1], TensorType(data->shape, dtype));
+ } else {
+ auto prod = data->shape[0];
+ for (size_t i = 1; i < data->shape.size(); ++i) {
+ prod = prod * data->shape[i];
+ }
+ reporter->Assign(types[1], TensorType({prod}, dtype));
+ }
+
+ return true;
+}
+
+Expr MakeCumsum(Expr data, Integer axis, DataType dtype) {
+ auto attrs = make_object<CumsumAttrs>();
+ attrs->dtype = dtype;
+ attrs->axis = axis;
+ static const Op& op = Op::Get("cumsum");
+ return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.cumsum").set_body_typed(MakeCumsum);
+
+RELAY_REGISTER_OP("cumsum")
+ .describe(
+ R"doc(Return the cumulative sum of the elements along a given axis.)doc" TVM_ADD_FILELINE)
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_support_level(3)
+ .add_type_rel("Cumsum", CumsumRel)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque);
+
} // namespace relay
} // namespace tvm
diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu
index 4e3e3a8..7295d4c 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -275,9 +275,22 @@ void thrust_scan(DLTensor* data,
if (scan_size == 0) return;
- if (data->ndim == 1 || (data->ndim == 2 && data->shape[0] == 1)) {
- if (exclusive) {
+ size_t size = 1;
+ for (int i = 0; i < data->ndim; ++i) size *= data->shape[i];
+
+ const bool need_cast = std::is_same<InType, OutType>::value == false;
+
+ auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) {
+ return static_cast<OutType>(v);
+ }); // NOLINT(*)
+
+ if (size == static_cast<size_t>(data->shape[data->ndim - 1])) {
+ if (exclusive && need_cast) {
+ thrust::exclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
+ } else if (exclusive && !need_cast) {
thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
+ } else if (!exclusive && need_cast) {
+ thrust::inclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
} else {
thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
}
@@ -288,17 +301,19 @@ void thrust_scan(DLTensor* data,
// This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,...,
// without materializing the sequence vector
- auto counting_iter = thrust::counting_iterator<int64_t>(0);
+ auto counting_iter = thrust::counting_iterator<size_t>(0);
// Without __host__ annotation, cub crashes
- auto linear_index_to_scan_key = [scan_size] __host__ __device__(int64_t i) {
+ auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) {
return i / scan_size;
}; // NOLINT(*)
auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key);
- int64_t size = 1;
- for (int i = 0; i < data->ndim; ++i) size *= data->shape[i];
- if (exclusive) {
+ if (exclusive && need_cast) {
+ thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
+ } else if (exclusive && !need_cast) {
thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
+ } else if (!exclusive && need_cast) {
+ thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
} else {
thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
}
@@ -315,28 +330,62 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
auto in_dtype = DLDataType2String(data->dtype);
auto out_dtype = DLDataType2String(output->dtype);
- if (in_dtype == "int32") {
+ if (in_dtype == "bool") {
+ if (out_dtype == "int32") {
+ thrust_scan<bool, int>(data, output, exclusive);
+ } else if (out_dtype == "int64") {
+ thrust_scan<bool, int64_t>(data, output, exclusive);
+ } else if (out_dtype == "float32") {
+ thrust_scan<bool, float>(data, output, exclusive);
+ } else if (out_dtype == "float64") {
+ thrust_scan<bool, double>(data, output, exclusive);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtypes are int32, int64, float32, and float64";
+ }
+ } else if (in_dtype == "int32") {
if (out_dtype == "int32") {
thrust_scan<int, int>(data, output, exclusive);
} else if (out_dtype == "int64") {
thrust_scan<int, int64_t>(data, output, exclusive);
+ } else if (out_dtype == "float32") {
+ thrust_scan<int, float>(data, output, exclusive);
+ } else if (out_dtype == "float64") {
+ thrust_scan<int, double>(data, output, exclusive);
} else {
- LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtypes are int32, int64, float32, and float64";
}
} else if (in_dtype == "int64") {
if (out_dtype == "int64") {
thrust_scan<int64_t, int64_t>(data, output, exclusive);
+ } else if (out_dtype == "float32") {
+ thrust_scan<int64_t, float>(data, output, exclusive);
+ } else if (out_dtype == "float64") {
+ thrust_scan<int64_t, double>(data, output, exclusive);
} else {
- LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtypes are int64, float32, and float64";
}
} else if (in_dtype == "float32") {
if (out_dtype == "float32") {
thrust_scan<float, float>(data, output, exclusive);
+ } else if (out_dtype == "float64") {
+ thrust_scan<float, double>(data, output, exclusive);
} else {
- LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtypes are float32, and float64";
+ }
+ } else if (in_dtype == "float64") {
+ if (out_dtype == "float64") {
+ thrust_scan<double, double>(data, output, exclusive);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtype is float64";
}
} else {
- LOG(FATAL) << "Unsupported input dtype: " << in_dtype;
+ LOG(FATAL) << "Unsupported input dtype: " << in_dtype
+ << ". Supported input dtypes are bool, int32, int64, float32, and float64";
}
});
diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py
index 5f66d46..c5b6a29 100644
--- a/tests/python/contrib/test_thrust.py
+++ b/tests/python/contrib/test_thrust.py
@@ -59,7 +59,7 @@ def test_exclusive_scan():
print("skip because thrust is not enabled...")
return
- for ishape in [(1,), (10, 10)]:
+ for ishape in [(10,), (10, 10), (10, 10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")
with tvm.target.Target("cuda"):
@@ -75,7 +75,7 @@ def test_exclusive_scan():
if len(ishape) == 1:
reduction_shape = ()
else:
- reduction_shape = (ishape[0],)
+ reduction_shape = ishape[:-1]
reduction_np_out = np.zeros(reduction_shape, np.int32)
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 5e44170..559eb24 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1311,6 +1311,7 @@ def test_sparse_to_dense():
# verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+@tvm.testing.uses_gpu
def test_adv_index():
def verify_adv_index(data_shape, index_shapes):
dtype = "float32"
@@ -1342,6 +1343,40 @@ def test_adv_index():
verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
+@tvm.testing.parametrize_targets
+def test_cumsum(target, ctx):
+ def verify_cumsum(data_np, np_out, axis=None, out_dtype=None, rtol=1e-5, atol=1e-5):
+ inp = relay.var("data", relay.TensorType(data_np.shape, str(data_np.dtype)))
+
+ out = relay.op.cumsum(inp, axis, out_dtype)
+ func = relay.Function([inp], out)
+
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(data_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=rtol, atol=atol)
+
+ data = np.array([2, 3, 0])
+ verify_cumsum(data, np.cumsum(data))
+ verify_cumsum(data, np.cumsum(data), out_dtype="int64")
+
+ data = np.random.randn(10, 10)
+ verify_cumsum(data, np.cumsum(data))
+ verify_cumsum(data, np.cumsum(data, axis=0), axis=0)
+ verify_cumsum(data, np.cumsum(data, axis=1), axis=1)
+
+ data = np.random.randn(10, 5, 10).astype("float32")
+ verify_cumsum(data, np.cumsum(data), rtol=1e-4, atol=1e-4)
+ verify_cumsum(data, np.cumsum(data, axis=0), axis=0, rtol=1e-4, atol=1e-4)
+ verify_cumsum(data, np.cumsum(data, axis=1), axis=1, rtol=1e-4, atol=1e-4)
+ verify_cumsum(data, np.cumsum(data, axis=-1), axis=-1, rtol=1e-4, atol=1e-4)
+
+ data = np.random.rand(10) > 0.5
+ data = data.astype(np.int32)
+ verify_cumsum(data, np.cumsum(data, dtype=np.int32))
+ verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64")
+
+
if __name__ == "__main__":
test_cast()
test_zeros_ones()
@@ -1379,3 +1414,4 @@ if __name__ == "__main__":
test_sparse_to_dense()
test_fixed_point_multiply()
test_adv_index()
+ test_cumsum()
diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py
new file mode 100644
index 0000000..a01a496
--- /dev/null
+++ b/tests/python/topi/python/test_topi_cumsum.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import topi
+import tvm.topi.testing
+
+
+@tvm.testing.parametrize_targets
+def test_cumsum(ctx, target):
+ def check_cumsum(np_ref, data, axis=None, dtype=None):
+ implementations = {
+ "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern),
+ "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+ "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+ }
+ fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
+ tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
+
+ data = np.array([2, 3, 0])
+ check_cumsum(np.cumsum(data), data)
+
+ data = np.random.rand(10) > 0.5
+ data = data.astype(np.int32)
+ check_cumsum(np.cumsum(data, dtype=np.int32), data)
+ check_cumsum(np.cumsum(data), data, dtype="int64")
+
+ data = np.random.rand(10) > 0.5
+ check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
+
+ for in_dtype in ["float32", "float64"]:
+ data = np.random.randn(10, 10).astype(in_dtype)
+ check_cumsum(np.cumsum(data), data)
+ check_cumsum(np.cumsum(data, axis=0), data, axis=0)
+ check_cumsum(np.cumsum(data, axis=1), data, axis=1)
+
+ data = np.random.randn(10, 5, 10).astype(in_dtype)
+ check_cumsum(np.cumsum(data), data)
+ check_cumsum(np.cumsum(data, axis=0), data, axis=0)
+ check_cumsum(np.cumsum(data, axis=1), data, axis=1)
+ check_cumsum(np.cumsum(data, axis=-1), data, axis=-1)
+
+ for in_dtype in ["int32", "int64"]:
+ data = np.random.randint(-100, 100, size=(100, 100)).astype(in_dtype)
+ check_cumsum(np.cumsum(data, dtype=in_dtype), data)
+ check_cumsum(np.cumsum(data), data, dtype="int64")
+ check_cumsum(np.cumsum(data, axis=0, dtype=in_dtype), data, axis=0)
+ check_cumsum(np.cumsum(data, axis=1, dtype=in_dtype), data, axis=1)
+
+ data = np.random.randint(1 << 30, (1 << 31) - 1, size=(100)).astype(in_dtype)
+ check_cumsum(np.cumsum(data), data, dtype="int64")
+
+
+if __name__ == "__main__":
+ test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
+ test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
+ test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))