You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/09 02:02:01 UTC
[tvm] branch main updated: [MetaSchedule][Test] Add unittests for CAP (#12047)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 00ce86d68b [MetaSchedule][Test] Add unittests for CAP (#12047)
00ce86d68b is described below
commit 00ce86d68b123f3389b0fca1eca72e81f6054443
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Jul 8 19:01:54 2022 -0700
[MetaSchedule][Test] Add unittests for CAP (#12047)
---
.../unittest/test_meta_schedule_space_cpu.py | 194 +++++++++++++++++++++
.../unittest/test_meta_schedule_space_cuda.py | 102 +++++++++++
2 files changed, 296 insertions(+)
diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py
index 259f0da07b..87f61ec328 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -548,7 +548,201 @@ def test_cpu_c3d():
)
+def test_cpu_cap():
+ # fmt: off
+ @T.prim_func
+ def cap_0(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[(3, 3, 4, 4, 32, 32), "float32"], conv2d_capsule_nhwijc: T.Buffer[(1, 8, 8, 4, 4, 32), "float32"]) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64})
+ PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32")
+ conv2d_capsule_nhwijc_global = T.alloc_buffer([1, 8, 8, 4, 4, 32], dtype="float32")
+ for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i0_1, i1_1 in T.grid(1, 2, 1, 1, 1, 1, 1, 4):
+ for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 3, 17, 4, 4, 32):
+ with T.block("PadInput"):
+ i0 = T.axis.spatial(1, ax0)
+ i1 = T.axis.spatial(18, i1_0 * 8 + i1_1 * 2 + ax1)
+ i2 = T.axis.spatial(18, ax2)
+ i3, i4, i5 = T.axis.remap("SSS", [ax3, ax4, ax5])
+ T.reads(inputs[i0, i1 - 1, i2 - 1, i3, i4, i5])
+ T.writes(PadInput[i0, i1, i2, i3, i4, i5])
+ PadInput[i0, i1, i2, i3, i4, i5] = T.if_then_else(1 <= i1 and i1 < 17 and 1 <= i2 and i2 < 17, inputs[i0, i1 - 1, i2 - 1, i3, i4, i5], T.float32(0), dtype="float32")
+ for i2_1, i3_1, i4_1, i5_1 in T.grid(4, 1, 4, 2):
+ for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
+ with T.block("conv2d_capsule_nhwijc"):
+ n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+ h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
+ w = T.axis.spatial(8, (i2_0 * 4 + i2_1) * 2 + i2_2 + i2_3)
+ cap_i = T.axis.spatial(4, (i3_0 + i3_1 + i3_2) * 4 + i3_3)
+ cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3)
+ co = T.axis.spatial(32, (i5_0 * 2 + i5_1 + i5_2) * 16 + i5_3)
+ rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
+ rw = T.axis.reduce(3, i7_0 + i7_1)
+ cap_k = T.axis.reduce(4, i8_0 + i8_1)
+ rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
+ T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
+ T.writes(conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ with T.init():
+ conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = T.float32(0)
+ conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co]
+ for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 1, 2, 4, 1, 16):
+ with T.block("conv2d_capsule_nhwijc_global"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(8, i1_0 * 4 + i1_1 + ax1)
+ v2 = T.axis.spatial(8, i2_1 * 2 + ax2)
+ v3 = T.axis.spatial(4, ax3)
+ v4 = T.axis.spatial(4, i4_1 + ax4)
+ v5 = T.axis.spatial(32, i5_1 * 16 + ax5)
+ T.reads(conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5])
+ T.writes(conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5])
+ conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5]
+ @T.prim_func
+ def cap_1(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[(3, 3, 4, 4, 32, 32), "float32"], conv2d_capsule_nhwijc: T.Buffer[(1, 8, 8, 4, 4, 32), "float32"]) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":0, "meta_schedule.vectorize":64})
+ PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32")
+ conv2d_capsule_nhwijc_global = T.alloc_buffer([1, 8, 8, 4, 4, 32], dtype="float32")
+ for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0 in T.grid(1, 2, 1, 1, 1, 1):
+ for i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 in T.grid(1, 4, 4, 1, 4, 2):
+ for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 3, 5, 4, 4, 32):
+ with T.block("PadInput"):
+ i0 = T.axis.spatial(1, ax0)
+ i1 = T.axis.spatial(18, i1_0 * 8 + i1_1 * 2 + ax1)
+ i2 = T.axis.spatial(18, i2_1 * 4 + ax2)
+ i3, i4, i5 = T.axis.remap("SSS", [ax3, ax4, ax5])
+ T.reads(inputs[i0, i1 - 1, i2 - 1, i3, i4, i5])
+ T.writes(PadInput[i0, i1, i2, i3, i4, i5])
+ PadInput[i0, i1, i2, i3, i4, i5] = T.if_then_else(1 <= i1 and i1 < 17 and 1 <= i2 and i2 < 17, inputs[i0, i1 - 1, i2 - 1, i3, i4, i5], T.float32(0), dtype="float32")
+ for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
+ with T.block("conv2d_capsule_nhwijc"):
+ n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+ h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
+ w = T.axis.spatial(8, (i2_0 * 4 + i2_1) * 2 + i2_2 + i2_3)
+ cap_i = T.axis.spatial(4, (i3_0 + i3_1 + i3_2) * 4 + i3_3)
+ cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3)
+ co = T.axis.spatial(32, (i5_0 * 2 + i5_1 + i5_2) * 16 + i5_3)
+ rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
+ rw = T.axis.reduce(3, i7_0 + i7_1)
+ cap_k = T.axis.reduce(4, i8_0 + i8_1)
+ rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
+ T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
+ T.writes(conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ with T.init():
+ conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = T.float32(0)
+ conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc_global[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co]
+ for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 4, 8, 4, 4, 32):
+ with T.block("conv2d_capsule_nhwijc_global"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(8, i1_0 * 4 + ax1)
+ v2, v3, v4, v5 = T.axis.remap("SSSS", [ax2, ax3, ax4, ax5])
+ T.reads(conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5])
+ T.writes(conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5])
+ conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5]
+ @T.prim_func
+ def cap_2(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[(3, 3, 4, 4, 32, 32), "float32"], conv2d_capsule_nhwijc: T.Buffer[(1, 8, 8, 4, 4, 32), "float32"]) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
+ PadInput = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32")
+ for i0, i1, i2, i3, i4, i5 in T.grid(1, 18, 18, 4, 4, 32):
+ with T.block("PadInput"):
+ i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5])
+ T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1, i4_1, i5_1])
+ T.writes(PadInput[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1])
+ PadInput[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1, i4_1, i5_1], T.float32(0), dtype="float32")
+ for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_1_1, i5_1_1, i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 4, 4, 1, 4, 2, 1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
+ with T.block("conv2d_capsule_nhwijc"):
+ n = T.axis.spatial(1, i0_3 + i0_2 + i0_1_1 + i0_0)
+ h = T.axis.spatial(8, i1_0 * 4 + i1_1_1 + i1_2 + i1_3)
+ w = T.axis.spatial(8, (i2_0 * 4 + i2_1_1) * 2 + i2_2 + i2_3)
+ cap_i = T.axis.spatial(4, (i3_0 + i3_1_1 + i3_2) * 4 + i3_3)
+ cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1_1 + i4_2 + i4_3)
+ co = T.axis.spatial(32, (i5_0 * 2 + i5_1_1 + i5_2) * 16 + i5_3)
+ rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
+ rw = T.axis.reduce(3, i7_0 + i7_1)
+ cap_k = T.axis.reduce(4, i8_0 + i8_1)
+ rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
+ T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
+ T.writes(conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ with T.init():
+ conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] = T.float32(0)
+ conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc[n, h, w, cap_i, cap_j, co] + PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight[rh, rw, cap_k, cap_j, rc, co]
+ # fmt: on
+ decision_0 = [
+ ("SamplePerfectTile", [1, 1, 1, 1]),
+ ("SamplePerfectTile", [2, 4, 1, 1]),
+ ("SamplePerfectTile", [1, 4, 2, 1]),
+ ("SamplePerfectTile", [1, 1, 1, 4]),
+ ("SamplePerfectTile", [1, 4, 1, 1]),
+ ("SamplePerfectTile", [1, 2, 1, 16]),
+ ("SamplePerfectTile", [1, 3]),
+ ("SamplePerfectTile", [3, 1]),
+ ("SamplePerfectTile", [4, 1]),
+ ("SamplePerfectTile", [1, 32]),
+ ("SampleCategorical", 0),
+ ("SampleComputeLocation", 7),
+ ]
+ decision_1 = [
+ ("SamplePerfectTile", [1, 1, 1, 1]),
+ ("SamplePerfectTile", [2, 4, 1, 1]),
+ ("SamplePerfectTile", [1, 4, 2, 1]),
+ ("SamplePerfectTile", [1, 1, 1, 4]),
+ ("SamplePerfectTile", [1, 4, 1, 1]),
+ ("SamplePerfectTile", [1, 2, 1, 16]),
+ ("SamplePerfectTile", [1, 3]),
+ ("SamplePerfectTile", [3, 1]),
+ ("SamplePerfectTile", [4, 1]),
+ ("SamplePerfectTile", [1, 32]),
+ ("SampleCategorical", 0),
+ ("SampleComputeLocation", 11),
+ ]
+ decision_2 = [
+ ("SamplePerfectTile", [1, 1, 1, 1]),
+ ("SamplePerfectTile", [2, 4, 1, 1]),
+ ("SamplePerfectTile", [1, 4, 2, 1]),
+ ("SamplePerfectTile", [1, 1, 1, 4]),
+ ("SamplePerfectTile", [1, 4, 1, 1]),
+ ("SamplePerfectTile", [1, 2, 1, 16]),
+ ("SamplePerfectTile", [1, 3]),
+ ("SamplePerfectTile", [3, 1]),
+ ("SamplePerfectTile", [4, 1]),
+ ("SamplePerfectTile", [1, 32]),
+ ("SampleCategorical", 1),
+ ("SampleComputeLocation", -1),
+ ]
+ mod = create_te_workload("CAP", 0)
+ actual = ms.TuneContext(
+ mod=mod,
+ target=_target(),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules="default",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[cap_0, cap_1, cap_2],
+ expected_decisions=[decision_0, decision_1, decision_2],
+ )
+
+
if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
test_cpu_c3d()
+ test_cpu_cap()
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 277f74d888..bffb80436c 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -295,7 +295,109 @@ def test_cuda_c3d():
)
+def test_cuda_cap():
+ # fmt: off
+ @T.prim_func
+ def cap_0(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[(3, 3, 4, 4, 32, 32), "float32"], conv2d_capsule_nhwijc: T.Buffer[(1, 8, 8, 4, 4, 32), "float32"]) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.unroll_explicit":64})
+ conv2d_capsule_nhwijc_local = T.alloc_buffer([1, 8, 8, 4, 4, 32], dtype="float32", scope="local")
+ PadInput_shared = T.alloc_buffer([1, 18, 18, 4, 4, 32], dtype="float32", scope="shared")
+ weight_shared = T.alloc_buffer([3, 3, 4, 4, 32, 32], dtype="float32", scope="shared")
+ for i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused in T.thread_binding(256, thread="blockIdx.x"):
+ for i0_1_i1_1_i2_1_i3_1_i4_1_i5_1_fused in T.thread_binding(1, thread="vthread.x"):
+ for i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused in T.thread_binding(4, thread="threadIdx.x"):
+ for i6_0, i7_0, i8_0, i9_0 in T.grid(3, 3, 2, 8):
+ for ax0_ax1_ax2_ax3_ax4_ax5_fused in T.serial(48):
+ with T.block("PadInput_shared"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 4 + i6_0 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16)
+ v2 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0 + 0)
+ v3 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 16 // 8)
+ v4 = T.axis.spatial(4, i8_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8 // 4)
+ v5 = T.axis.spatial(32, i9_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 4)
+ T.reads(inputs[v0, v1 - 1, v2 - 1, v3, v4, v5])
+ T.writes(PadInput_shared[v0, v1, v2, v3, v4, v5])
+ T.block_attr({"meta_schedule.cooperative_fetch":2})
+ PadInput_shared[v0, v1, v2, v3, v4, v5] = T.if_then_else(1 <= v1 and v1 < 17 and 1 <= v2 and v2 < 17, inputs[v0, v1 - 1, v2 - 1, v3, v4, v5], T.float32(0), dtype="float32")
+ for ax0_ax1_ax2_ax3_ax4_ax5_fused in T.serial(256):
+ with T.block("weight_shared"):
+ v0, v1 = T.axis.remap("SS", [i6_0, i7_0])
+ v2 = T.axis.spatial(4, i8_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused // 128)
+ v3 = T.axis.spatial(4, ax0_ax1_ax2_ax3_ax4_ax5_fused % 128 // 32)
+ v4 = T.axis.spatial(32, i9_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 32 // 8)
+ v5 = T.axis.spatial(32, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8)
+ T.reads(weight[v0, v1, v2, v3, v4, v5])
+ T.writes(weight_shared[v0, v1, v2, v3, v4, v5])
+ T.block_attr({"meta_schedule.cooperative_fetch":4})
+ weight_shared[v0, v1, v2, v3, v4, v5] = weight[v0, v1, v2, v3, v4, v5]
+ for i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3, i6_2, i7_2, i8_2, i9_2, i0_4, i1_4, i2_4, i3_4, i4_4, i5_4 in T.grid(1, 1, 1, 4, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 8):
+ with T.block("conv2d_capsule_nhwijc"):
+ n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
+ h = T.axis.spatial(8, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 256 // 64 + 0 + 0) * 2 + i1_3 + i1_4)
+ w = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 + 0 + 0 + i2_3 + i2_4)
+ cap_i = T.axis.spatial(4, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 4 // 2 + i3_3 + i3_4)
+ cap_j = T.axis.spatial(4, ((0 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 2 + i4_3) * 2 + i4_4)
+ co = T.axis.spatial(32, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 + 0 + 0 + i5_3) * 8 + i5_4)
+ rh = T.axis.reduce(3, i6_0 + i6_1 + i6_2)
+ rw = T.axis.reduce(3, i7_0 + i7_1 + i7_2)
+ cap_k = T.axis.reduce(4, (i8_0 + i8_1) * 2 + i8_2)
+ rc = T.axis.reduce(32, i9_0 * 4 + i9_1 + i9_2)
+ T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight_shared[rh, rw, cap_k, cap_j, rc, co])
+ T.writes(conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co])
+ T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
+ with T.init():
+ conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co] = T.float32(0)
+ conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co] = conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co] + PadInput_shared[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc] * weight_shared[rh, rw, cap_k, cap_j, rc, co]
+ for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 2, 1, 1, 2, 8):
+ with T.block("conv2d_capsule_nhwijc_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 2 + ax1)
+ v2 = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 + ax2)
+ v3 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused // 2 + ax3)
+ v4 = T.axis.spatial(4, i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 2 * 2 + ax4)
+ v5 = T.axis.spatial(32, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 * 8 + ax5)
+ T.reads(conv2d_capsule_nhwijc_local[v0, v1, v2, v3, v4, v5])
+ T.writes(conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5])
+ conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_local[v0, v1, v2, v3, v4, v5]
+ # fmt: on
+ decision_0 = [
+ ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+ ("SamplePerfectTile", [4, 1, 1, 2, 1]),
+ ("SamplePerfectTile", [8, 1, 1, 1, 1]),
+ ("SamplePerfectTile", [2, 1, 2, 1, 1]),
+ ("SamplePerfectTile", [1, 1, 2, 1, 2]),
+ ("SamplePerfectTile", [4, 1, 1, 1, 8]),
+ ("SamplePerfectTile", [3, 1, 1]),
+ ("SamplePerfectTile", [3, 1, 1]),
+ ("SamplePerfectTile", [2, 1, 2]),
+ ("SamplePerfectTile", [8, 4, 1]),
+ ("SampleCategorical", 1),
+ ("SampleCategorical", 3),
+ ("SampleCategorical", 2),
+ ]
+ mod = create_te_workload("CAP", 0)
+ actual = ms.TuneContext(
+ mod=mod,
+ target=_target(),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules="default",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[cap_0],
+ expected_decisions=[decision_0],
+ )
+
+
if __name__ == "__main__":
test_cuda_c1d()
test_cuda_c2d()
test_cuda_c3d()
+ test_cuda_cap()