You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/03 14:17:20 UTC

(tvm) branch main updated: [TOPI] improve inclusive_scan for thrust (#16652)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b25588926 [TOPI] improve inclusive_scan for thrust (#16652)
3b25588926 is described below

commit 3b255889262d856efb31fc0b362ac1be57d5d1ea
Author: Yong Wu <yo...@gmail.com>
AuthorDate: Sun Mar 3 06:17:13 2024 -0800

    [TOPI] improve inclusive_scan for thrust (#16652)
    
    Fix comments
---
 python/tvm/topi/cuda/scan.py | 42 +++++++++++++++++++++++++++++++++---------
 1 file changed, 33 insertions(+), 9 deletions(-)

diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py
index 238163722f..4b1bac0529 100644
--- a/python/tvm/topi/cuda/scan.py
+++ b/python/tvm/topi/cuda/scan.py
@@ -35,6 +35,21 @@ def _get_thrust_func_name(tvmop):
     return tvmop_to_thrust_func_name[tvmop]
 
 
+def _can_use_scan_thrust(binop):
+    """
+    Check if scan_thrust can be utilized based on the current target and binary op.
+    """
+    target = tvm.target.Target.current()
+    if target is None:
+        return False
+    return binop == tvm.tir.generic.add and any(
+        [
+            can_use_thrust(target, "tvm.contrib.thrust.sum_scan"),
+            can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan"),
+        ]
+    )
+
+
 def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, identity_value=0):
     """Low level IR to do exclusive sum scan along rows of 2D input.
 
@@ -363,17 +378,9 @@ def exclusive_scan(
     """
 
     def do_scan(data, output_dtype):
-        target = tvm.target.Target.current()
 
         # TODO: add support for a prod_scan
-        if (
-            target
-            and binop == tvm.tir.generic.add
-            and (
-                can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
-                or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
-            )
-        ):
+        if _can_use_scan_thrust(binop):
             return scan_thrust(
                 data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
             )
@@ -479,6 +486,23 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add,
     output : tvm.te.Tensor
         A N-D tensor of the same rank N as the input data.
     """
+
+    if _can_use_scan_thrust(binop):
+        if output_dtype is None or output_dtype == "":
+            output_dtype = data.dtype
+        ndim = len(data.shape)
+        if axis < 0:
+            axis += ndim
+
+        if axis != ndim - 1:
+            axes = swap(list(range(ndim)), axis)
+            data = transpose(data, axes)
+        output = scan_thrust(data, output_dtype, exclusive=False, binop=binop)
+        if axis != ndim - 1:
+            axes = swap(list(range(ndim)), axis)
+            output = transpose(output, axes)
+        return output
+
     ex_scan = exclusive_scan(
         data, axis, output_dtype=output_dtype, binop=binop, identity_value=identity_value
     )