You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/09/13 09:20:37 UTC

[tvm] branch main updated: [MetaSchedule][Test] Migrate `check_trace` to `check_sketch` (#12764)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ef784d68e0 [MetaSchedule][Test] Migrate `check_trace` to `check_sketch` (#12764)
ef784d68e0 is described below

commit ef784d68e04ab4b858ce4c953b2d83b5d5811eda
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Tue Sep 13 02:20:30 2022 -0700

    [MetaSchedule][Test] Migrate `check_trace` to `check_sketch` (#12764)
    
    * Migrate AutoBind
    
    * Migrate RandomComputeLocation
    
    * Migrate CrossThreadReduction
    
    * Migrate ParallelVectorizeUnroll
---
 python/tvm/meta_schedule/testing/schedule_rule.py  |  48 +-
 .../test_meta_schedule_schedule_rule_auto_bind.py  | 175 +++---
 ...chedule_schedule_rule_cross_thread_reduction.py | 665 ++++++++++++++++-----
 ...dule_schedule_rule_parallel_vectorize_unroll.py | 111 ++--
 ...hedule_schedule_rule_random_compute_location.py |  72 ++-
 5 files changed, 718 insertions(+), 353 deletions(-)

diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index b08db0811d..12ca4200d7 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -18,28 +18,15 @@
 from typing import List, Union
 
 from tvm.meta_schedule.schedule_rule import (
-    AutoBind,
     AutoInline,
-    CrossThreadReduction,
     MultiLevelTiling,
-    ParallelizeVectorizeUnroll,
-    RandomComputeLocation,
+    MultiLevelTilingTensorCore,
     ReuseType,
     ScheduleRule,
 )
-from tvm.meta_schedule.schedule_rule.multi_level_tiling import (
-    MultiLevelTilingTensorCore,
-)
 from tvm.target import Target
 
 
-def auto_bind(target: Target) -> ScheduleRule:
-    """Default schedule rules for auto bind"""
-    if target.kind.name == "cuda":
-        return AutoBind(max_threadblocks=256, thread_extents=[32, 64, 128, 256, 512, 1024])
-    raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
 def auto_inline(target: Target) -> ScheduleRule:
     """Default schedule rules for auto inline"""
     if target.kind.name == "llvm":
@@ -65,13 +52,6 @@ def auto_inline(target: Target) -> ScheduleRule:
     raise NotImplementedError(f"{target.kind.name} is not supported")
 
 
-def cross_thread_reduction(target: Target) -> ScheduleRule:
-    """Default schedule rules for with cross-thread reduction"""
-    if target.kind.name == "cuda":
-        return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
-    raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
 def multi_level_tiling(target: Target) -> ScheduleRule:
     """Default schedule rules for with multi-level tiling and reuse"""
     if target.kind.name == "llvm":
@@ -154,29 +134,3 @@ def multi_level_tiling_tensor_core(
             use_software_pipeline=use_software_pipeline,
         )
     raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
-def random_compute_location(target: Target) -> ScheduleRule:
-    """Default schedule rules for with random-compute-location"""
-    if target.kind.name == "llvm":
-        return RandomComputeLocation()
-    raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
-def parallel_vectorize_unroll(target: Target) -> ScheduleRule:
-    """Default schedule rules for with parallel-vectorize-unroll"""
-    if target.kind.name == "llvm":
-        return ParallelizeVectorizeUnroll(
-            max_jobs_per_core=16,
-            max_vectorize_extent=32,
-            unroll_max_steps=[0, 16, 64, 512],
-            unroll_explicit=True,
-        )
-    if target.kind.name == "cuda":
-        return ParallelizeVectorizeUnroll(
-            max_jobs_per_core=-1,
-            max_vectorize_extent=-1,
-            unroll_max_steps=[0, 16, 64, 512, 1024],
-            unroll_explicit=True,
-        )
-    raise NotImplementedError(f"{target.kind.name} is not supported")
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
index a89cca72e1..21ad04da47 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
@@ -15,10 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
-from tvm.meta_schedule.testing.schedule_rule import auto_bind
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.space_generation import check_sketches
 from tvm.script import tir as T
 from tvm.target import Target
 
@@ -60,83 +58,120 @@ def zero_dim_add(
         C[()] = A[()] + B[()]
 
 
-def _create_context(mod, target, rule) -> TuneContext:
-    ctx = TuneContext(
-        mod=mod,
-        target=target,
-        space_generator=PostOrderApply(),
-        sch_rules=[rule],
-        task_name="test",
-    )
-    return ctx
-
-
 def test_cuda_element_wise():
-    expected = [
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            "l1, l2 = sch.get_loops(block=b0)",
-            "l3 = sch.fuse(l1, l2, preserve_unit_iters=True)",
-            "v4 = sch.sample_categorical(candidates=[32, 64, 128, 256, 512, 1024], probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666])",
-            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
-            'sch.bind(loop=l5, thread_axis="blockIdx.x")',
-            'sch.bind(loop=l6, thread_axis="threadIdx.x")',
-        ]
+    @T.prim_func
+    def elementwise_0(
+        A: T.Buffer[(512, 512), "float32"],
+        B: T.Buffer[(512, 512), "float32"],
+    ) -> None:
+        # body
+        # with T.block("root")
+        for i_j_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
+            for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
+                with T.block("C"):
+                    vi = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) // 512)
+                    vj = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) % 512)
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] + T.float32(1)
+
+    decision_0 = [
+        ("SampleCategorical", 5),
     ]
-    target = Target("nvidia/geforce-rtx-3080", host="llvm")
-    ctx = _create_context(
-        element_wise,
-        target=target,
-        rule=auto_bind(target=target),
+    mod = element_wise
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3080", host="llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.AutoBind(
+                max_threadblocks=256,
+                thread_extents=[32, 64, 128, 256, 512, 1024],
+            )
+        ],
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[elementwise_0],
+        expected_decisions=[decision_0],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
 
 
 def test_cuda_reduction_loop_only():
-    expected = [
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            "l1, = sch.get_loops(block=b0)",
-            "l2 = sch.add_unit_loop(block_or_loop=l1)",
-            "l3 = sch.fuse(l2, preserve_unit_iters=True)",
-            "l4, l5 = sch.split(loop=l3, factors=[None, 1], preserve_unit_iters=True)",
-            'sch.bind(loop=l4, thread_axis="blockIdx.x")',
-            'sch.bind(loop=l5, thread_axis="threadIdx.x")',
-        ]
-    ]
-    target = Target("nvidia/geforce-rtx-3080", host="llvm")
-    ctx = _create_context(
-        reduction_loop_only,
-        target=target,
-        rule=auto_bind(target=target),
+    @T.prim_func
+    def reduction_loop_only_0(
+        A: T.Buffer[2, "float32"],
+        B: T.Buffer[2, "float32"],
+        C: T.Buffer[(), "float32"],
+    ) -> None:
+        for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+            for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
+                for i0 in T.serial(2):
+                    with T.block("C"):
+                        k0 = T.axis.reduce(2, i0)
+                        T.reads(A[k0], B[k0])
+                        T.writes(C[()])
+                        with T.init():
+                            C[()] = T.float32(1)
+                        C[()] = T.min(C[()], A[k0] / B[k0])
+
+    mod = reduction_loop_only
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3080", host="llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.AutoBind(
+                max_threadblocks=256,
+                thread_extents=[32, 64, 128, 256, 512, 1024],
+            )
+        ],
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[reduction_loop_only_0],
+        expected_decisions=[[]],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
 
 
 def test_cuda_zero_dim_add():
-    expected = [
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            "l1 = sch.add_unit_loop(block_or_loop=b0)",
-            "l2 = sch.fuse(l1, preserve_unit_iters=True)",
-            "l3, l4 = sch.split(loop=l2, factors=[None, 1], preserve_unit_iters=True)",
-            'sch.bind(loop=l3, thread_axis="blockIdx.x")',
-            'sch.bind(loop=l4, thread_axis="threadIdx.x")',
-        ]
-    ]
-    target = Target("nvidia/geforce-rtx-3080", host="llvm")
-    ctx = _create_context(
-        zero_dim_add,
-        target=target,
-        rule=auto_bind(target=target),
+    @T.prim_func
+    def zero_dim_add_0(
+        A: T.Buffer[(), "float32"],
+        B: T.Buffer[(), "float32"],
+        C: T.Buffer[(), "float32"],
+    ) -> None:
+        for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+            for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
+                with T.block("C"):
+                    vi = T.axis.spatial(1, 0)
+                    T.reads(A[()], B[()])
+                    T.writes(C[()])
+                    C[()] = A[()] + B[()]
+
+    mod = zero_dim_add
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3080", host="llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.AutoBind(
+                max_threadblocks=256,
+                thread_extents=[32, 64, 128, 256, 512, 1024],
+            )
+        ],
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[zero_dim_add_0],
+        expected_decisions=[[]],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index 592d32d624..a0ca47c09a 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -17,14 +17,12 @@
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 
 import tvm
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
+from tvm import meta_schedule as ms
 from tvm.meta_schedule.testing import te_workload
-from tvm.meta_schedule.testing.schedule_rule import cross_thread_reduction
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm.meta_schedule.testing.space_generation import check_sketches
 from tvm.script import tir as T
 from tvm.target import Target
-from tvm.te.operation import create_prim_func
+from tvm.te import create_prim_func
 
 
 @tvm.script.ir_module
@@ -59,179 +57,522 @@ class Softmax_mn_after_inline:
                 )
 
 
-def _create_context(mod, target, rule) -> TuneContext:
-    ctx = TuneContext(
-        mod=mod,
-        target=target,
-        space_generator=PostOrderApply(),
-        sch_rules=[rule],
-        task_name="test",
-    )
-    return ctx
+def test_gpu_softmax_mn():
+    @T.prim_func
+    def softmax_mn_0(
+        A: T.Buffer[(256, 256), "float32"],
+        T_softmax_norm: T.Buffer[(256, 256), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+        T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32")
+        T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
+        for i0, i1 in T.grid(256, 256):
+            with T.block("T_softmax_maxelem"):
+                i0_1, k = T.axis.remap("SR", [i0, i1])
+                T.reads(A[i0_1, k])
+                T.writes(T_softmax_maxelem[i0_1])
+                with T.init():
+                    T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+                T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+        for i0, i1 in T.grid(256, 256):
+            with T.block("T_softmax_exp"):
+                i0_2, i1_1 = T.axis.remap("SS", [i0, i1])
+                T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2])
+                T.writes(T_softmax_exp[i0_2, i1_1])
+                T_softmax_exp[i0_2, i1_1] = T.exp(
+                    A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32"
+                )
+        for i0_3, i1 in T.grid(256, 256):
+            with T.block("T_softmax_expsum"):
+                i0_4, k = T.axis.remap("SR", [i0_3, i1])
+                T.reads(T_softmax_exp[i0_4, k])
+                T.writes(T_softmax_expsum[i0_4])
+                with T.init():
+                    T_softmax_expsum[i0_4] = T.float32(0)
+                T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k]
+        for i0_5, i1 in T.grid(256, 256):
+            with T.block("T_softmax_norm"):
+                i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1])
+                T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6])
+                T.writes(T_softmax_norm[i0_6, i1_2])
+                T.block_attr({"axis": 1})
+                T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6]
 
+    @T.prim_func
+    def softmax_mn_1(
+        A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+        T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32")
+        T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
+        for i0 in T.serial(256):
+            for ax0, ax1_0 in T.grid(1, 1):
+                for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_maxelem"):
+                        T.where(ax1_0 * 512 + ax1_1 < 256)
+                        i0_1 = T.axis.spatial(256, ax0 + i0)
+                        k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+                        T.reads(A[i0_1, k])
+                        T.writes(T_softmax_maxelem_shared[i0_1])
+                        with T.init():
+                            T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38)
+                        T_softmax_maxelem_shared[i0_1] = T.max(
+                            T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+                        )
+            for i1_0 in T.serial(1):
+                for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_exp"):
+                        T.where(i1_0 * 512 + i1_1 < 256)
+                        i0_2 = T.axis.spatial(256, i0)
+                        i1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
+                        T.reads(A[i0_2, i1], T_softmax_maxelem_shared[i0_2])
+                        T.writes(T_softmax_exp[i0_2, i1])
+                        T_softmax_exp[i0_2, i1] = T.exp(
+                            A[i0_2, i1] - T_softmax_maxelem_shared[i0_2], dtype="float32"
+                        )
+        for i0_3, i1 in T.grid(256, 256):
+            with T.block("T_softmax_expsum"):
+                i0_4, k = T.axis.remap("SR", [i0_3, i1])
+                T.reads(T_softmax_exp[i0_4, k])
+                T.writes(T_softmax_expsum[i0_4])
+                with T.init():
+                    T_softmax_expsum[i0_4] = T.float32(0)
+                T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k]
+        for i0_5, i1 in T.grid(256, 256):
+            with T.block("T_softmax_norm"):
+                i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1])
+                T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6])
+                T.writes(T_softmax_norm[i0_6, i1_2])
+                T.block_attr({"axis": 1})
+                T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6]
 
-def test_gpu_softmax_mn():
-    expected = [
-        [],
-        [
-            'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
-            "b1, = sch.get_consumers(block=b0)",
-            "l2, l3 = sch.get_loops(block=b1)",
-            "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
-            'sch.bind(loop=l6, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
-            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
-            "l7, l8, l9 = sch.get_loops(block=b0)",
-            "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
-            'sch.bind(loop=l11, thread_axis="threadIdx.x")',
-        ],
-        [
-            'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")',
-            "b1, = sch.get_consumers(block=b0)",
-            "l2, l3 = sch.get_loops(block=b1)",
-            "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
-            'sch.bind(loop=l6, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
-            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
-            "l7, l8, l9 = sch.get_loops(block=b0)",
-            "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
-            'sch.bind(loop=l11, thread_axis="threadIdx.x")',
-        ],
-        [
-            'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
-            'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")',
-            "b2, = sch.get_consumers(block=b1)",
-            "l3, l4 = sch.get_loops(block=b2)",
-            "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
-            'sch.bind(loop=l7, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)",
-            'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
-            "l8, l9, l10 = sch.get_loops(block=b1)",
-            "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
-            'sch.bind(loop=l12, thread_axis="threadIdx.x")',
-            "b13, = sch.get_consumers(block=b0)",
-            "l14, l15 = sch.get_loops(block=b13)",
-            "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l17, l18 = sch.split(loop=l15, factors=[None, v16], preserve_unit_iters=True)",
-            'sch.bind(loop=l18, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True, index=-1)",
-            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
-            "l19, l20, l21 = sch.get_loops(block=b0)",
-            "l22, l23 = sch.split(loop=l21, factors=[None, v16], preserve_unit_iters=True)",
-            'sch.bind(loop=l23, thread_axis="threadIdx.x")',
-        ],
+    @T.prim_func
+    def softmax_mn_2(
+        A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+        T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32")
+        T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+        for i0, i1 in T.grid(256, 256):
+            with T.block("T_softmax_maxelem"):
+                i0_1, k = T.axis.remap("SR", [i0, i1])
+                T.reads(A[i0_1, k])
+                T.writes(T_softmax_maxelem[i0_1])
+                with T.init():
+                    T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+                T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+        for i0, i1 in T.grid(256, 256):
+            with T.block("T_softmax_exp"):
+                i0_2, i1_1 = T.axis.remap("SS", [i0, i1])
+                T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2])
+                T.writes(T_softmax_exp[i0_2, i1_1])
+                T_softmax_exp[i0_2, i1_1] = T.exp(
+                    A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32"
+                )
+        for i0_3 in T.serial(256):
+            for ax0, ax1_0 in T.grid(1, 32):
+                for ax1_1 in T.thread_binding(8, thread="threadIdx.x"):
+                    with T.block("T_softmax_expsum"):
+                        i0_4 = T.axis.spatial(256, ax0 + i0_3)
+                        k = T.axis.reduce(256, ax1_0 * 8 + ax1_1)
+                        T.reads(T_softmax_exp[i0_4, k])
+                        T.writes(T_softmax_expsum_shared[i0_4])
+                        with T.init():
+                            T_softmax_expsum_shared[i0_4] = T.float32(0)
+                        T_softmax_expsum_shared[i0_4] = (
+                            T_softmax_expsum_shared[i0_4] + T_softmax_exp[i0_4, k]
+                        )
+            for i1_0 in T.serial(32):
+                for i1_1_1 in T.thread_binding(8, thread="threadIdx.x"):
+                    with T.block("T_softmax_norm"):
+                        i0_5 = T.axis.spatial(256, i0_3)
+                        i1 = T.axis.spatial(256, i1_0 * 8 + i1_1_1)
+                        T.reads(T_softmax_exp[i0_5, i1], T_softmax_expsum_shared[i0_5])
+                        T.writes(T_softmax_norm[i0_5, i1])
+                        T.block_attr({"axis": 1})
+                        T_softmax_norm[i0_5, i1] = (
+                            T_softmax_exp[i0_5, i1] / T_softmax_expsum_shared[i0_5]
+                        )
+
+    @T.prim_func
+    def softmax_mn_3(
+        A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+        T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32")
+        T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+        for i0 in T.serial(256):
+            for ax0, ax1_0 in T.grid(1, 1):
+                for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_maxelem"):
+                        T.where(ax1_0 * 512 + ax1_1 < 256)
+                        i0_1 = T.axis.spatial(256, ax0 + i0)
+                        k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+                        T.reads(A[i0_1, k])
+                        T.writes(T_softmax_maxelem_shared[i0_1])
+                        with T.init():
+                            T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38)
+                        T_softmax_maxelem_shared[i0_1] = T.max(
+                            T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+                        )
+            for i1_0 in T.serial(1):
+                for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_exp"):
+                        T.where(i1_0 * 512 + i1_1 < 256)
+                        i0_2 = T.axis.spatial(256, i0)
+                        i1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
+                        T.reads(A[i0_2, i1], T_softmax_maxelem_shared[i0_2])
+                        T.writes(T_softmax_exp[i0_2, i1])
+                        T_softmax_exp[i0_2, i1] = T.exp(
+                            A[i0_2, i1] - T_softmax_maxelem_shared[i0_2], dtype="float32"
+                        )
+        for i0_3 in T.serial(256):
+            for ax0, ax1_0 in T.grid(1, 32):
+                for ax1_1 in T.thread_binding(8, thread="threadIdx.x"):
+                    with T.block("T_softmax_expsum"):
+                        i0_4 = T.axis.spatial(256, ax0 + i0_3)
+                        k = T.axis.reduce(256, ax1_0 * 8 + ax1_1)
+                        T.reads(T_softmax_exp[i0_4, k])
+                        T.writes(T_softmax_expsum_shared[i0_4])
+                        with T.init():
+                            T_softmax_expsum_shared[i0_4] = T.float32(0)
+                        T_softmax_expsum_shared[i0_4] = (
+                            T_softmax_expsum_shared[i0_4] + T_softmax_exp[i0_4, k]
+                        )
+            for i1_0 in T.serial(32):
+                for i1_1 in T.thread_binding(8, thread="threadIdx.x"):
+                    with T.block("T_softmax_norm"):
+                        i0_5 = T.axis.spatial(256, i0_3)
+                        i1 = T.axis.spatial(256, i1_0 * 8 + i1_1)
+                        T.reads(T_softmax_exp[i0_5, i1], T_softmax_expsum_shared[i0_5])
+                        T.writes(T_softmax_norm[i0_5, i1])
+                        T.block_attr({"axis": 1})
+                        T_softmax_norm[i0_5, i1] = (
+                            T_softmax_exp[i0_5, i1] / T_softmax_expsum_shared[i0_5]
+                        )
+
+    decision_0 = []  # type: ignore
+    decision_1 = [
+        ("SampleCategorical", 7),
+    ]
+    decision_2 = [
+        ("SampleCategorical", 1),
+    ]
+    decision_3 = [
+        ("SampleCategorical", 1),
+        ("SampleCategorical", 7),
     ]
-    target = Target("nvidia/geforce-rtx-3090", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.softmax_mn(
-                n=256,
-                m=256,
-            )
-        ),
-        target=target,
-        rule=cross_thread_reduction(target=target),
+    mod = create_prim_func(te_workload.softmax_mn(n=256, m=256))
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3090", host="llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
+        ],
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[softmax_mn_0, softmax_mn_1, softmax_mn_2, softmax_mn_3],
+        expected_decisions=[decision_0, decision_1, decision_2, decision_3],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 4
-    check_trace(spaces, expected)
 
 
 def test_gpu_softmax_mn_after_inline():
-    expected = [
-        [],
-        [
-            'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
-            "v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l2, l3 = sch.get_loops(block=b0)",
-            "l4, l5 = sch.split(loop=l3, factors=[None, v1], preserve_unit_iters=True)",
-            'sch.bind(loop=l5, thread_axis="threadIdx.x")',
-        ],
-        [
-            'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")',
-            "b1, = sch.get_consumers(block=b0)",
-            "l2, l3 = sch.get_loops(block=b1)",
-            "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
-            'sch.bind(loop=l6, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
-            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
-            "l7, l8, l9 = sch.get_loops(block=b0)",
-            "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
-            'sch.bind(loop=l11, thread_axis="threadIdx.x")',
+    @T.prim_func
+    def softmax_mn_after_inline_0(
+        A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+    ) -> None:
+        T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+        T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
+        for i0, i1 in T.grid(256, 256):
+            with T.block("T_softmax_maxelem"):
+                i0_1, k = T.axis.remap("SR", [i0, i1])
+                T.reads(A[i0_1, k])
+                T.writes(T_softmax_maxelem[i0_1])
+                with T.init():
+                    T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+                T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+        for i0, i1 in T.grid(256, 256):
+            with T.block("T_softmax_expsum"):
+                i0_2, k = T.axis.remap("SR", [i0, i1])
+                T.reads(A[i0_2, k], T_softmax_maxelem[i0_2])
+                T.writes(T_softmax_expsum[i0_2])
+                with T.init():
+                    T_softmax_expsum[i0_2] = T.float32(0)
+                T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(
+                    A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32"
+                )
+        for i0_3, i1 in T.grid(256, 256):
+            with T.block("T_softmax_norm"):
+                i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1])
+                T.reads(A[i0_4, i1_1], T_softmax_maxelem[i0_4], T_softmax_expsum[i0_4])
+                T.writes(T_softmax_norm[i0_4, i1_1])
+                T.block_attr({"axis": 1})
+                T_softmax_norm[i0_4, i1_1] = (
+                    T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32")
+                    / T_softmax_expsum[i0_4]
+                )
+
+    @T.prim_func
+    def softmax_mn_after_inline_1(
+        A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+    ) -> None:
+        T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+        T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
+        for i0, i1_0 in T.grid(256, 4):
+            for i1_1 in T.thread_binding(64, thread="threadIdx.x"):
+                with T.block("T_softmax_maxelem"):
+                    i0_1 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, i1_0 * 64 + i1_1)
+                    T.reads(A[i0_1, k])
+                    T.writes(T_softmax_maxelem[i0_1])
+                    with T.init():
+                        T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+                    T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+        for i0, i1 in T.grid(256, 256):
+            with T.block("T_softmax_expsum"):
+                i0_2, k = T.axis.remap("SR", [i0, i1])
+                T.reads(A[i0_2, k], T_softmax_maxelem[i0_2])
+                T.writes(T_softmax_expsum[i0_2])
+                with T.init():
+                    T_softmax_expsum[i0_2] = T.float32(0)
+                T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(
+                    A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32"
+                )
+        for i0_3, i1 in T.grid(256, 256):
+            with T.block("T_softmax_norm"):
+                i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1])
+                T.reads(A[i0_4, i1_1], T_softmax_maxelem[i0_4], T_softmax_expsum[i0_4])
+                T.writes(T_softmax_norm[i0_4, i1_1])
+                T.block_attr({"axis": 1})
+                T_softmax_norm[i0_4, i1_1] = (
+                    T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32")
+                    / T_softmax_expsum[i0_4]
+                )
+
+    @T.prim_func
+    def softmax_mn_after_inline_2(
+        A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+    ) -> None:
+        T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+        T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+        for i0, i1 in T.grid(256, 256):
+            with T.block("T_softmax_maxelem"):
+                i0_1, k = T.axis.remap("SR", [i0, i1])
+                T.reads(A[i0_1, k])
+                T.writes(T_softmax_maxelem[i0_1])
+                with T.init():
+                    T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+                T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+        for i0_3 in T.serial(256):
+            for ax0, ax1_0 in T.grid(1, 1):
+                for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_expsum"):
+                        T.where(ax1_0 * 512 + ax1_1 < 256)
+                        i0_2 = T.axis.spatial(256, ax0 + i0_3)
+                        k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+                        T.reads(A[i0_2, k], T_softmax_maxelem[i0_2])
+                        T.writes(T_softmax_expsum_shared[i0_2])
+                        with T.init():
+                            T_softmax_expsum_shared[i0_2] = T.float32(0)
+                        T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
+                            A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32"
+                        )
+            for i1_0 in T.serial(1):
+                for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_norm"):
+                        T.where(i1_0 * 512 + i1_1 < 256)
+                        i0_4 = T.axis.spatial(256, i0_3)
+                        i1_1_1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
+                        T.reads(
+                            A[i0_4, i1_1_1], T_softmax_maxelem[i0_4], T_softmax_expsum_shared[i0_4]
+                        )
+                        T.writes(T_softmax_norm[i0_4, i1_1_1])
+                        T.block_attr({"axis": 1})
+                        T_softmax_norm[i0_4, i1_1_1] = (
+                            T.exp(A[i0_4, i1_1_1] - T_softmax_maxelem[i0_4], dtype="float32")
+                            / T_softmax_expsum_shared[i0_4]
+                        )
+
+    @T.prim_func
+    def softmax_mn_after_inline_3(
+        A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+    ) -> None:
+        T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+        T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+        for i0_3 in T.serial(256):
+            for ax0, ax1_0 in T.grid(1, 1):
+                for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_maxelem"):
+                        T.where(ax1_0 * 512 + ax1_1 < 256)
+                        i0_1 = T.axis.spatial(256, ax0 + i0_3)
+                        k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+                        T.reads(A[i0_1, k])
+                        T.writes(T_softmax_maxelem_shared[i0_1])
+                        with T.init():
+                            T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38)
+                        T_softmax_maxelem_shared[i0_1] = T.max(
+                            T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+                        )
+            for ax0, ax1_0 in T.grid(1, 1):
+                for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_expsum"):
+                        T.where(ax1_0 * 512 + ax1_1 < 256)
+                        i0_2 = T.axis.spatial(256, ax0 + i0_3)
+                        k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+                        T.reads(A[i0_2, k], T_softmax_maxelem_shared[i0_2])
+                        T.writes(T_softmax_expsum_shared[i0_2])
+                        with T.init():
+                            T_softmax_expsum_shared[i0_2] = T.float32(0)
+                        T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
+                            A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
+                        )
+            for i1_0 in T.serial(1):
+                for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                    with T.block("T_softmax_norm"):
+                        T.where(i1_0 * 512 + i1_1 < 256)
+                        i0_4 = T.axis.spatial(256, i0_3)
+                        i1_1_1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
+                        T.reads(
+                            A[i0_4, i1_1_1],
+                            T_softmax_maxelem_shared[i0_4],
+                            T_softmax_expsum_shared[i0_4],
+                        )
+                        T.writes(T_softmax_norm[i0_4, i1_1_1])
+                        T.block_attr({"axis": 1})
+                        T_softmax_norm[i0_4, i1_1_1] = (
+                            T.exp(A[i0_4, i1_1_1] - T_softmax_maxelem_shared[i0_4], dtype="float32")
+                            / T_softmax_expsum_shared[i0_4]
+                        )
+
+    decision_0 = []  # type: ignore
+    decision_1 = [
+        ("SampleCategorical", 4),
+    ]
+    decision_2 = [
+        ("SampleCategorical", 7),
+    ]
+    decision_3 = [
+        ("SampleCategorical", 7),
+        ("SampleCategorical", 0),
+    ]
+
+    mod = Softmax_mn_after_inline
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3090", host="llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
         ],
-        [
-            'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
-            'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")',
-            "b2, = sch.get_consumers(block=b1)",
-            "l3, l4 = sch.get_loops(block=b2)",
-            "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
-            'sch.bind(loop=l7, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)",
-            'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
-            "l8, l9, l10 = sch.get_loops(block=b1)",
-            "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
-            'sch.bind(loop=l12, thread_axis="threadIdx.x")',
-            "b13, b14 = sch.get_consumers(block=b0)",
-            "l15, l16, l17, l18 = sch.get_loops(block=b13)",
-            "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True, index=-1)",
-            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
-            "l19, l20, l21 = sch.get_loops(block=b0)",
-            "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)",
-            'sch.bind(loop=l23, thread_axis="threadIdx.x")',
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[
+            softmax_mn_after_inline_0,
+            softmax_mn_after_inline_1,
+            softmax_mn_after_inline_2,
+            softmax_mn_after_inline_3,
         ],
-    ]
-    target = Target("nvidia/geforce-rtx-3090", host="llvm")
-    ctx = _create_context(
-        mod=Softmax_mn_after_inline,
-        target=target,
-        rule=cross_thread_reduction(target=target),
+        expected_decisions=[decision_0, decision_1, decision_2, decision_3],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 4
-    check_trace(spaces, expected)
 
 
 def test_gpu_batch_norm_bmn():
-    expected = [
-        [],
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            "b1, = sch.get_consumers(block=b0)",
-            "l2, = sch.get_loops(block=b1)",
-            "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
-            "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)",
-            'sch.bind(loop=l5, thread_axis="threadIdx.x")',
-            "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True, index=-1)",
-            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
-            "l6, l7, l8, l9 = sch.get_loops(block=b0)",
-            "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)",
-            "l11, l12 = sch.split(loop=l10, factors=[None, v3], preserve_unit_iters=True)",
-            'sch.bind(loop=l12, thread_axis="threadIdx.x")',
-        ],
+    @T.prim_func
+    def batch_norm_bmn_0(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C = T.alloc_buffer([1], dtype="float32")
+        for i0, i1, i2 in T.grid(1, 512, 512):
+            with T.block("C"):
+                b, i, j = T.axis.remap("SRR", [i0, i1, i2])
+                T.reads(A[b, i, j])
+                T.writes(C[b])
+                with T.init():
+                    C[b] = T.float32(0)
+                C[b] = C[b] + A[b, i, j] * A[b, i, j]
+        for i0 in T.serial(1):
+            with T.block("D"):
+                b = T.axis.spatial(1, i0)
+                T.reads(C[b])
+                T.writes(D[b])
+                D[b] = T.sqrt(C[b], dtype="float32")
+
+    @T.prim_func
+    def batch_norm_bmn_1(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        C_shared = T.alloc_buffer([1], dtype="float32", scope="shared")
+        for i0_0 in T.serial(1):
+            for ax0, ax1_ax2_fused_0 in T.grid(1, 1024):
+                for ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                    with T.block("C"):
+                        b = T.axis.spatial(1, ax0)
+                        i = T.axis.reduce(512, (ax1_ax2_fused_0 * 256 + ax1_ax2_fused_1) // 512)
+                        j = T.axis.reduce(512, (ax1_ax2_fused_0 * 256 + ax1_ax2_fused_1) % 512)
+                        T.reads(A[b, i, j])
+                        T.writes(C_shared[b])
+                        with T.init():
+                            C_shared[b] = T.float32(0)
+                        C_shared[b] = C_shared[b] + A[b, i, j] * A[b, i, j]
+            for i0_1 in T.thread_binding(256, thread="threadIdx.x"):
+                with T.block("D"):
+                    T.where(i0_0 * 256 + i0_1 < 1)
+                    b = T.axis.spatial(1, i0_0 * 256 + i0_1)
+                    T.reads(C_shared[b])
+                    T.writes(D[b])
+                    D[b] = T.sqrt(C_shared[b], dtype="float32")
+
+    decision_0 = []  # type: ignore
+    decision_1 = [
+        ("SampleCategorical", 6),
     ]
-    target = Target("nvidia/geforce-rtx-3090", host="llvm")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.norm_bmn(
-                B=1,
-                M=512,
-                N=512,
-            )
-        ),
-        target=target,
-        rule=cross_thread_reduction(target=target),
+
+    mod = create_prim_func(te_workload.norm_bmn(B=1, M=512, N=512))
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("nvidia/geforce-rtx-3090", host="llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
+        ],
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[batch_norm_bmn_0, batch_norm_bmn_1],
+        expected_decisions=[decision_0, decision_1],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 2
-    check_trace(spaces, expected)
 
 
 if __name__ == "__main__":
-    # test_gpu_softmax_mn()
-    # test_gpu_softmax_mn_after_inline()
+    test_gpu_softmax_mn()
+    test_gpu_softmax_mn_after_inline()
     test_gpu_batch_norm_bmn()
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py
index 02b55350b7..8076fcaa8b 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py
@@ -17,10 +17,7 @@
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 import tvm
 from tvm import meta_schedule as ms
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
-from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm.meta_schedule.testing.space_generation import check_sketches
 from tvm.script import tir as T
 from tvm.target import Target
 
@@ -68,10 +65,7 @@ class ParallelizeVectorizeUnroll:
 class PureSpatial:
     @T.prim_func
     def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None:
-        # function attr dict
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        # body
-        # with T.block("root")
         T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
         T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
         T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
@@ -224,55 +218,80 @@ class PureSpatial:
 # fmt: on
 
 
-def _create_context(mod, target, rule):
-    ctx = TuneContext(
-        mod=mod,
-        target=target,
-        space_generator=PostOrderApply(),
-        sch_rules=[rule],
-        task_name="test",
-    )
-    return ctx
-
-
 def test_parallel_vectorize_unroll():
-    expected = [
-        [
-            'b0 = sch.get_block(name="root", func_name="main")',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.parallel", ann_val=512)',
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.vectorize", ann_val=32)',
-            "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])",
-            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)',
-        ]
+    @T.prim_func
+    def Matmul_0(
+        A: T.Buffer[(1024, 1024), "float32"],
+        B: T.Buffer[(1024, 1024), "float32"],
+        C: T.Buffer[(1024, 1024), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main"})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr(
+                {
+                    "meta_schedule.parallel": 512,
+                    "meta_schedule.unroll_explicit": 16,
+                    "meta_schedule.vectorize": 32,
+                }
+            )
+            for i, j, k in T.grid(1024, 1024, 1024):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    T.reads(A[vi, vk], B[vk, vj])
+                    T.writes(C[vi, vj])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+    decision_0 = [
+        ("SampleCategorical", 1),
     ]
+
     mod = Matmul
-    target = Target("llvm --num-cores=32")
-    ctx = _create_context(
+    actual = ms.TuneContext(
         mod=mod,
-        target=target,
-        rule=parallel_vectorize_unroll(target=target),
+        target=Target("llvm --num-cores=32"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.ParallelizeVectorizeUnroll(
+                max_jobs_per_core=16,
+                max_vectorize_extent=32,
+                unroll_max_steps=[0, 16, 64, 512],
+                unroll_explicit=True,
+            ),
+        ],
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[Matmul_0],
+        expected_decisions=[decision_0],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
 
 
 def test_parallel_vectorize_unroll_spatial():
     mod = PureSpatial
-    target = Target("llvm --num-cores=32")
-    ctx = _create_context(
+    actual = ms.TuneContext(
         mod=mod,
-        target=target,
-        rule=ms.schedule_rule.ParallelizeVectorizeUnroll(
-            max_jobs_per_core=-1,
-            max_vectorize_extent=-1,
-            unroll_max_steps=[1, 2, 4, 8, 16, 32, 64],
-            unroll_explicit=True,
-        ),
-    )
-    spaces = ctx.space_generator.generate_design_space(mod=mod)
-    assert len(spaces) == 1
-    trace = spaces[0].trace.simplified(remove_postproc=True)
+        target=Target("llvm --num-cores=32"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[
+            ms.schedule_rule.ParallelizeVectorizeUnroll(
+                max_jobs_per_core=-1,
+                max_vectorize_extent=-1,
+                unroll_max_steps=[0, 16, 64, 512],
+                unroll_explicit=True,
+            ),
+        ],
+        task_name="test",
+    ).generate_design_space()
+    assert len(actual) == 1
+    trace = actual[0].trace.simplified(remove_postproc=True)
     assert not trace.insts
 
 
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
index c951a5adf3..fc52aa199c 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
@@ -16,10 +16,8 @@
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 import tvm
-from tvm.meta_schedule.schedule_rule import RandomComputeLocation
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.space_generation import check_sketches
 from tvm.script import tir as T
 from tvm.target import Target
 
@@ -55,35 +53,53 @@ class Add:
 # fmt: on
 
 
-def _create_context(mod, target, rule):
-    ctx = TuneContext(
-        mod=mod,
-        target=target,
-        space_generator=PostOrderApply(),
-        sch_rules=[rule],
-        task_name="test",
-    )
-    return ctx
-
-
 def test_random_compute_location():
-    expected = [
-        [
-            'b0 = sch.get_block(name="move", func_name="main")',
-            "l1 = sch.sample_compute_location(block=b0)",
-            "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True, index=-1)",
-        ]
+    @T.prim_func
+    def add_0(
+        A: T.Buffer[(2048, 2048, 2048), "float32"],
+        B: T.Buffer[(2048, 2048, 2048), "float32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main"})
+        # body
+        # with T.block("root")
+        A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32")
+        for i0, j0, i1, j1, k0, i2 in T.grid(128, 64, 4, 4, 64, 4):
+            for ax0, ax1, ax2 in T.grid(1, 8, 32):
+                with T.block("move"):
+                    vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2 + ax0)
+                    vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + ax1)
+                    vk = T.axis.spatial(2048, k0 * 32 + ax2)
+                    T.reads(A[vi, vj, vk])
+                    T.writes(A_cached[vi, vj, vk])
+                    A_cached[vi, vj, vk] = A[vi, vj, vk]
+            for j2, k1 in T.grid(8, 32):
+                with T.block("add"):
+                    vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2)
+                    vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2)
+                    vk = T.axis.spatial(2048, k0 * 32 + k1)
+                    T.reads(A_cached[vi, vj, vk])
+                    T.writes(B[vi, vj, vk])
+                    B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1)
+
+    decision_0 = [
+        ("SampleComputeLocation", 5),
     ]
+
     mod = Add
-    target = Target("llvm")
-    ctx = _create_context(
+    actual = ms.TuneContext(
         mod=mod,
-        target=target,
-        rule=RandomComputeLocation(),
+        target=Target("llvm"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[ms.schedule_rule.RandomComputeLocation()],
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[add_0],
+        expected_decisions=[decision_0],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=mod)
-    assert len(spaces) == 1
-    check_trace(spaces, expected)
 
 
 if __name__ == "__main__":