You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/03 14:17:20 UTC
(tvm) branch main updated: [TOPI] improve inclusive_scan for thrust (#16652)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3b25588926 [TOPI] improve inclusive_scan for thrust (#16652)
3b25588926 is described below
commit 3b255889262d856efb31fc0b362ac1be57d5d1ea
Author: Yong Wu <yo...@gmail.com>
AuthorDate: Sun Mar 3 06:17:13 2024 -0800
[TOPI] improve inclusive_scan for thrust (#16652)
Fix comments
---
python/tvm/topi/cuda/scan.py | 42 +++++++++++++++++++++++++++++++++---------
1 file changed, 33 insertions(+), 9 deletions(-)
diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py
index 238163722f..4b1bac0529 100644
--- a/python/tvm/topi/cuda/scan.py
+++ b/python/tvm/topi/cuda/scan.py
@@ -35,6 +35,21 @@ def _get_thrust_func_name(tvmop):
return tvmop_to_thrust_func_name[tvmop]
+def _can_use_scan_thrust(binop):
+ """
+ Check if scan_thrust can be utilized based on the current target and binary op.
+ """
+ target = tvm.target.Target.current()
+ if target is None:
+ return False
+ return binop == tvm.tir.generic.add and any(
+ [
+ can_use_thrust(target, "tvm.contrib.thrust.sum_scan"),
+ can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan"),
+ ]
+ )
+
+
def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, identity_value=0):
"""Low level IR to do exclusive sum scan along rows of 2D input.
@@ -363,17 +378,9 @@ def exclusive_scan(
"""
def do_scan(data, output_dtype):
- target = tvm.target.Target.current()
# TODO: add support for a prod_scan
- if (
- target
- and binop == tvm.tir.generic.add
- and (
- can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
- or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
- )
- ):
+ if _can_use_scan_thrust(binop):
return scan_thrust(
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
)
@@ -479,6 +486,23 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add,
output : tvm.te.Tensor
A N-D tensor of the same rank N as the input data.
"""
+
+ if _can_use_scan_thrust(binop):
+ if output_dtype is None or output_dtype == "":
+ output_dtype = data.dtype
+ ndim = len(data.shape)
+ if axis < 0:
+ axis += ndim
+
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ data = transpose(data, axes)
+ output = scan_thrust(data, output_dtype, exclusive=False, binop=binop)
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ output = transpose(output, axes)
+ return output
+
ex_scan = exclusive_scan(
data, axis, output_dtype=output_dtype, binop=binop, identity_value=identity_value
)