You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/07/07 21:06:07 UTC

[tvm] branch main updated: [MetaSchedule][Testing] Test search space of conv1d (#12032)

This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 99d42b2238 [MetaSchedule][Testing] Test search space of conv1d (#12032)
99d42b2238 is described below

commit 99d42b22382d19cfd2e2e1ec65d92f1fe41e4c10
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Jul 7 14:05:59 2022 -0700

    [MetaSchedule][Testing] Test search space of conv1d (#12032)
    
    * [MetaSchedule][Testing] Test search space of conv1d
    
    * Add checks for trace roundtripping
---
 .../tvm/meta_schedule/testing/space_generation.py  |  65 +++++++++++-
 .../unittest/test_meta_schedule_space_cuda.py      | 115 +++++++++++++++++++++
 2 files changed, 179 insertions(+), 1 deletion(-)

diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py
index 10e31e7213..2d846e244a 100644
--- a/python/tvm/meta_schedule/testing/space_generation.py
+++ b/python/tvm/meta_schedule/testing/space_generation.py
@@ -15,10 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
-from typing import List
+from typing import List, Optional, Tuple
 
+from tvm.ir import IRModule, structural_equal
 from tvm.tir import Schedule
 from tvm.tir.schedule import Trace
+from tvm.tir.schedule.testing import verify_trace_roundtrip
 
 
 def check_trace(spaces: List[Schedule], expected: List[List[str]]):
@@ -31,3 +33,64 @@ def check_trace(spaces: List[Schedule], expected: List[List[str]]):
         actual_traces.add(str_trace)
         assert str_trace in expected_traces, "\n" + str_trace
     assert len(expected_traces) == len(actual_traces)
+
+
+def _find_match_sketch_id(
+    mod: IRModule,
+    sketches: List[Schedule],
+    expected_mod: IRModule,
+    expected_decision: List[Tuple[str, List[int]]],
+) -> Optional[int]:
+    for sketch_id, sketch in enumerate(sketches):
+        i = 0
+        new_decisions = {}
+        for inst in sketch.trace.insts:
+            if not inst.kind.name.startswith("Sample"):
+                continue
+            assert i < len(expected_decision)
+            if inst.kind.name == expected_decision[i][0]:
+                new_decisions[inst] = expected_decision[i][1]
+                i += 1
+        if len(new_decisions) != len(expected_decision):
+            continue
+        sch = Schedule(mod, debug_mask="all")
+        Trace(
+            insts=sketch.trace.insts,
+            decisions=new_decisions,
+        ).apply_to_schedule(sch, remove_postproc=True)
+        if structural_equal(sch.mod, expected_mod):
+            verify_trace_roundtrip(sch=sch, mod=mod)
+            return sketch_id
+    return None
+
+
+def check_sketches(
+    mod: IRModule,
+    sketches: List[Schedule],
+    expected_mods: List[IRModule],
+    expected_decisions: List[List[Tuple[str, List[int]]]],
+):
+    assert len(expected_mods) == len(expected_decisions)
+    assert len(sketches) == len(expected_mods)
+    expected_mods = [
+        IRModule({"main": m}) if not isinstance(m, IRModule) else m for m in expected_mods
+    ]
+    sketches = list(sketches)
+    for expected_id, (expected_mod, expected_decision) in enumerate(
+        zip(expected_mods, expected_decisions)
+    ):
+        sketch_id = _find_match_sketch_id(mod, sketches, expected_mod, expected_decision)
+        if sketch_id is None:
+            raise AssertionError(
+                f"Expected sketch #{expected_id} doesn't exist in the generated sketches."
+            )
+        sketches.pop(sketch_id)
+
+
+def print_sketches(sketches: List[Schedule]):
+    for i, sch in enumerate(sketches):
+        print(f"###### {i}")
+        print(sch.mod.script())
+        for inst in sch.trace.insts:
+            if inst in sch.trace.decisions:
+                print(f'("{inst.kind.name}", {sch.trace.decisions[inst]}),')
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
new file mode 100644
index 0000000000..e2c324cfda
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for MetaSchedule search space on CUDA"""
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.space_generation import check_sketches
+from tvm.meta_schedule.testing.te_workload import create_te_workload
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def _target():
+    return Target("nvidia/geforce-rtx-3070")
+
+
+def test_cuda_c1d():
+    # fmt: off
+    @T.prim_func
+    def mod_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.unroll_explicit":16})
+            conv1d_nlc_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
+            PadInput_shared = T.alloc_buffer([1, 258, 64], dtype="float32", scope="shared")
+            weight_shared = T.alloc_buffer([3, 64, 128], dtype="float32", scope="shared")
+            for i0_0_i1_0_i2_0_fused in T.thread_binding(4, thread="blockIdx.x"):
+                for i0_1_i1_1_i2_1_fused in T.thread_binding(16, thread="vthread.x"):
+                    for i0_2_i1_2_i2_2_fused in T.thread_binding(4, thread="threadIdx.x"):
+                        for i3_0, i4_0 in T.grid(1, 16):
+                            for ax0_ax1_ax2_fused in T.serial(260):
+                                with T.block("PadInput_shared"):
+                                    v0 = T.axis.spatial(1, 0)
+                                    v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
+                                    v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4)
+                                    T.reads(inputs[v0, v1 - 1, v2])
+                                    T.writes(PadInput_shared[v0, v1, v2])
+                                    T.block_attr({"meta_schedule.cooperative_fetch":4})
+                                    PadInput_shared[v0, v1, v2] = T.if_then_else(1 <= v1 and v1 < 257, inputs[v0, v1 - 1, v2], T.float32(0), dtype="float32")
+                            for ax0_ax1_ax2_fused in T.serial(1536):
+                                with T.block("weight_shared"):
+                                    v0 = T.axis.spatial(3, ax0_ax1_ax2_fused // 512)
+                                    v1 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 512 // 128)
+                                    v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128)
+                                    T.reads(weight[v0, v1, v2])
+                                    T.writes(weight_shared[v0, v1, v2])
+                                    T.block_attr({"meta_schedule.cooperative_fetch":3})
+                                    weight_shared[v0, v1, v2] = weight[v0, v1, v2]
+                            for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
+                                with T.block("conv1d_nlc"):
+                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
+                                    l = T.axis.spatial(128, (i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 + i1_4)
+                                    co = T.axis.spatial(128, (((0 * 2 + i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + i2_4)
+                                    rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + i3_2)
+                                    rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 2 + i4_2)
+                                    T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co])
+                                    T.writes(conv1d_nlc_local[n, l, co])
+                                    T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
+                                    with T.init():
+                                        conv1d_nlc_local[n, l, co] = T.float32(0)
+                                    conv1d_nlc_local[n, l, co] = conv1d_nlc_local[n, l, co] + PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc] * weight_shared[rl, rc, co]
+                        for ax0, ax1, ax2 in T.grid(1, 4, 16):
+                            with T.block("conv1d_nlc_local"):
+                                v0 = T.axis.spatial(1, ax0)
+                                v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + ax1)
+                                v2 = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + ax2)
+                                T.reads(conv1d_nlc_local[v0, v1, v2])
+                                T.writes(conv1d_nlc[v0, v1, v2])
+                                conv1d_nlc[v0, v1, v2] = conv1d_nlc_local[v0, v1, v2]
+    # fmt: on
+
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [4, 8, 1, 1, 4]),
+        ("SamplePerfectTile", [1, 2, 4, 2, 8]),
+        ("SamplePerfectTile", [1, 1, 3]),
+        ("SamplePerfectTile", [16, 2, 2]),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 1),
+    ]
+
+    mod = create_te_workload("C1D", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[mod_0],
+        expected_decisions=[decision_0],
+    )
+
+
+if __name__ == "__main__":
+    test_cuda_c1d()