You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/09 12:14:23 UTC

[tvm] branch unity updated: [Unity] DefaultGPUSchedule working for targets other than CUDA (#14540)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 97f4db0f91 [Unity] DefaultGPUSchedule working for targets other than CUDA (#14540)
97f4db0f91 is described below

commit 97f4db0f91521a28f0744a89ddd5a799dca5ed3b
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Sun Apr 9 08:14:14 2023 -0400

    [Unity] DefaultGPUSchedule working for targets other than CUDA (#14540)
    
    Previously the DefaultGPUSchedule only works for CUDA. This PR enables
    it for other GPU targets like Metal.
---
 src/tir/transforms/default_gpu_schedule.cc         |  4 ---
 .../test_transform_default_gpu_schedule.py         | 38 ++++++++++++++++++++++
 2 files changed, 38 insertions(+), 4 deletions(-)

diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc
index 8666d7eb47..2b56dda0d6 100644
--- a/src/tir/transforms/default_gpu_schedule.cc
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -80,10 +80,6 @@ Pass DefaultGPUSchedule() {
         // get the target from context.
         tvm::Target target = tvm::Target::Current();
         ICHECK(target.defined()) << "Target is not set in current context";
-        // skip non-cuda targets.
-        if (target->kind->name != "cuda") {
-          return m;
-        }
         // get the max thread per block from target.
         Optional<Integer> opt_max_thread_per_block = target->GetAttr<Integer>("max_num_threads");
         ICHECK(opt_max_thread_per_block.defined())
diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py b/tests/python/unittest/test_transform_default_gpu_schedule.py
index 644a9aede0..2503a7009d 100644
--- a/tests/python/unittest/test_transform_default_gpu_schedule.py
+++ b/tests/python/unittest/test_transform_default_gpu_schedule.py
@@ -413,5 +413,43 @@ def test_multiple():
     assert tvm.ir.structural_equal(After, Expected)
 
 
+def test_add_on_metal():
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)):
+                with T.block("T_add"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+                    T.writes(T_add[ax0, ax1, ax2, ax3])
+                    T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
+                for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(72), thread="threadIdx.x"):
+                    with T.block("T_add"):
+                        ax0 = T.axis.spatial(T.int64(4), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) // T.int64(18))
+                        ax1 = T.axis.spatial(T.int64(3), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(18) // T.int64(6))
+                        ax2 = T.axis.spatial(T.int64(2), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(6) // T.int64(3))
+                        ax3 = T.axis.spatial(T.int64(3), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(3))
+                        T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)])
+                        T.writes(T_add[ax0, ax1, ax2, ax3])
+                        T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("apple/m1-gpu")
+    with target, tvm.transform.PassContext(opt_level=0):
+        mod = DefaultGPUSchedule()(Before)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()