You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/09 00:06:33 UTC

[tvm] branch main updated: [MetaSchedule][Test] Add unittests for C3D (#12046)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c9356fd18 [MetaSchedule][Test] Add unittests for C3D (#12046)
6c9356fd18 is described below

commit 6c9356fd18d0be4282acf3d428ce6f72f8e91e52
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Jul 8 17:06:28 2022 -0700

    [MetaSchedule][Test] Add unittests for C3D (#12046)
---
 .../unittest/test_meta_schedule_space_cpu.py       | 198 +++++++++++++++++++++
 .../unittest/test_meta_schedule_space_cuda.py      |  96 ++++++++++
 2 files changed, 294 insertions(+)

diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py
index d6bfbde71f..259f0da07b 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -351,6 +351,204 @@ def test_cpu_c2d():
     )
 
 
+def test_cpu_c3d():
+    # fmt: off
+    @T.prim_func
+    def c3d_0(inputs: T.Buffer[(1, 16, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 7, 3, 64), "float32"], conv3d_ndhwc: T.Buffer[(1, 8, 112, 112, 64), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
+            PadInput = T.alloc_buffer([1, 22, 230, 230, 3], dtype="float32")
+            conv3d_ndhwc_global = T.alloc_buffer([1, 8, 112, 112, 64], dtype="float32")
+            for i0_0, i1_0, i2_0, i3_0, i4_0 in T.grid(1, 2, 4, 1, 2):
+                for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 13, 61, 229, 3):
+                    with T.block("PadInput"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(22, i1_0 * 8 + ax1)
+                        i2 = T.axis.spatial(230, i2_0 * 56 + ax2)
+                        i3 = T.axis.spatial(230, ax3)
+                        i4 = T.axis.spatial(3, ax4)
+                        T.reads(inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4])
+                        T.writes(PadInput[i0, i1, i2, i3, i4])
+                        PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32")
+                for i0_1, i1_1, i2_1, i3_1, i4_1 in T.grid(1, 4, 4, 14, 1):
+                    for i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
+                        with T.block("conv3d_ndhwc"):
+                            n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+                            d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
+                            h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
+                            w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
+                            co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+                            rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
+                            rh = T.axis.reduce(7, i6_0 + i6_1)
+                            rw = T.axis.reduce(7, i7_0 + i7_1)
+                            rc = T.axis.reduce(3, i8_0 + i8_1)
+                            T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
+                            T.writes(conv3d_ndhwc_global[n, d, h, w, co])
+                            T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                            with T.init():
+                                conv3d_ndhwc_global[n, d, h, w, co] = T.float32(0)
+                            conv3d_ndhwc_global[n, d, h, w, co] = conv3d_ndhwc_global[n, d, h, w, co] + PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rd, rh, rw, rc, co]
+                    for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 1, 7, 8, 32):
+                        with T.block("conv3d_ndhwc_global"):
+                            v0 = T.axis.spatial(1, ax0)
+                            v1 = T.axis.spatial(8, i1_0 * 4 + i1_1 + ax1)
+                            v2 = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + ax2)
+                            v3 = T.axis.spatial(112, i3_1 * 8 + ax3)
+                            v4 = T.axis.spatial(64, i4_0 * 32 + ax4)
+                            T.reads(conv3d_ndhwc_global[v0, v1, v2, v3, v4])
+                            T.writes(conv3d_ndhwc[v0, v1, v2, v3, v4])
+                            conv3d_ndhwc[v0, v1, v2, v3, v4] = conv3d_ndhwc_global[v0, v1, v2, v3, v4]
+    @T.prim_func
+    def c3d_1(inputs: T.Buffer[(1, 16, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 7, 3, 64), "float32"], conv3d_ndhwc: T.Buffer[(1, 8, 112, 112, 64), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64})
+            PadInput = T.alloc_buffer([1, 22, 230, 230, 3], dtype="float32")
+            conv3d_ndhwc_global = T.alloc_buffer([1, 8, 112, 112, 64], dtype="float32")
+            for i0_0, i1_0, i2_0, i3_0, i4_0 in T.grid(1, 2, 4, 1, 2):
+                for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 4, 4, 14):
+                    for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 7, 19, 21, 3):
+                        with T.block("PadInput"):
+                            i0 = T.axis.spatial(1, ax0)
+                            i1 = T.axis.spatial(22, i1_0 * 8 + i1_1 * 2 + ax1)
+                            i2 = T.axis.spatial(230, i2_0 * 56 + i2_1 * 14 + ax2)
+                            i3 = T.axis.spatial(230, i3_1 * 16 + ax3)
+                            i4 = T.axis.spatial(3, ax4)
+                            T.reads(inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4])
+                            T.writes(PadInput[i0, i1, i2, i3, i4])
+                            PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32")
+                    for i4_1, i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
+                        with T.block("conv3d_ndhwc"):
+                            n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+                            d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
+                            h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
+                            w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
+                            co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+                            rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
+                            rh = T.axis.reduce(7, i6_0 + i6_1)
+                            rw = T.axis.reduce(7, i7_0 + i7_1)
+                            rc = T.axis.reduce(3, i8_0 + i8_1)
+                            T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
+                            T.writes(conv3d_ndhwc_global[n, d, h, w, co])
+                            T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                            with T.init():
+                                conv3d_ndhwc_global[n, d, h, w, co] = T.float32(0)
+                            conv3d_ndhwc_global[n, d, h, w, co] = conv3d_ndhwc_global[n, d, h, w, co] + PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rd, rh, rw, rc, co]
+                for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 4, 28, 112, 32):
+                    with T.block("conv3d_ndhwc_global"):
+                        v0 = T.axis.spatial(1, ax0)
+                        v1 = T.axis.spatial(8, i1_0 * 4 + ax1)
+                        v2 = T.axis.spatial(112, i2_0 * 28 + ax2)
+                        v3 = T.axis.spatial(112, ax3)
+                        v4 = T.axis.spatial(64, i4_0 * 32 + ax4)
+                        T.reads(conv3d_ndhwc_global[v0, v1, v2, v3, v4])
+                        T.writes(conv3d_ndhwc[v0, v1, v2, v3, v4])
+                        conv3d_ndhwc[v0, v1, v2, v3, v4] = conv3d_ndhwc_global[v0, v1, v2, v3, v4]
+    @T.prim_func
+    def c3d_2(inputs: T.Buffer[(1, 16, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 7, 3, 64), "float32"], conv3d_ndhwc: T.Buffer[(1, 8, 112, 112, 64), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
+            PadInput = T.alloc_buffer([1, 22, 230, 230, 3], dtype="float32")
+            for i0_0, i1_0, i2_0, i3_0, i4_0, i0_1, i1_1, i2_1, i3_1 in T.grid(1, 2, 4, 1, 2, 1, 4, 4, 14):
+                for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 7, 19, 21, 3):
+                    with T.block("PadInput"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(22, i1_0 * 8 + i1_1 * 2 + ax1)
+                        i2 = T.axis.spatial(230, i2_0 * 56 + i2_1 * 14 + ax2)
+                        i3 = T.axis.spatial(230, i3_1 * 16 + ax3)
+                        i4 = T.axis.spatial(3, ax4)
+                        T.reads(inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4])
+                        T.writes(PadInput[i0, i1, i2, i3, i4])
+                        PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32")
+                for i4_1, i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
+                    with T.block("conv3d_ndhwc"):
+                        n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+                        d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
+                        h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
+                        w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
+                        co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+                        rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
+                        rh = T.axis.reduce(7, i6_0 + i6_1)
+                        rw = T.axis.reduce(7, i7_0 + i7_1)
+                        rc = T.axis.reduce(3, i8_0 + i8_1)
+                        T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
+                        T.writes(conv3d_ndhwc[n, d, h, w, co])
+                        T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                        with T.init():
+                            conv3d_ndhwc[n, d, h, w, co] = T.float32(0)
+                        conv3d_ndhwc[n, d, h, w, co] = conv3d_ndhwc[n, d, h, w, co] + PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rd, rh, rw, rc, co]
+    # fmt: on
+
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [2, 4, 1, 1]),
+        ("SamplePerfectTile", [4, 4, 1, 7]),
+        ("SamplePerfectTile", [1, 14, 1, 8]),
+        ("SamplePerfectTile", [2, 1, 32, 1]),
+        ("SamplePerfectTile", [1, 7]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [3, 1]),
+        ("SampleCategorical", 3),
+        ("SampleComputeLocation", 4),
+    ]
+    decision_1 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [2, 4, 1, 1]),
+        ("SamplePerfectTile", [4, 4, 1, 7]),
+        ("SamplePerfectTile", [1, 14, 1, 8]),
+        ("SamplePerfectTile", [2, 1, 32, 1]),
+        ("SamplePerfectTile", [1, 7]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [3, 1]),
+        ("SampleCategorical", 2),
+        ("SampleComputeLocation", 8),
+    ]
+    decision_2 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [2, 4, 1, 1]),
+        ("SamplePerfectTile", [4, 4, 1, 7]),
+        ("SamplePerfectTile", [1, 14, 1, 8]),
+        ("SamplePerfectTile", [2, 1, 32, 1]),
+        ("SamplePerfectTile", [1, 7]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [7, 1]),
+        ("SamplePerfectTile", [3, 1]),
+        ("SampleCategorical", 1),
+        ("SampleComputeLocation", 8),
+    ]
+
+    mod = create_te_workload("C3D", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[c3d_0, c3d_1, c3d_2],
+        expected_decisions=[decision_0, decision_1, decision_2],
+    )
+
+
 if __name__ == "__main__":
     test_cpu_c1d()
     test_cpu_c2d()
+    test_cpu_c3d()
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 472a7ccc13..277f74d888 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -200,6 +200,102 @@ def test_cuda_c2d():
     )
 
 
+def test_cuda_c3d():
+    # fmt: off
+    @T.prim_func
+    def c3d_0(inputs: T.Buffer[(1, 16, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 7, 3, 64), "float32"], conv3d_ndhwc: T.Buffer[(1, 8, 112, 112, 64), "float32"]) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.unroll_explicit":16})
+            conv3d_ndhwc_local = T.alloc_buffer([1, 8, 112, 112, 64], dtype="float32", scope="local")
+            PadInput_shared = T.alloc_buffer([1, 22, 230, 230, 3], dtype="float32", scope="shared")
+            weight_shared = T.alloc_buffer([7, 7, 7, 3, 64], dtype="float32", scope="shared")
+            for i0_0_i1_0_i2_0_i3_0_i4_0_fused in T.thread_binding(2, thread="blockIdx.x"):
+                for i0_1_i1_1_i2_1_i3_1_i4_1_fused in T.thread_binding(8, thread="vthread.x"):
+                    for i0_2_i1_2_i2_2_i3_2_i4_2_fused in T.thread_binding(392, thread="threadIdx.x"):
+                        for i5_0, i6_0, i7_0, i8_0 in T.grid(1, 1, 1, 1):
+                            for ax0_ax1_ax2_ax3_ax4_fused in T.serial(1687959):
+                                with T.block("PadInput_shared"):
+                                    v0 = T.axis.spatial(1, 0)
+                                    v1 = T.axis.spatial(22, ax0_ax1_ax2_ax3_ax4_fused % 1687959 // 80379)
+                                    v2 = T.axis.spatial(230, ax0_ax1_ax2_ax3_ax4_fused % 80379 // 351)
+                                    v3 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 112 + ax0_ax1_ax2_ax3_ax4_fused % 351 // 3)
+                                    v4 = T.axis.spatial(3, ax0_ax1_ax2_ax3_ax4_fused % 3)
+                                    T.reads(inputs[v0, v1 - 3, v2 - 3, v3 - 3, v4])
+                                    T.writes(PadInput_shared[v0, v1, v2, v3, v4])
+                                    T.block_attr({"meta_schedule.cooperative_fetch":4})
+                                    PadInput_shared[v0, v1, v2, v3, v4] = T.if_then_else(3 <= v1 and v1 < 19 and 3 <= v2 and v2 < 227 and 3 <= v3 and v3 < 227, inputs[v0, v1 - 3, v2 - 3, v3 - 3, v4], T.float32(0), dtype="float32")
+                            for ax0_ax1_ax2_ax3_ax4_fused in T.serial(65856):
+                                with T.block("weight_shared"):
+                                    v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_ax4_fused // 9408)
+                                    v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_ax4_fused % 9408 // 1344)
+                                    v2 = T.axis.spatial(7, ax0_ax1_ax2_ax3_ax4_fused % 1344 // 192)
+                                    v3 = T.axis.spatial(3, ax0_ax1_ax2_ax3_ax4_fused % 192 // 64)
+                                    v4 = T.axis.spatial(64, ax0_ax1_ax2_ax3_ax4_fused % 64)
+                                    T.reads(weight[v0, v1, v2, v3, v4])
+                                    T.writes(weight_shared[v0, v1, v2, v3, v4])
+                                    T.block_attr({"meta_schedule.cooperative_fetch":3})
+                                    weight_shared[v0, v1, v2, v3, v4] = weight[v0, v1, v2, v3, v4]
+                            for i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_2, i6_2, i7_2, i8_2, i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(7, 7, 1, 3, 1, 2, 2, 1, 32, 1, 1, 7, 1, 1, 1, 2, 4, 1):
+                                with T.block("conv3d_ndhwc"):
+                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
+                                    d = T.axis.spatial(8, ((0 + 0) * 4 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 392 // 98) * 2 + i1_3 + i1_4)
+                                    h = T.axis.spatial(112, (((0 * 4 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 8 // 2) * 7 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 98 // 14) * 2 + i2_3) * 2 + i2_4)
+                                    w = T.axis.spatial(112, ((i0_0_i1_0_i2_0_i3_0_i4_0_fused % 2 * 2 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 2) * 7 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 14 // 2 + i3_3) * 4 + i3_4)
+                                    co = T.axis.spatial(64, ((0 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 2) * 32 + i4_3 + i4_4)
+                                    rd = T.axis.reduce(7, i5_0 * 7 + i5_1 + i5_2)
+                                    rh = T.axis.reduce(7, i6_0 * 7 + i6_1 + i6_2)
+                                    rw = T.axis.reduce(7, (i7_0 + i7_1) * 7 + i7_2)
+                                    rc = T.axis.reduce(3, i8_0 * 3 + i8_1 + i8_2)
+                                    T.reads(PadInput_shared[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_shared[rd, rh, rw, rc, co])
+                                    T.writes(conv3d_ndhwc_local[n, d, h, w, co])
+                                    T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
+                                    with T.init():
+                                        conv3d_ndhwc_local[n, d, h, w, co] = T.float32(0)
+                                    conv3d_ndhwc_local[n, d, h, w, co] = conv3d_ndhwc_local[n, d, h, w, co] + PadInput_shared[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight_shared[rd, rh, rw, rc, co]
+                        for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 2, 4, 4, 32):
+                            with T.block("conv3d_ndhwc_local"):
+                                v0 = T.axis.spatial(1, ax0)
+                                v1 = T.axis.spatial(8, i0_2_i1_2_i2_2_i3_2_i4_2_fused // 98 * 2 + ax1)
+                                v2 = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_i4_1_fused // 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 98 // 14 * 4 + ax2)
+                                v3 = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 56 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 14 // 2 * 4 + ax3)
+                                v4 = T.axis.spatial(64, i0_2_i1_2_i2_2_i3_2_i4_2_fused % 2 * 32 + ax4)
+                                T.reads(conv3d_ndhwc_local[v0, v1, v2, v3, v4])
+                                T.writes(conv3d_ndhwc[v0, v1, v2, v3, v4])
+                                conv3d_ndhwc[v0, v1, v2, v3, v4] = conv3d_ndhwc_local[v0, v1, v2, v3, v4]
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 4, 2, 1]),
+        ("SamplePerfectTile", [1, 4, 7, 2, 2]),
+        ("SamplePerfectTile", [2, 2, 7, 1, 4]),
+        ("SamplePerfectTile", [1, 1, 2, 32, 1]),
+        ("SamplePerfectTile", [1, 7, 1]),
+        ("SamplePerfectTile", [1, 7, 1]),
+        ("SamplePerfectTile", [1, 1, 7]),
+        ("SamplePerfectTile", [1, 3, 1]),
+        ("SampleCategorical", 3),
+        ("SampleCategorical", 2),
+        ("SampleCategorical", 1),
+    ]
+    mod = create_te_workload("C3D", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[c3d_0],
+        expected_decisions=[decision_0],
+    )
+
+
 if __name__ == "__main__":
     test_cuda_c1d()
     test_cuda_c2d()
+    test_cuda_c3d()