You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/09 12:14:23 UTC
[tvm] branch unity updated: [Unity] DefaultGPUSchedule working for targets other than CUDA (#14540)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 97f4db0f91 [Unity] DefaultGPUSchedule working for targets other than CUDA (#14540)
97f4db0f91 is described below
commit 97f4db0f91521a28f0744a89ddd5a799dca5ed3b
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Sun Apr 9 08:14:14 2023 -0400
[Unity] DefaultGPUSchedule working for targets other than CUDA (#14540)
Previously the DefaultGPUSchedule only works for CUDA. This PR enables
it for other GPU targets like Metal.
---
src/tir/transforms/default_gpu_schedule.cc | 4 ---
.../test_transform_default_gpu_schedule.py | 38 ++++++++++++++++++++++
2 files changed, 38 insertions(+), 4 deletions(-)
diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc
index 8666d7eb47..2b56dda0d6 100644
--- a/src/tir/transforms/default_gpu_schedule.cc
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -80,10 +80,6 @@ Pass DefaultGPUSchedule() {
// get the target from context.
tvm::Target target = tvm::Target::Current();
ICHECK(target.defined()) << "Target is not set in current context";
- // skip non-cuda targets.
- if (target->kind->name != "cuda") {
- return m;
- }
// get the max thread per block from target.
Optional<Integer> opt_max_thread_per_block = target->GetAttr<Integer>("max_num_threads");
ICHECK(opt_max_thread_per_block.defined())
diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py b/tests/python/unittest/test_transform_default_gpu_schedule.py
index 644a9aede0..2503a7009d 100644
--- a/tests/python/unittest/test_transform_default_gpu_schedule.py
+++ b/tests/python/unittest/test_transform_default_gpu_schedule.py
@@ -413,5 +413,43 @@ def test_multiple():
assert tvm.ir.structural_equal(After, Expected)
+def test_add_on_metal():
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)):
+ with T.block("T_add"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+ T.writes(T_add[ax0, ax1, ax2, ax3])
+ T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
+ for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(72), thread="threadIdx.x"):
+ with T.block("T_add"):
+ ax0 = T.axis.spatial(T.int64(4), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) // T.int64(18))
+ ax1 = T.axis.spatial(T.int64(3), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(18) // T.int64(6))
+ ax2 = T.axis.spatial(T.int64(2), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(6) // T.int64(3))
+ ax3 = T.axis.spatial(T.int64(3), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(3))
+ T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+ T.writes(T_add[ax0, ax1, ax2, ax3])
+ T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("apple/m1-gpu")
+ with target, tvm.transform.PassContext(opt_level=0):
+ mod = DefaultGPUSchedule()(Before)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()