You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/07/05 06:21:45 UTC
[tvm] branch unity updated: [Unity][TIR][Transform] Support no spatial axes cases for DefaultGPUSchedule (#15232)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c45f72b4cc [Unity][TIR][Transform] Support no spatial axes cases for DefaultGPUSchedule (#15232)
c45f72b4cc is described below
commit c45f72b4cc4943b9c2af16f93006fbd304028d25
Author: Yixin Dong <ub...@gmail.com>
AuthorDate: Wed Jul 5 14:21:39 2023 +0800
[Unity][TIR][Transform] Support no spatial axes cases for DefaultGPUSchedule (#15232)
This PR adds support for no spatial axes cases for DefaultGPUSchedule.
To be specific, this PR will add this logic
- Find all innermost loops in a tir function
- If a such loop does not have data parallel loop, add a dummy loop for it
---
src/tir/transforms/default_gpu_schedule.cc | 22 +++---
.../test_transform_default_gpu_schedule.py | 89 ++++++++++++----------
2 files changed, 59 insertions(+), 52 deletions(-)
diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc
index 5ee9f8995f..5a22d0b0d9 100644
--- a/src/tir/transforms/default_gpu_schedule.cc
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -41,15 +41,9 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread
}
Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
- // special check for no-loop case
- // in te/operation/create_primfunc.cc:L321, it will create a dummy iter var
- // which makes loops.size() == 0 and iters.size() == 1.
- if (loops.size() == 0 && iters.size() == 1) {
- auto loop = sch->AddUnitLoop(block);
- loops.push_back(loop);
- }
-
- ICHECK_EQ(loops.size(), iters.size());
+ // when there is no loops, tir will add a dummy iter var for the block
+ // so loops.size() == 0 && iters.size() == 1
+ ICHECK(loops.size() == iters.size() || (loops.size() == 0 && iters.size() == 1));
Array<tir::LoopRV> data_parallel_loops;
// only fuse data parallel loops
@@ -58,9 +52,11 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread
data_parallel_loops.push_back(loops[i]);
}
}
- // skip if no data parallel loops
+
+ // Add a dummy loop if there is no data parallel loops
if (data_parallel_loops.size() == 0) {
- return;
+ data_parallel_loops.push_back(loops.empty() ? sch->AddUnitLoop(block)
+ : sch->AddUnitLoop(loops[0]));
}
// fuse all data parallel loops
tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false);
@@ -123,6 +119,10 @@ Pass DefaultGPUSchedule() {
sch->WorkOn(gv->name_hint);
Array<tir::BlockRV> blocks = meta_schedule::BlockCollector::Collect(sch);
for (const tir::BlockRV& block : blocks) {
+ auto childs = sch->GetChildBlocks(block);
+ if (!childs.empty()) {
+ continue;
+ }
ThreadBind(sch, block, max_thread_per_block);
}
}
diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py b/tests/python/unittest/test_transform_default_gpu_schedule.py
index 622d1e9322..1af846c9d5 100644
--- a/tests/python/unittest/test_transform_default_gpu_schedule.py
+++ b/tests/python/unittest/test_transform_default_gpu_schedule.py
@@ -46,49 +46,17 @@ def test_broadcast_to_symbolic():
@tvm.script.ir_module
class Expected:
@T.prim_func
- def broadcast_to(
- rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
- var_T_broadcast_to: T.handle,
- ):
- T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
- x_0 = T.int64()
- x_1 = T.int64()
+ def broadcast_to(rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), var_T_broadcast_to: T.handle):
+ T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
+ x_0, x_1 = T.int64(), T.int64()
T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
- # with T.block("root"):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
- for ax0_ax1_fused_2 in T.thread_binding(
- T.int64(1024), thread="threadIdx.x"
- ):
- for ax0_ax1_fused_0 in range(
- (x_0 * x_1 + T.int64(262143)) // T.int64(262144)
- ):
+ for ax0_ax1_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
+ for ax0_ax1_fused_0 in range((x_0 * x_1 + T.int64(262143)) // T.int64(262144)):
with T.block("T_broadcast_to"):
- v_ax0 = T.axis.spatial(
- x_0,
- (
- (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1)
- * T.int64(1024)
- + ax0_ax1_fused_2
- )
- // x_1,
- )
- v_ax1 = T.axis.spatial(
- x_1,
- (
- (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1)
- * T.int64(1024)
- + ax0_ax1_fused_2
- )
- % x_1,
- )
- T.where(
- (ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1)
- * T.int64(1024)
- + ax0_ax1_fused_2
- < x_0 * x_1
- )
- T.reads(rxplaceholder[v_ax0, T.int64(0)])
- T.writes(T_broadcast_to[v_ax0, v_ax1])
+ v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % (x_1 * x_0) // x_1)
+ v_ax1 = T.axis.spatial(x_1, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % x_1)
+ T.where((ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) * T.int64(1024) + ax0_ax1_fused_2 < x_0 * x_1)
T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, T.int64(0)]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
@@ -432,7 +400,7 @@ def test_add_on_metal():
class Expected:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": T.bool(True)})
+ T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(72), thread="threadIdx.x"):
with T.block("T_add"):
@@ -486,5 +454,44 @@ def test_scalar_add():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_sum():
+ # sum has two reduction axes and no spatial axis
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")):
+ for k0, k1 in T.grid(T.int64(2), T.int64(2)):
+ with T.block("A_red"):
+ v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
+ with T.init():
+ A_red[()] = T.float64(0)
+ A_red[()] = A_red[()] + A[v_k0, v_k1]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")):
+ T.func_attr({"tir.is_scheduled": T.bool(True)})
+ # with T.block("root"):
+ for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+ for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
+ for k0, k1 in T.grid(T.int64(2), T.int64(2)):
+ with T.block("A_red"):
+ v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
+ T.reads(A[v_k0, v_k1])
+ T.writes(A_red[()])
+ with T.init():
+ A_red[()] = T.float64(0)
+ A_red[()] = A_red[()] + A[v_k0, v_k1]
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ with target, tvm.transform.PassContext(opt_level=0):
+ mod = DefaultGPUSchedule()(Before)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()