You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/09/13 09:20:37 UTC
[tvm] branch main updated: [MetaSchedule][Test] Migrate `check_trace` to `check_sketch` (#12764)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ef784d68e0 [MetaSchedule][Test] Migrate `check_trace` to `check_sketch` (#12764)
ef784d68e0 is described below
commit ef784d68e04ab4b858ce4c953b2d83b5d5811eda
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Tue Sep 13 02:20:30 2022 -0700
[MetaSchedule][Test] Migrate `check_trace` to `check_sketch` (#12764)
* Migrate AutoBind
* Migrate RandomComputeLocation
* Migrate CrossThreadReduction
* Migrate ParallelVectorizeUnroll
---
python/tvm/meta_schedule/testing/schedule_rule.py | 48 +-
.../test_meta_schedule_schedule_rule_auto_bind.py | 175 +++---
...chedule_schedule_rule_cross_thread_reduction.py | 665 ++++++++++++++++-----
...dule_schedule_rule_parallel_vectorize_unroll.py | 111 ++--
...hedule_schedule_rule_random_compute_location.py | 72 ++-
5 files changed, 718 insertions(+), 353 deletions(-)
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index b08db0811d..12ca4200d7 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -18,28 +18,15 @@
from typing import List, Union
from tvm.meta_schedule.schedule_rule import (
- AutoBind,
AutoInline,
- CrossThreadReduction,
MultiLevelTiling,
- ParallelizeVectorizeUnroll,
- RandomComputeLocation,
+ MultiLevelTilingTensorCore,
ReuseType,
ScheduleRule,
)
-from tvm.meta_schedule.schedule_rule.multi_level_tiling import (
- MultiLevelTilingTensorCore,
-)
from tvm.target import Target
-def auto_bind(target: Target) -> ScheduleRule:
- """Default schedule rules for auto bind"""
- if target.kind.name == "cuda":
- return AutoBind(max_threadblocks=256, thread_extents=[32, 64, 128, 256, 512, 1024])
- raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
def auto_inline(target: Target) -> ScheduleRule:
"""Default schedule rules for auto inline"""
if target.kind.name == "llvm":
@@ -65,13 +52,6 @@ def auto_inline(target: Target) -> ScheduleRule:
raise NotImplementedError(f"{target.kind.name} is not supported")
-def cross_thread_reduction(target: Target) -> ScheduleRule:
- """Default schedule rules for with cross-thread reduction"""
- if target.kind.name == "cuda":
- return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
- raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
def multi_level_tiling(target: Target) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling and reuse"""
if target.kind.name == "llvm":
@@ -154,29 +134,3 @@ def multi_level_tiling_tensor_core(
use_software_pipeline=use_software_pipeline,
)
raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
-def random_compute_location(target: Target) -> ScheduleRule:
- """Default schedule rules for with random-compute-location"""
- if target.kind.name == "llvm":
- return RandomComputeLocation()
- raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
-def parallel_vectorize_unroll(target: Target) -> ScheduleRule:
- """Default schedule rules for with parallel-vectorize-unroll"""
- if target.kind.name == "llvm":
- return ParallelizeVectorizeUnroll(
- max_jobs_per_core=16,
- max_vectorize_extent=32,
- unroll_max_steps=[0, 16, 64, 512],
- unroll_explicit=True,
- )
- if target.kind.name == "cuda":
- return ParallelizeVectorizeUnroll(
- max_jobs_per_core=-1,
- max_vectorize_extent=-1,
- unroll_max_steps=[0, 16, 64, 512, 1024],
- unroll_explicit=True,
- )
- raise NotImplementedError(f"{target.kind.name} is not supported")
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
index a89cca72e1..21ad04da47 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py
@@ -15,10 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
-from tvm.meta_schedule.testing.schedule_rule import auto_bind
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.space_generation import check_sketches
from tvm.script import tir as T
from tvm.target import Target
@@ -60,83 +58,120 @@ def zero_dim_add(
C[()] = A[()] + B[()]
-def _create_context(mod, target, rule) -> TuneContext:
- ctx = TuneContext(
- mod=mod,
- target=target,
- space_generator=PostOrderApply(),
- sch_rules=[rule],
- task_name="test",
- )
- return ctx
-
-
def test_cuda_element_wise():
- expected = [
- [
- 'b0 = sch.get_block(name="C", func_name="main")',
- "l1, l2 = sch.get_loops(block=b0)",
- "l3 = sch.fuse(l1, l2, preserve_unit_iters=True)",
- "v4 = sch.sample_categorical(candidates=[32, 64, 128, 256, 512, 1024], probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666])",
- "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
- 'sch.bind(loop=l5, thread_axis="blockIdx.x")',
- 'sch.bind(loop=l6, thread_axis="threadIdx.x")',
- ]
+ @T.prim_func
+ def elementwise_0(
+ A: T.Buffer[(512, 512), "float32"],
+ B: T.Buffer[(512, 512), "float32"],
+ ) -> None:
+ # body
+ # with T.block("root")
+ for i_j_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
+ for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
+ with T.block("C"):
+ vi = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) // 512)
+ vj = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) % 512)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] + T.float32(1)
+
+ decision_0 = [
+ ("SampleCategorical", 5),
]
- target = Target("nvidia/geforce-rtx-3080", host="llvm")
- ctx = _create_context(
- element_wise,
- target=target,
- rule=auto_bind(target=target),
+ mod = element_wise
+ actual = ms.TuneContext(
+ mod=mod,
+ target=Target("nvidia/geforce-rtx-3080", host="llvm"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[
+ ms.schedule_rule.AutoBind(
+ max_threadblocks=256,
+ thread_extents=[32, 64, 128, 256, 512, 1024],
+ )
+ ],
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[elementwise_0],
+ expected_decisions=[decision_0],
)
- spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
- assert len(spaces) == 1
- check_trace(spaces, expected)
def test_cuda_reduction_loop_only():
- expected = [
- [
- 'b0 = sch.get_block(name="C", func_name="main")',
- "l1, = sch.get_loops(block=b0)",
- "l2 = sch.add_unit_loop(block_or_loop=l1)",
- "l3 = sch.fuse(l2, preserve_unit_iters=True)",
- "l4, l5 = sch.split(loop=l3, factors=[None, 1], preserve_unit_iters=True)",
- 'sch.bind(loop=l4, thread_axis="blockIdx.x")',
- 'sch.bind(loop=l5, thread_axis="threadIdx.x")',
- ]
- ]
- target = Target("nvidia/geforce-rtx-3080", host="llvm")
- ctx = _create_context(
- reduction_loop_only,
- target=target,
- rule=auto_bind(target=target),
+ @T.prim_func
+ def reduction_loop_only_0(
+ A: T.Buffer[2, "float32"],
+ B: T.Buffer[2, "float32"],
+ C: T.Buffer[(), "float32"],
+ ) -> None:
+ for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+ for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
+ for i0 in T.serial(2):
+ with T.block("C"):
+ k0 = T.axis.reduce(2, i0)
+ T.reads(A[k0], B[k0])
+ T.writes(C[()])
+ with T.init():
+ C[()] = T.float32(1)
+ C[()] = T.min(C[()], A[k0] / B[k0])
+
+ mod = reduction_loop_only
+ actual = ms.TuneContext(
+ mod=mod,
+ target=Target("nvidia/geforce-rtx-3080", host="llvm"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[
+ ms.schedule_rule.AutoBind(
+ max_threadblocks=256,
+ thread_extents=[32, 64, 128, 256, 512, 1024],
+ )
+ ],
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[reduction_loop_only_0],
+ expected_decisions=[[]],
)
- spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
- assert len(spaces) == 1
- check_trace(spaces, expected)
def test_cuda_zero_dim_add():
- expected = [
- [
- 'b0 = sch.get_block(name="C", func_name="main")',
- "l1 = sch.add_unit_loop(block_or_loop=b0)",
- "l2 = sch.fuse(l1, preserve_unit_iters=True)",
- "l3, l4 = sch.split(loop=l2, factors=[None, 1], preserve_unit_iters=True)",
- 'sch.bind(loop=l3, thread_axis="blockIdx.x")',
- 'sch.bind(loop=l4, thread_axis="threadIdx.x")',
- ]
- ]
- target = Target("nvidia/geforce-rtx-3080", host="llvm")
- ctx = _create_context(
- zero_dim_add,
- target=target,
- rule=auto_bind(target=target),
+ @T.prim_func
+ def zero_dim_add_0(
+ A: T.Buffer[(), "float32"],
+ B: T.Buffer[(), "float32"],
+ C: T.Buffer[(), "float32"],
+ ) -> None:
+ for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+ for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
+ with T.block("C"):
+ vi = T.axis.spatial(1, 0)
+ T.reads(A[()], B[()])
+ T.writes(C[()])
+ C[()] = A[()] + B[()]
+
+ mod = zero_dim_add
+ actual = ms.TuneContext(
+ mod=mod,
+ target=Target("nvidia/geforce-rtx-3080", host="llvm"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[
+ ms.schedule_rule.AutoBind(
+ max_threadblocks=256,
+ thread_extents=[32, 64, 128, 256, 512, 1024],
+ )
+ ],
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[zero_dim_add_0],
+ expected_decisions=[[]],
)
- spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
- assert len(spaces) == 1
- check_trace(spaces, expected)
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index 592d32d624..a0ca47c09a 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -17,14 +17,12 @@
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
+from tvm import meta_schedule as ms
from tvm.meta_schedule.testing import te_workload
-from tvm.meta_schedule.testing.schedule_rule import cross_thread_reduction
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm.meta_schedule.testing.space_generation import check_sketches
from tvm.script import tir as T
from tvm.target import Target
-from tvm.te.operation import create_prim_func
+from tvm.te import create_prim_func
@tvm.script.ir_module
@@ -59,179 +57,522 @@ class Softmax_mn_after_inline:
)
-def _create_context(mod, target, rule) -> TuneContext:
- ctx = TuneContext(
- mod=mod,
- target=target,
- space_generator=PostOrderApply(),
- sch_rules=[rule],
- task_name="test",
- )
- return ctx
+def test_gpu_softmax_mn():
+ @T.prim_func
+ def softmax_mn_0(
+ A: T.Buffer[(256, 256), "float32"],
+ T_softmax_norm: T.Buffer[(256, 256), "float32"],
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+ T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32")
+ T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
+ for i0, i1 in T.grid(256, 256):
+ with T.block("T_softmax_maxelem"):
+ i0_1, k = T.axis.remap("SR", [i0, i1])
+ T.reads(A[i0_1, k])
+ T.writes(T_softmax_maxelem[i0_1])
+ with T.init():
+ T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+ for i0, i1 in T.grid(256, 256):
+ with T.block("T_softmax_exp"):
+ i0_2, i1_1 = T.axis.remap("SS", [i0, i1])
+ T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2])
+ T.writes(T_softmax_exp[i0_2, i1_1])
+ T_softmax_exp[i0_2, i1_1] = T.exp(
+ A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32"
+ )
+ for i0_3, i1 in T.grid(256, 256):
+ with T.block("T_softmax_expsum"):
+ i0_4, k = T.axis.remap("SR", [i0_3, i1])
+ T.reads(T_softmax_exp[i0_4, k])
+ T.writes(T_softmax_expsum[i0_4])
+ with T.init():
+ T_softmax_expsum[i0_4] = T.float32(0)
+ T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k]
+ for i0_5, i1 in T.grid(256, 256):
+ with T.block("T_softmax_norm"):
+ i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1])
+ T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6])
+ T.writes(T_softmax_norm[i0_6, i1_2])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6]
+ @T.prim_func
+ def softmax_mn_1(
+ A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+ T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32")
+ T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
+ for i0 in T.serial(256):
+ for ax0, ax1_0 in T.grid(1, 1):
+ for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_maxelem"):
+ T.where(ax1_0 * 512 + ax1_1 < 256)
+ i0_1 = T.axis.spatial(256, ax0 + i0)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+ T.reads(A[i0_1, k])
+ T.writes(T_softmax_maxelem_shared[i0_1])
+ with T.init():
+ T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem_shared[i0_1] = T.max(
+ T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+ )
+ for i1_0 in T.serial(1):
+ for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_exp"):
+ T.where(i1_0 * 512 + i1_1 < 256)
+ i0_2 = T.axis.spatial(256, i0)
+ i1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
+ T.reads(A[i0_2, i1], T_softmax_maxelem_shared[i0_2])
+ T.writes(T_softmax_exp[i0_2, i1])
+ T_softmax_exp[i0_2, i1] = T.exp(
+ A[i0_2, i1] - T_softmax_maxelem_shared[i0_2], dtype="float32"
+ )
+ for i0_3, i1 in T.grid(256, 256):
+ with T.block("T_softmax_expsum"):
+ i0_4, k = T.axis.remap("SR", [i0_3, i1])
+ T.reads(T_softmax_exp[i0_4, k])
+ T.writes(T_softmax_expsum[i0_4])
+ with T.init():
+ T_softmax_expsum[i0_4] = T.float32(0)
+ T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k]
+ for i0_5, i1 in T.grid(256, 256):
+ with T.block("T_softmax_norm"):
+ i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1])
+ T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6])
+ T.writes(T_softmax_norm[i0_6, i1_2])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6]
-def test_gpu_softmax_mn():
- expected = [
- [],
- [
- 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
- "b1, = sch.get_consumers(block=b0)",
- "l2, l3 = sch.get_loops(block=b1)",
- "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
- 'sch.bind(loop=l6, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
- 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
- "l7, l8, l9 = sch.get_loops(block=b0)",
- "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
- 'sch.bind(loop=l11, thread_axis="threadIdx.x")',
- ],
- [
- 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")',
- "b1, = sch.get_consumers(block=b0)",
- "l2, l3 = sch.get_loops(block=b1)",
- "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
- 'sch.bind(loop=l6, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
- 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
- "l7, l8, l9 = sch.get_loops(block=b0)",
- "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
- 'sch.bind(loop=l11, thread_axis="threadIdx.x")',
- ],
- [
- 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
- 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")',
- "b2, = sch.get_consumers(block=b1)",
- "l3, l4 = sch.get_loops(block=b2)",
- "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
- 'sch.bind(loop=l7, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)",
- 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
- "l8, l9, l10 = sch.get_loops(block=b1)",
- "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
- 'sch.bind(loop=l12, thread_axis="threadIdx.x")',
- "b13, = sch.get_consumers(block=b0)",
- "l14, l15 = sch.get_loops(block=b13)",
- "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l17, l18 = sch.split(loop=l15, factors=[None, v16], preserve_unit_iters=True)",
- 'sch.bind(loop=l18, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True, index=-1)",
- 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
- "l19, l20, l21 = sch.get_loops(block=b0)",
- "l22, l23 = sch.split(loop=l21, factors=[None, v16], preserve_unit_iters=True)",
- 'sch.bind(loop=l23, thread_axis="threadIdx.x")',
- ],
+ @T.prim_func
+ def softmax_mn_2(
+ A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+ T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32")
+ T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+ for i0, i1 in T.grid(256, 256):
+ with T.block("T_softmax_maxelem"):
+ i0_1, k = T.axis.remap("SR", [i0, i1])
+ T.reads(A[i0_1, k])
+ T.writes(T_softmax_maxelem[i0_1])
+ with T.init():
+ T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+ for i0, i1 in T.grid(256, 256):
+ with T.block("T_softmax_exp"):
+ i0_2, i1_1 = T.axis.remap("SS", [i0, i1])
+ T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2])
+ T.writes(T_softmax_exp[i0_2, i1_1])
+ T_softmax_exp[i0_2, i1_1] = T.exp(
+ A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32"
+ )
+ for i0_3 in T.serial(256):
+ for ax0, ax1_0 in T.grid(1, 32):
+ for ax1_1 in T.thread_binding(8, thread="threadIdx.x"):
+ with T.block("T_softmax_expsum"):
+ i0_4 = T.axis.spatial(256, ax0 + i0_3)
+ k = T.axis.reduce(256, ax1_0 * 8 + ax1_1)
+ T.reads(T_softmax_exp[i0_4, k])
+ T.writes(T_softmax_expsum_shared[i0_4])
+ with T.init():
+ T_softmax_expsum_shared[i0_4] = T.float32(0)
+ T_softmax_expsum_shared[i0_4] = (
+ T_softmax_expsum_shared[i0_4] + T_softmax_exp[i0_4, k]
+ )
+ for i1_0 in T.serial(32):
+ for i1_1_1 in T.thread_binding(8, thread="threadIdx.x"):
+ with T.block("T_softmax_norm"):
+ i0_5 = T.axis.spatial(256, i0_3)
+ i1 = T.axis.spatial(256, i1_0 * 8 + i1_1_1)
+ T.reads(T_softmax_exp[i0_5, i1], T_softmax_expsum_shared[i0_5])
+ T.writes(T_softmax_norm[i0_5, i1])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_5, i1] = (
+ T_softmax_exp[i0_5, i1] / T_softmax_expsum_shared[i0_5]
+ )
+
+ @T.prim_func
+ def softmax_mn_3(
+ A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+ T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32")
+ T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+ for i0 in T.serial(256):
+ for ax0, ax1_0 in T.grid(1, 1):
+ for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_maxelem"):
+ T.where(ax1_0 * 512 + ax1_1 < 256)
+ i0_1 = T.axis.spatial(256, ax0 + i0)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+ T.reads(A[i0_1, k])
+ T.writes(T_softmax_maxelem_shared[i0_1])
+ with T.init():
+ T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem_shared[i0_1] = T.max(
+ T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+ )
+ for i1_0 in T.serial(1):
+ for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_exp"):
+ T.where(i1_0 * 512 + i1_1 < 256)
+ i0_2 = T.axis.spatial(256, i0)
+ i1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
+ T.reads(A[i0_2, i1], T_softmax_maxelem_shared[i0_2])
+ T.writes(T_softmax_exp[i0_2, i1])
+ T_softmax_exp[i0_2, i1] = T.exp(
+ A[i0_2, i1] - T_softmax_maxelem_shared[i0_2], dtype="float32"
+ )
+ for i0_3 in T.serial(256):
+ for ax0, ax1_0 in T.grid(1, 32):
+ for ax1_1 in T.thread_binding(8, thread="threadIdx.x"):
+ with T.block("T_softmax_expsum"):
+ i0_4 = T.axis.spatial(256, ax0 + i0_3)
+ k = T.axis.reduce(256, ax1_0 * 8 + ax1_1)
+ T.reads(T_softmax_exp[i0_4, k])
+ T.writes(T_softmax_expsum_shared[i0_4])
+ with T.init():
+ T_softmax_expsum_shared[i0_4] = T.float32(0)
+ T_softmax_expsum_shared[i0_4] = (
+ T_softmax_expsum_shared[i0_4] + T_softmax_exp[i0_4, k]
+ )
+ for i1_0 in T.serial(32):
+ for i1_1 in T.thread_binding(8, thread="threadIdx.x"):
+ with T.block("T_softmax_norm"):
+ i0_5 = T.axis.spatial(256, i0_3)
+ i1 = T.axis.spatial(256, i1_0 * 8 + i1_1)
+ T.reads(T_softmax_exp[i0_5, i1], T_softmax_expsum_shared[i0_5])
+ T.writes(T_softmax_norm[i0_5, i1])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_5, i1] = (
+ T_softmax_exp[i0_5, i1] / T_softmax_expsum_shared[i0_5]
+ )
+
+ decision_0 = [] # type: ignore
+ decision_1 = [
+ ("SampleCategorical", 7),
+ ]
+ decision_2 = [
+ ("SampleCategorical", 1),
+ ]
+ decision_3 = [
+ ("SampleCategorical", 1),
+ ("SampleCategorical", 7),
]
- target = Target("nvidia/geforce-rtx-3090", host="llvm")
- ctx = _create_context(
- create_prim_func(
- te_workload.softmax_mn(
- n=256,
- m=256,
- )
- ),
- target=target,
- rule=cross_thread_reduction(target=target),
+ mod = create_prim_func(te_workload.softmax_mn(n=256, m=256))
+ actual = ms.TuneContext(
+ mod=mod,
+ target=Target("nvidia/geforce-rtx-3090", host="llvm"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[
+ ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
+ ],
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[softmax_mn_0, softmax_mn_1, softmax_mn_2, softmax_mn_3],
+ expected_decisions=[decision_0, decision_1, decision_2, decision_3],
)
- spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
- assert len(spaces) == 4
- check_trace(spaces, expected)
def test_gpu_softmax_mn_after_inline():
- expected = [
- [],
- [
- 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
- "v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l2, l3 = sch.get_loops(block=b0)",
- "l4, l5 = sch.split(loop=l3, factors=[None, v1], preserve_unit_iters=True)",
- 'sch.bind(loop=l5, thread_axis="threadIdx.x")',
- ],
- [
- 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")',
- "b1, = sch.get_consumers(block=b0)",
- "l2, l3 = sch.get_loops(block=b1)",
- "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
- 'sch.bind(loop=l6, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)",
- 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
- "l7, l8, l9 = sch.get_loops(block=b0)",
- "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
- 'sch.bind(loop=l11, thread_axis="threadIdx.x")',
+ @T.prim_func
+ def softmax_mn_after_inline_0(
+ A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+ ) -> None:
+ T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+ T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
+ for i0, i1 in T.grid(256, 256):
+ with T.block("T_softmax_maxelem"):
+ i0_1, k = T.axis.remap("SR", [i0, i1])
+ T.reads(A[i0_1, k])
+ T.writes(T_softmax_maxelem[i0_1])
+ with T.init():
+ T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+ for i0, i1 in T.grid(256, 256):
+ with T.block("T_softmax_expsum"):
+ i0_2, k = T.axis.remap("SR", [i0, i1])
+ T.reads(A[i0_2, k], T_softmax_maxelem[i0_2])
+ T.writes(T_softmax_expsum[i0_2])
+ with T.init():
+ T_softmax_expsum[i0_2] = T.float32(0)
+ T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(
+ A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32"
+ )
+ for i0_3, i1 in T.grid(256, 256):
+ with T.block("T_softmax_norm"):
+ i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1])
+ T.reads(A[i0_4, i1_1], T_softmax_maxelem[i0_4], T_softmax_expsum[i0_4])
+ T.writes(T_softmax_norm[i0_4, i1_1])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_4, i1_1] = (
+ T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32")
+ / T_softmax_expsum[i0_4]
+ )
+
+ @T.prim_func
+ def softmax_mn_after_inline_1(
+ A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+ ) -> None:
+ T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+ T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
+ for i0, i1_0 in T.grid(256, 4):
+ for i1_1 in T.thread_binding(64, thread="threadIdx.x"):
+ with T.block("T_softmax_maxelem"):
+ i0_1 = T.axis.spatial(256, i0)
+ k = T.axis.reduce(256, i1_0 * 64 + i1_1)
+ T.reads(A[i0_1, k])
+ T.writes(T_softmax_maxelem[i0_1])
+ with T.init():
+ T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+ for i0, i1 in T.grid(256, 256):
+ with T.block("T_softmax_expsum"):
+ i0_2, k = T.axis.remap("SR", [i0, i1])
+ T.reads(A[i0_2, k], T_softmax_maxelem[i0_2])
+ T.writes(T_softmax_expsum[i0_2])
+ with T.init():
+ T_softmax_expsum[i0_2] = T.float32(0)
+ T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(
+ A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32"
+ )
+ for i0_3, i1 in T.grid(256, 256):
+ with T.block("T_softmax_norm"):
+ i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1])
+ T.reads(A[i0_4, i1_1], T_softmax_maxelem[i0_4], T_softmax_expsum[i0_4])
+ T.writes(T_softmax_norm[i0_4, i1_1])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_4, i1_1] = (
+ T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32")
+ / T_softmax_expsum[i0_4]
+ )
+
+ @T.prim_func
+ def softmax_mn_after_inline_2(
+ A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+ ) -> None:
+ T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
+ T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+ for i0, i1 in T.grid(256, 256):
+ with T.block("T_softmax_maxelem"):
+ i0_1, k = T.axis.remap("SR", [i0, i1])
+ T.reads(A[i0_1, k])
+ T.writes(T_softmax_maxelem[i0_1])
+ with T.init():
+ T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k])
+ for i0_3 in T.serial(256):
+ for ax0, ax1_0 in T.grid(1, 1):
+ for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_expsum"):
+ T.where(ax1_0 * 512 + ax1_1 < 256)
+ i0_2 = T.axis.spatial(256, ax0 + i0_3)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+ T.reads(A[i0_2, k], T_softmax_maxelem[i0_2])
+ T.writes(T_softmax_expsum_shared[i0_2])
+ with T.init():
+ T_softmax_expsum_shared[i0_2] = T.float32(0)
+ T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
+ A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32"
+ )
+ for i1_0 in T.serial(1):
+ for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_norm"):
+ T.where(i1_0 * 512 + i1_1 < 256)
+ i0_4 = T.axis.spatial(256, i0_3)
+ i1_1_1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
+ T.reads(
+ A[i0_4, i1_1_1], T_softmax_maxelem[i0_4], T_softmax_expsum_shared[i0_4]
+ )
+ T.writes(T_softmax_norm[i0_4, i1_1_1])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_4, i1_1_1] = (
+ T.exp(A[i0_4, i1_1_1] - T_softmax_maxelem[i0_4], dtype="float32")
+ / T_softmax_expsum_shared[i0_4]
+ )
+
+ @T.prim_func
+ def softmax_mn_after_inline_3(
+ A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+ ) -> None:
+ T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+ T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+ for i0_3 in T.serial(256):
+ for ax0, ax1_0 in T.grid(1, 1):
+ for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_maxelem"):
+ T.where(ax1_0 * 512 + ax1_1 < 256)
+ i0_1 = T.axis.spatial(256, ax0 + i0_3)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+ T.reads(A[i0_1, k])
+ T.writes(T_softmax_maxelem_shared[i0_1])
+ with T.init():
+ T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem_shared[i0_1] = T.max(
+ T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+ )
+ for ax0, ax1_0 in T.grid(1, 1):
+ for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_expsum"):
+ T.where(ax1_0 * 512 + ax1_1 < 256)
+ i0_2 = T.axis.spatial(256, ax0 + i0_3)
+ k = T.axis.reduce(256, ax1_0 * 512 + ax1_1)
+ T.reads(A[i0_2, k], T_softmax_maxelem_shared[i0_2])
+ T.writes(T_softmax_expsum_shared[i0_2])
+ with T.init():
+ T_softmax_expsum_shared[i0_2] = T.float32(0)
+ T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
+ A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
+ )
+ for i1_0 in T.serial(1):
+ for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+ with T.block("T_softmax_norm"):
+ T.where(i1_0 * 512 + i1_1 < 256)
+ i0_4 = T.axis.spatial(256, i0_3)
+ i1_1_1 = T.axis.spatial(256, i1_0 * 512 + i1_1)
+ T.reads(
+ A[i0_4, i1_1_1],
+ T_softmax_maxelem_shared[i0_4],
+ T_softmax_expsum_shared[i0_4],
+ )
+ T.writes(T_softmax_norm[i0_4, i1_1_1])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_4, i1_1_1] = (
+ T.exp(A[i0_4, i1_1_1] - T_softmax_maxelem_shared[i0_4], dtype="float32")
+ / T_softmax_expsum_shared[i0_4]
+ )
+
+ decision_0 = [] # type: ignore
+ decision_1 = [
+ ("SampleCategorical", 4),
+ ]
+ decision_2 = [
+ ("SampleCategorical", 7),
+ ]
+ decision_3 = [
+ ("SampleCategorical", 7),
+ ("SampleCategorical", 0),
+ ]
+
+ mod = Softmax_mn_after_inline
+ actual = ms.TuneContext(
+ mod=mod,
+ target=Target("nvidia/geforce-rtx-3090", host="llvm"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[
+ ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
],
- [
- 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
- 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")',
- "b2, = sch.get_consumers(block=b1)",
- "l3, l4 = sch.get_loops(block=b2)",
- "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
- 'sch.bind(loop=l7, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)",
- 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
- "l8, l9, l10 = sch.get_loops(block=b1)",
- "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
- 'sch.bind(loop=l12, thread_axis="threadIdx.x")',
- "b13, b14 = sch.get_consumers(block=b0)",
- "l15, l16, l17, l18 = sch.get_loops(block=b13)",
- "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True, index=-1)",
- 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
- "l19, l20, l21 = sch.get_loops(block=b0)",
- "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)",
- 'sch.bind(loop=l23, thread_axis="threadIdx.x")',
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[
+ softmax_mn_after_inline_0,
+ softmax_mn_after_inline_1,
+ softmax_mn_after_inline_2,
+ softmax_mn_after_inline_3,
],
- ]
- target = Target("nvidia/geforce-rtx-3090", host="llvm")
- ctx = _create_context(
- mod=Softmax_mn_after_inline,
- target=target,
- rule=cross_thread_reduction(target=target),
+ expected_decisions=[decision_0, decision_1, decision_2, decision_3],
)
- spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
- assert len(spaces) == 4
- check_trace(spaces, expected)
def test_gpu_batch_norm_bmn():
- expected = [
- [],
- [
- 'b0 = sch.get_block(name="C", func_name="main")',
- "b1, = sch.get_consumers(block=b0)",
- "l2, = sch.get_loops(block=b1)",
- "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
- "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)",
- 'sch.bind(loop=l5, thread_axis="threadIdx.x")',
- "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True, index=-1)",
- 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
- "l6, l7, l8, l9 = sch.get_loops(block=b0)",
- "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)",
- "l11, l12 = sch.split(loop=l10, factors=[None, v3], preserve_unit_iters=True)",
- 'sch.bind(loop=l12, thread_axis="threadIdx.x")',
- ],
+ @T.prim_func
+ def batch_norm_bmn_0(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "float32"]) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ C = T.alloc_buffer([1], dtype="float32")
+ for i0, i1, i2 in T.grid(1, 512, 512):
+ with T.block("C"):
+ b, i, j = T.axis.remap("SRR", [i0, i1, i2])
+ T.reads(A[b, i, j])
+ T.writes(C[b])
+ with T.init():
+ C[b] = T.float32(0)
+ C[b] = C[b] + A[b, i, j] * A[b, i, j]
+ for i0 in T.serial(1):
+ with T.block("D"):
+ b = T.axis.spatial(1, i0)
+ T.reads(C[b])
+ T.writes(D[b])
+ D[b] = T.sqrt(C[b], dtype="float32")
+
+ @T.prim_func
+ def batch_norm_bmn_1(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "float32"]) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ C_shared = T.alloc_buffer([1], dtype="float32", scope="shared")
+ for i0_0 in T.serial(1):
+ for ax0, ax1_ax2_fused_0 in T.grid(1, 1024):
+ for ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block("C"):
+ b = T.axis.spatial(1, ax0)
+ i = T.axis.reduce(512, (ax1_ax2_fused_0 * 256 + ax1_ax2_fused_1) // 512)
+ j = T.axis.reduce(512, (ax1_ax2_fused_0 * 256 + ax1_ax2_fused_1) % 512)
+ T.reads(A[b, i, j])
+ T.writes(C_shared[b])
+ with T.init():
+ C_shared[b] = T.float32(0)
+ C_shared[b] = C_shared[b] + A[b, i, j] * A[b, i, j]
+ for i0_1 in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block("D"):
+ T.where(i0_0 * 256 + i0_1 < 1)
+ b = T.axis.spatial(1, i0_0 * 256 + i0_1)
+ T.reads(C_shared[b])
+ T.writes(D[b])
+ D[b] = T.sqrt(C_shared[b], dtype="float32")
+
+ decision_0 = [] # type: ignore
+ decision_1 = [
+ ("SampleCategorical", 6),
]
- target = Target("nvidia/geforce-rtx-3090", host="llvm")
- ctx = _create_context(
- create_prim_func(
- te_workload.norm_bmn(
- B=1,
- M=512,
- N=512,
- )
- ),
- target=target,
- rule=cross_thread_reduction(target=target),
+
+ mod = create_prim_func(te_workload.norm_bmn(B=1, M=512, N=512))
+ actual = ms.TuneContext(
+ mod=mod,
+ target=Target("nvidia/geforce-rtx-3090", host="llvm"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[
+ ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
+ ],
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[batch_norm_bmn_0, batch_norm_bmn_1],
+ expected_decisions=[decision_0, decision_1],
)
- spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
- assert len(spaces) == 2
- check_trace(spaces, expected)
if __name__ == "__main__":
- # test_gpu_softmax_mn()
- # test_gpu_softmax_mn_after_inline()
+ test_gpu_softmax_mn()
+ test_gpu_softmax_mn_after_inline()
test_gpu_batch_norm_bmn()
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py
index 02b55350b7..8076fcaa8b 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py
@@ -17,10 +17,7 @@
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
from tvm import meta_schedule as ms
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
-from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm.meta_schedule.testing.space_generation import check_sketches
from tvm.script import tir as T
from tvm.target import Target
@@ -68,10 +65,7 @@ class ParallelizeVectorizeUnroll:
class PureSpatial:
@T.prim_func
def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None:
- # function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- # body
- # with T.block("root")
T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
@@ -224,55 +218,80 @@ class PureSpatial:
# fmt: on
-def _create_context(mod, target, rule):
- ctx = TuneContext(
- mod=mod,
- target=target,
- space_generator=PostOrderApply(),
- sch_rules=[rule],
- task_name="test",
- )
- return ctx
-
-
def test_parallel_vectorize_unroll():
- expected = [
- [
- 'b0 = sch.get_block(name="root", func_name="main")',
- 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.parallel", ann_val=512)',
- 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.vectorize", ann_val=32)',
- "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])",
- 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)',
- ]
+ @T.prim_func
+ def Matmul_0(
+ A: T.Buffer[(1024, 1024), "float32"],
+ B: T.Buffer[(1024, 1024), "float32"],
+ C: T.Buffer[(1024, 1024), "float32"],
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main"})
+ # body
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr(
+ {
+ "meta_schedule.parallel": 512,
+ "meta_schedule.unroll_explicit": 16,
+ "meta_schedule.vectorize": 32,
+ }
+ )
+ for i, j, k in T.grid(1024, 1024, 1024):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[vi, vk], B[vk, vj])
+ T.writes(C[vi, vj])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ decision_0 = [
+ ("SampleCategorical", 1),
]
+
mod = Matmul
- target = Target("llvm --num-cores=32")
- ctx = _create_context(
+ actual = ms.TuneContext(
mod=mod,
- target=target,
- rule=parallel_vectorize_unroll(target=target),
+ target=Target("llvm --num-cores=32"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[
+ ms.schedule_rule.ParallelizeVectorizeUnroll(
+ max_jobs_per_core=16,
+ max_vectorize_extent=32,
+ unroll_max_steps=[0, 16, 64, 512],
+ unroll_explicit=True,
+ ),
+ ],
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[Matmul_0],
+ expected_decisions=[decision_0],
)
- spaces = ctx.space_generator.generate_design_space(mod=mod)
- assert len(spaces) == 1
- check_trace(spaces, expected)
def test_parallel_vectorize_unroll_spatial():
mod = PureSpatial
- target = Target("llvm --num-cores=32")
- ctx = _create_context(
+ actual = ms.TuneContext(
mod=mod,
- target=target,
- rule=ms.schedule_rule.ParallelizeVectorizeUnroll(
- max_jobs_per_core=-1,
- max_vectorize_extent=-1,
- unroll_max_steps=[1, 2, 4, 8, 16, 32, 64],
- unroll_explicit=True,
- ),
- )
- spaces = ctx.space_generator.generate_design_space(mod=mod)
- assert len(spaces) == 1
- trace = spaces[0].trace.simplified(remove_postproc=True)
+ target=Target("llvm --num-cores=32"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[
+ ms.schedule_rule.ParallelizeVectorizeUnroll(
+ max_jobs_per_core=-1,
+ max_vectorize_extent=-1,
+ unroll_max_steps=[0, 16, 64, 512],
+ unroll_explicit=True,
+ ),
+ ],
+ task_name="test",
+ ).generate_design_space()
+ assert len(actual) == 1
+ trace = actual[0].trace.simplified(remove_postproc=True)
assert not trace.insts
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
index c951a5adf3..fc52aa199c 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py
@@ -16,10 +16,8 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
-from tvm.meta_schedule.schedule_rule import RandomComputeLocation
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.space_generation import check_sketches
from tvm.script import tir as T
from tvm.target import Target
@@ -55,35 +53,53 @@ class Add:
# fmt: on
-def _create_context(mod, target, rule):
- ctx = TuneContext(
- mod=mod,
- target=target,
- space_generator=PostOrderApply(),
- sch_rules=[rule],
- task_name="test",
- )
- return ctx
-
-
def test_random_compute_location():
- expected = [
- [
- 'b0 = sch.get_block(name="move", func_name="main")',
- "l1 = sch.sample_compute_location(block=b0)",
- "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True, index=-1)",
- ]
+ @T.prim_func
+ def add_0(
+ A: T.Buffer[(2048, 2048, 2048), "float32"],
+ B: T.Buffer[(2048, 2048, 2048), "float32"],
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main"})
+ # body
+ # with T.block("root")
+ A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32")
+ for i0, j0, i1, j1, k0, i2 in T.grid(128, 64, 4, 4, 64, 4):
+ for ax0, ax1, ax2 in T.grid(1, 8, 32):
+ with T.block("move"):
+ vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2 + ax0)
+ vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + ax1)
+ vk = T.axis.spatial(2048, k0 * 32 + ax2)
+ T.reads(A[vi, vj, vk])
+ T.writes(A_cached[vi, vj, vk])
+ A_cached[vi, vj, vk] = A[vi, vj, vk]
+ for j2, k1 in T.grid(8, 32):
+ with T.block("add"):
+ vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2)
+ vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2)
+ vk = T.axis.spatial(2048, k0 * 32 + k1)
+ T.reads(A_cached[vi, vj, vk])
+ T.writes(B[vi, vj, vk])
+ B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1)
+
+ decision_0 = [
+ ("SampleComputeLocation", 5),
]
+
mod = Add
- target = Target("llvm")
- ctx = _create_context(
+ actual = ms.TuneContext(
mod=mod,
- target=target,
- rule=RandomComputeLocation(),
+ target=Target("llvm"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[ms.schedule_rule.RandomComputeLocation()],
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[add_0],
+ expected_decisions=[decision_0],
)
- spaces = ctx.space_generator.generate_design_space(mod=mod)
- assert len(spaces) == 1
- check_trace(spaces, expected)
if __name__ == "__main__":