You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/07/05 06:21:45 UTC

[tvm] branch unity updated: [Unity][TIR][Transform] Support no spatial axes cases for DefaultGPUSchedule (#15232)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c45f72b4cc [Unity][TIR][Transform] Support no spatial axes cases for DefaultGPUSchedule (#15232)
c45f72b4cc is described below

commit c45f72b4cc4943b9c2af16f93006fbd304028d25
Author: Yixin Dong <ub...@gmail.com>
AuthorDate: Wed Jul 5 14:21:39 2023 +0800

    [Unity][TIR][Transform] Support no spatial axes cases for DefaultGPUSchedule (#15232)
    
    This PR adds support for no spatial axes cases for DefaultGPUSchedule.
    
    To be specific, this PR will add this logic
    - Find all innermost loops in a tir function
    - If a such loop does not have data parallel loop, add a dummy loop for it
---
 src/tir/transforms/default_gpu_schedule.cc         | 22 +++---
 .../test_transform_default_gpu_schedule.py         | 89 ++++++++++++----------
 2 files changed, 59 insertions(+), 52 deletions(-)

diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc
index 5ee9f8995f..5a22d0b0d9 100644
--- a/src/tir/transforms/default_gpu_schedule.cc
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -41,15 +41,9 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread
   }
   Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
 
-  // special check for no-loop case
-  // in te/operation/create_primfunc.cc:L321, it will create a dummy iter var
-  // which makes loops.size() == 0 and iters.size() == 1.
-  if (loops.size() == 0 && iters.size() == 1) {
-    auto loop = sch->AddUnitLoop(block);
-    loops.push_back(loop);
-  }
-
-  ICHECK_EQ(loops.size(), iters.size());
+  // when there is no loops, tir will add a dummy iter var for the block
+  // so loops.size() == 0 && iters.size() == 1
+  ICHECK(loops.size() == iters.size() || (loops.size() == 0 && iters.size() == 1));
 
   Array<tir::LoopRV> data_parallel_loops;
   // only fuse data parallel loops
@@ -58,9 +52,11 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread
       data_parallel_loops.push_back(loops[i]);
     }
   }
-  // skip if no data parallel loops
+
+  // Add a dummy loop if there is no data parallel loops
   if (data_parallel_loops.size() == 0) {
-    return;
+    data_parallel_loops.push_back(loops.empty() ? sch->AddUnitLoop(block)
+                                                : sch->AddUnitLoop(loops[0]));
   }
   // fuse all data parallel loops
   tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false);
@@ -123,6 +119,10 @@ Pass DefaultGPUSchedule() {
             sch->WorkOn(gv->name_hint);
             Array<tir::BlockRV> blocks = meta_schedule::BlockCollector::Collect(sch);
             for (const tir::BlockRV& block : blocks) {
+              auto childs = sch->GetChildBlocks(block);
+              if (!childs.empty()) {
+                continue;
+              }
               ThreadBind(sch, block, max_thread_per_block);
             }
           }
diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py b/tests/python/unittest/test_transform_default_gpu_schedule.py
index 622d1e9322..1af846c9d5 100644
--- a/tests/python/unittest/test_transform_default_gpu_schedule.py
+++ b/tests/python/unittest/test_transform_default_gpu_schedule.py
@@ -46,49 +46,17 @@ def test_broadcast_to_symbolic():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func
-        def broadcast_to(
-            rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
-            var_T_broadcast_to: T.handle,
-        ):
-            T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
-            x_0 = T.int64()
-            x_1 = T.int64()
+        def broadcast_to(rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), var_T_broadcast_to: T.handle):
+            T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
+            x_0, x_1 = T.int64(), T.int64()
             T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
-            # with T.block("root"):
             for ax0_ax1_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
-                for ax0_ax1_fused_2 in T.thread_binding(
-                    T.int64(1024), thread="threadIdx.x"
-                ):
-                    for ax0_ax1_fused_0 in range(
-                        (x_0 * x_1 + T.int64(262143)) // T.int64(262144)
-                    ):
+                for ax0_ax1_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
+                    for ax0_ax1_fused_0 in range((x_0 * x_1 + T.int64(262143)) // T.int64(262144)):
                         with T.block("T_broadcast_to"):
-                            v_ax0 = T.axis.spatial(
-                                x_0,
-                                (
-                                    (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1)
-                                    * T.int64(1024)
-                                    + ax0_ax1_fused_2
-                                )
-                                // x_1,
-                            )
-                            v_ax1 = T.axis.spatial(
-                                x_1,
-                                (
-                                    (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1)
-                                    * T.int64(1024)
-                                    + ax0_ax1_fused_2
-                                )
-                                % x_1,
-                            )
-                            T.where(
-                                (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1)
-                                * T.int64(1024)
-                                + ax0_ax1_fused_2
-                                < x_0 * x_1
-                            )
-                            T.reads(rxplaceholder[v_ax0, T.int64(0)])
-                            T.writes(T_broadcast_to[v_ax0, v_ax1])
+                            v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % (x_1 * x_0) // x_1)
+                            v_ax1 = T.axis.spatial(x_1, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % x_1)
+                            T.where((ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) * T.int64(1024) + ax0_ax1_fused_2 < x_0 * x_1)
                             T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, T.int64(0)]
     # fmt: on
     # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
@@ -432,7 +400,7 @@ def test_add_on_metal():
     class Expected:
         @T.prim_func
         def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
-            T.func_attr({"tir.noalias": T.bool(True)})
+            T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
             for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
                 for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(72), thread="threadIdx.x"):
                     with T.block("T_add"):
@@ -486,5 +454,44 @@ def test_scalar_add():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_sum():
+    # sum has two reduction axes and no spatial axis
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")):
+            for k0, k1 in T.grid(T.int64(2), T.int64(2)):
+                with T.block("A_red"):
+                    v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
+                    with T.init():
+                        A_red[()] = T.float64(0)
+                    A_red[()] = A_red[()] + A[v_k0, v_k1]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")):
+            T.func_attr({"tir.is_scheduled": T.bool(True)})
+            # with T.block("root"):
+            for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+                for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
+                    for k0, k1 in T.grid(T.int64(2), T.int64(2)):
+                        with T.block("A_red"):
+                            v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
+                            T.reads(A[v_k0, v_k1])
+                            T.writes(A_red[()])
+                            with T.init():
+                                A_red[()] = T.float64(0)
+                            A_red[()] = A_red[()] + A[v_k0, v_k1]
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    with target, tvm.transform.PassContext(opt_level=0):
+        mod = DefaultGPUSchedule()(Before)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()