You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/12/10 03:49:04 UTC

[GitHub] [tvm] nverke commented on a diff in pull request #13110: [Hexagon] Add a test to show how to use multi input async dma pipelin…

nverke commented on code in PR #13110:
URL: https://github.com/apache/tvm/pull/13110#discussion_r1044960977


##########
tests/python/contrib/test_hexagon/test_async_dma_pipeline.py:
##########
@@ -349,5 +449,313 @@ def test_loading_vtcm_for_vrmpy(
                 "async_dma_input": async_input_runtime,
                 "async_dma_output": async_output_runtime,
                 "async_dma_input_output": async_input_output_runtime,
+                "async_dma_multi_input_output": async_multi_input_output_runtime,
+                "async_input_output_runtime_larger_buffers": async_input_output_runtime_larger_buffers,
             },
         )
+
+
+# from tvm.script import tir as T
+@tvm.script.ir_module
+class ModulePipelined:
+    @T.prim_func
+    def main(
+        p0: T.Buffer[(1, 1, 230, 230, 4), "uint8"],
+        p1: T.Buffer[(2, 1, 7, 7, 1, 32, 4), "int8"],
+        T_cast: T.Buffer[(1, 2, 112, 112, 32), "int32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
+        # body
+        # with T.block("root")
+        conv2d_NCHWc_int8 = T.alloc_buffer([1, 2, 112, 112, 32], dtype="int32", scope="global.vtcm")
+        p0_global_vtcm = T.alloc_buffer([1, 1, 230, 230, 4], dtype="uint8", scope="global.vtcm")
+        p1_global_vtcm = T.alloc_buffer([2, 1, 7, 7, 1, 32, 4], dtype="int8", scope="global.vtcm")
+        for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(2, 1, 7, 7, 1, 32, 4):
+            with T.block("p1_global.vtcm"):
+                v0, v1, v2, v3, v4, v5, v6 = T.axis.remap(
+                    "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6]
+                )
+                T.reads(p1[v0, v1, v2, v3, v4, v5, v6])
+                T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6])
+                p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2, v3, v4, v5, v6]
+        for po in T.serial(4):
+            for i in T.serial(55876):
+                with T.block("p0_global.vtcm"):
+                    v0 = T.axis.spatial(1, 0)
+                    v1 = T.axis.spatial(1, 0)
+                    v2 = T.axis.spatial(230, po * 56 + i // 916)
+                    v3 = T.axis.spatial(230, i % 916 // 4)
+                    v4 = T.axis.spatial(4, i % 4)
+                    T.reads(p0[v0, v1, v2, v3, v4])
+                    T.writes(p0_global_vtcm[v0, v1, v2, v3, v4])
+                    p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4]
+            for i in T.parallel(28):
+                for ii, iii, iiii in T.grid(2, 14, 8):
+                    with T.block("conv2d_NCHWc_int8_o_init"):
+                        n = T.axis.spatial(1, 0)
+                        oc_chunk = T.axis.spatial(2, ii)
+                        oh = T.axis.spatial(112, (po * 28 + i) // 14 * 14 + iii)
+                        ow = T.axis.spatial(112, (po * 28 + i) % 14 * 8 + iiii)
+                        oc_block_o = T.axis.spatial(1, 0)
+                        T.reads()
+                        T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32])
+                        for i4_1 in T.vectorized(32):
+                            with T.block("conv2d_NCHWc_int8_init"):
+                                oc_block_i_init = T.axis.spatial(32, i4_1)
+                                T.reads()
+                                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init])
+                                conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0
+                for i1_1, i5_1, i6_1, i2_2, i3_2 in T.grid(2, 7, 7, 14, 8):
+                    with T.block("conv2d_NCHWc_int8_o_update"):
+                        n = T.axis.spatial(1, 0)
+                        oc_chunk = T.axis.spatial(2, i1_1)
+                        oh = T.axis.spatial(112, (po * 28 + i) // 14 * 14 + i2_2)
+                        ow = T.axis.spatial(112, (po * 28 + i) % 14 * 8 + i3_2)
+                        oc_block_o = T.axis.spatial(1, 0)
+                        kh = T.axis.reduce(7, i5_1)
+                        kw = T.axis.reduce(7, i6_1)
+                        ic_outer = T.axis.reduce(1, 0)
+                        ic_f_inner = T.axis.reduce(1, 0)
+                        ic_s_inner_o = T.axis.reduce(1, 0)
+                        T.reads(
+                            conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32],
+                            p0_global_vtcm[
+                                n,
+                                ic_outer,
+                                oh * 2 + kh,
+                                ow * 2 + kw,
+                                ic_f_inner * 4 : ic_f_inner * 4 + 4,
+                            ],
+                            p1_global_vtcm[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4],
+                        )
+                        T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32])
+                        A = T.match_buffer(
+                            p0_global_vtcm[
+                                n,
+                                ic_outer,
+                                oh * 2 + kh,
+                                ow * 2 + kw,
+                                ic_f_inner * 4 : ic_f_inner * 4 + 4,
+                            ],
+                            [4],
+                            dtype="uint8",
+                            offset_factor=1,
+                            scope="global.vtcm",
+                        )
+                        B = T.match_buffer(
+                            p1_global_vtcm[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4],
+                            [32, 4],
+                            dtype="int8",
+                            offset_factor=1,
+                            scope="global.vtcm",
+                        )
+                        C = T.match_buffer(
+                            conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32],
+                            [32],
+                            dtype="int32",
+                            offset_factor=1,
+                            scope="global.vtcm",
+                        )
+                        A_u8x4: T.uint8x4 = A[0:4]
+                        A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
+                        B_i8x128 = B[0, 0:128]
+                        B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
+                        C[0:32] = T.call_llvm_pure_intrin(
+                            4217,

Review Comment:
   Ahh interesting thought that they were tied to each intrin, will update accordingly! 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org