You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/26 22:39:45 UTC

[GitHub] [tvm] junrushao1994 commented on a change in pull request #10066: [TIR] Add software pipelining

junrushao1994 commented on a change in pull request #10066:
URL: https://github.com/apache/tvm/pull/10066#discussion_r793110174



##########
File path: tests/python/unittest/test_tir_transform_inject_software_pipeline.py
##########
@@ -0,0 +1,824 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import sys
+
+import tvm
+from tvm import tir, te, TVMError
+from tvm.script import tir as T
+
+
+def _check(original, transformed):
+    func = original
+    mod = tvm.IRModule.from_expr(func)
+    mod = tvm.tir.transform.InjectSoftwarePipeline()(mod)
+    mod = tvm.tir.transform.Simplify()(mod)
+    tvm.ir.assert_structural_equal(mod["main"], transformed, True)
+
+
+def _check_error(func):
+    mod = tvm.IRModule.from_expr(func)
+    with pytest.raises(ValueError):
+        tvm.tir.transform.InjectSoftwarePipeline()(mod)
+
+
+@T.prim_func
+def trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}
+        ):
+            with T.block():
+                T.reads(A[tx, i])
+                T.writes(C[tx, i])
+                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, i])
+                    T.writes(B[tx, 0])
+                    B[tx, 0] = A[tx, i] * T.float32(2)
+                with T.block():
+                    T.reads(B[tx, 0])
+                    T.writes(C[tx, i])
+                    C[tx, i] = B[tx, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_trivial_pipeline(
+    A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]
+) -> None:
+    for tx in T.thread_binding(16, thread="threadIdx.x"):
+        with T.block():
+            T.reads(A[tx, 0])
+            T.writes(C[tx, 0])
+            B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+            with T.block():
+                T.reads(A[tx, 0])
+                T.writes(B[0, tx, 0])
+                B[0, tx, 0] = A[tx, 0] * T.float32(2)
+            with T.block():
+                T.reads()
+                T.writes()
+                T.evaluate(0)
+            with T.block():
+                T.reads(B[0, tx, 0])
+                T.writes(C[tx, 0])
+                C[tx, 0] = B[0, tx, 0] + T.float32(1)
+
+
+@T.prim_func
+def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0,
+            16,
+            annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]},
+        ):
+            with T.block():
+                T.reads(A[tx, i])
+                T.writes(C[tx, i])
+                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, i])
+                    T.writes(B[tx, 0])
+                    B[tx, 0] = A[tx, i] * T.float32(2)
+                with T.block():
+                    T.reads(B[tx, 0])
+                    T.writes(C[tx, i])
+                    C[tx, i] = B[tx, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_simple_compute(
+    A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]
+) -> None:
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        with T.block():
+            T.reads([A[tx, 0:16]])
+            T.writes([C[tx, 0:16]])
+            B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+            with T.block():
+                T.reads([A[tx, 0]])
+                T.writes([B[0, tx, 0]])
+                B[0, tx, 0] = A[tx, 0] * T.float32(2)
+            with T.block():
+                T.reads([A[tx, 1:16], B[0:2, tx, 0]])
+                T.writes([B[0:2, tx, 0], C[tx, 0:15]])
+                for i in T.serial(0, 15):
+                    with T.block():
+                        T.reads([A[tx, i + 1]])
+                        T.writes([B[(i + 1) % 2, tx, 0]])
+                        B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+                    with T.block():
+                        T.reads([B[i % 2, tx, 0]])
+                        T.writes([C[tx, i]])
+                        C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
+            with T.block():
+                T.reads([B[1, tx, 0]])
+                T.writes([C[tx, 15]])
+                C[tx, 15] = B[1, tx, 0] + T.float32(1)
+
+
+@T.prim_func
+def nested_pipeline_simple(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0,
+            16,
+            annotations={
+                "software_pipeline_stage": [0, 1, 1, 1],
+                "software_pipeline_order": [0, 1, 2, 3],
+            },
+        ):
+            with T.block():
+                T.reads(A[tx, i, 0:16])
+                T.writes(C[tx, i, 0:16])
+                A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared")
+                for j in T.serial(0, 16):
+                    with T.block():
+                        T.reads(A[tx, i, j])
+                        T.writes(A_shared[tx, 0, j])
+                        A_shared[tx, 0, j] = A[tx, i, j]
+                for j in T.serial(
+                    0,
+                    16,
+                    annotations={
+                        "software_pipeline_stage": [0, 1],
+                        "software_pipeline_order": [0, 1],
+                    },
+                ):
+                    with T.block():
+                        T.reads(A_shared[tx, 0, j])
+                        T.writes(C[tx, i, j])
+                        B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared")
+                        with T.block():
+                            T.reads(A_shared[tx, i, j])
+                            T.writes(B[tx, i, 0])
+                            B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2)
+                        with T.block():
+                            T.reads(B[tx, i, 0])
+                            T.writes(C[tx, i, j])
+                            C[tx, i, j] = B[tx, i, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_nested_pipeline_simple(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+) -> None:
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        with T.block():
+            T.reads([A[tx, 0:16, 0:16]])
+            T.writes([C[tx, 0:16, 0:16]])
+            A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared")
+            B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
+            with T.block():
+                T.reads([A[tx, 0, 0:16]])
+                T.writes([A_shared[0, tx, 0, 0:16]])
+                for j in T.serial(0, 16):
+                    with T.block():
+                        T.reads([A[tx, 0, j]])
+                        T.writes([A_shared[0, tx, 0, j]])
+                        A_shared[0, tx, 0, j] = A[tx, 0, j]
+            with T.block():
+                T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:15, 0:16], B[0:2, tx, 0:15, 0]])
+                T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:15, 0], C[tx, 0:15, 0:16]])
+                for i in T.serial(0, 15):
+                    with T.block():
+                        T.reads([A[tx, i + 1, 0:16]])
+                        T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
+                        for j in T.serial(0, 16):
+                            with T.block():
+                                T.reads([A[tx, i + 1, j]])
+                                T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
+                                A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
+                    with T.block():
+                        T.reads([A_shared[i % 2, tx, i, 0]])
+                        T.writes([B[0, tx, i, 0]])
+                        B[0, tx, i, 0] = A_shared[i % 2, tx, 0, 0] * T.float32(2)
+                    with T.block():
+                        T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
+                        T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
+                        for j in T.serial(0, 15):
+                            with T.block():
+                                T.reads([A_shared[i % 2, tx, i, j + 1]])
+                                T.writes([B[(j + 1) % 2, tx, i, 0]])
+                                B[(j + 1) % 2, tx, i, 0] = A_shared[
+                                    i % 2, tx, 0, j + 1
+                                ] * T.float32(2)
+                            with T.block():
+                                T.reads([B[j % 2, tx, i, 0]])
+                                T.writes([C[tx, i, j]])
+                                C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
+                    with T.block():
+                        T.reads([B[1, tx, i, 0]])
+                        T.writes([C[tx, i, 15]])
+                        C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1)
+            with T.block():
+                T.reads([A_shared[1, tx, 15, 0:16], B[0:2, tx, 15, 0]])
+                T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]])
+                with T.block():
+                    T.reads([A_shared[1, tx, 15, 0]])
+                    T.writes([B[0, tx, 15, 0]])
+                    B[0, tx, 15, 0] = A_shared[1, tx, 0, 0] * T.float32(2)
+                with T.block():
+                    T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]])
+                    T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]])
+                    for j in T.serial(0, 15):
+                        with T.block():
+                            T.reads([A_shared[1, tx, 15, j + 1]])
+                            T.writes([B[(j + 1) % 2, tx, 15, 0]])
+                            B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
+                        with T.block():
+                            T.reads([B[j % 2, tx, 15, 0]])
+                            T.writes([C[tx, 15, j]])
+                            C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1)
+                with T.block():
+                    T.reads([B[1, tx, 15, 0]])
+                    T.writes([C[tx, 15, 15]])
+                    C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1)
+
+
+@T.prim_func
+def nested_pipeline_prefetch_inner(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0,
+            16,
+            annotations={
+                "software_pipeline_stage": [0, 0, 1, 1],
+                "software_pipeline_order": [0, 2, 1, 3],
+            },
+        ):
+            with T.block():
+                T.reads(A[tx, i, 0:16])
+                T.writes(C[tx, i, 0:16])
+                A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared")
+                for j in T.serial(0, 16):
+                    with T.block():
+                        T.reads(A[tx, i, j])
+                        T.writes(A_shared[tx, 0, j])
+                        A_shared[tx, 0, j] = A[tx, i, j]
+                for j in T.serial(
+                    0,
+                    16,
+                    annotations={
+                        "software_pipeline_stage": [0, 1],
+                        "software_pipeline_order": [0, 1],
+                    },
+                ):
+                    with T.block():
+                        T.reads(A_shared[tx, 0, j])
+                        T.writes(C[tx, i, j])
+                        B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared")
+                        with T.block():
+                            T.reads(A_shared[tx, i, j])
+                            T.writes(B[tx, i, 0])
+                            B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2)
+                        with T.block():
+                            T.reads(B[tx, i, 0])
+                            T.writes(C[tx, i, j])
+                            C[tx, i, j] = B[tx, i, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_nested_pipeline_prefetch_inner(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+) -> None:
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        with T.block():
+            T.reads([A[tx, 0:16, 0:16]])
+            T.writes([C[tx, 0:16, 0:16]])
+            A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared")
+            B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
+            with T.block():
+                T.reads([A[tx, 0, 0:16], A_shared[0, tx, 0, 0]])
+                T.writes([A_shared[0, tx, 0, 0:16], B[0, tx, 0, 0]])
+                with T.block():
+                    T.reads([A[tx, 0, 0:16]])
+                    T.writes([A_shared[0, tx, 0, 0:16]])
+                    for j in T.serial(0, 16):
+                        with T.block():
+                            T.reads([A[tx, 0, j]])
+                            T.writes([A_shared[0, tx, 0, j]])
+                            A_shared[0, tx, 0, j] = A[tx, 0, j]
+                with T.block():
+                    T.reads([A_shared[0, tx, 0, 0]])
+                    T.writes([B[0, tx, 0, 0]])
+                    B[0, tx, 0, 0] = A_shared[0, tx, 0, 0] * T.float32(2)
+            with T.block():
+                T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:16, 0:16], B[0:2, tx, 0:15, 0]])
+                T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:16, 0], C[tx, 0:15, 0:16]])
+                for i in T.serial(0, 15):
+                    with T.block():
+                        T.reads([A[tx, i + 1, 0:16]])
+                        T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
+                        for j in T.serial(0, 16):
+                            with T.block():
+                                T.reads([A[tx, i + 1, j]])
+                                T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
+                                A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
+                    with T.block():
+                        T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
+                        T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
+                        for j in T.serial(0, 15):
+                            with T.block():
+                                T.reads([A_shared[i % 2, tx, i, j + 1]])
+                                T.writes([B[(j + 1) % 2, tx, i, 0]])
+                                B[(j + 1) % 2, tx, i, 0] = A_shared[
+                                    i % 2, tx, 0, j + 1
+                                ] * T.float32(2)
+                            with T.block():
+                                T.reads([B[j % 2, tx, i, 0]])
+                                T.writes([C[tx, i, j]])
+                                C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
+                    with T.block():
+                        T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]])
+                        T.writes([B[0, tx, i + 1, 0]])
+                        B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2)
+                    with T.block():
+                        T.reads([B[1, tx, i, 0]])
+                        T.writes([C[tx, i, 15]])
+                        C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1)
+            with T.block():
+                T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]])
+                T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]])
+                with T.block():
+                    T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]])
+                    T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]])
+                    for j in T.serial(0, 15):
+                        with T.block():
+                            T.reads([A_shared[1, tx, 15, j + 1]])
+                            T.writes([B[(j + 1) % 2, tx, 15, 0]])
+                            B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
+                        with T.block():
+                            T.reads([B[j % 2, tx, 15, 0]])
+                            T.writes([C[tx, 15, j]])
+                            C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1)
+                with T.block():
+                    T.reads([B[1, tx, 15, 0]])
+                    T.writes([C[tx, 15, 15]])
+                    C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1)
+
+
+@T.prim_func
+def nested_pipeline_interleaving(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0,
+            16,
+            annotations={
+                "software_pipeline_stage": [0, 0, 0, 1, 1],
+                "software_pipeline_order": [0, 2, 3, 1, 4],
+            },
+        ):
+            with T.block():
+                T.reads(A[tx, i, 0:16])
+                T.writes(C[tx, i, 0:16])
+                A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared")
+                A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local")
+                for j in T.serial(0, 16):
+                    with T.block():
+                        T.reads(A[tx, i, j])
+                        T.writes(A_shared[tx, 0, j])
+                        A_shared[tx, 0, j] = A[tx, i, j]
+                for j in T.serial(0, 16):
+                    with T.block():
+                        T.reads(A_shared[tx, 0, j])
+                        T.writes(A_local[0, 0, j])
+                        A_local[0, 0, j] = A_shared[tx, i, j]
+                for j in T.serial(
+                    0,
+                    16,
+                    annotations={
+                        "software_pipeline_stage": [0, 1],
+                        "software_pipeline_order": [0, 1],
+                    },
+                ):
+                    with T.block():
+                        T.reads(A_local[0, 0, j])
+                        T.writes(C[tx, i, j])
+                        B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared")
+                        with T.block():
+                            T.reads(A_local[tx, i, j])
+                            T.writes(B[tx, i, 0])
+                            B[tx, i, 0] = A_local[0, 0, j] * T.float32(2)
+                        with T.block():
+                            T.reads(B[tx, i, 0])
+                            T.writes(C[tx, i, j])
+                            C[tx, i, j] = B[tx, i, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_nested_pipeline_interleaving(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+) -> None:
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        with T.block():
+            T.reads([A[tx, 0:16, 0:16]])
+            T.writes([C[tx, 0:16, 0:16]])
+            A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared")
+            A_local = T.alloc_buffer([1, 1, 16], dtype="float32", scope="local")
+            B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
+            with T.block():
+                T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[tx, 0, 0]])
+                T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0:16], B[0, tx, 0, 0]])
+                with T.block():
+                    T.reads([A[tx, 0, 0:16]])
+                    T.writes([A_shared[tx, 0, 0:16]])
+                    for j in T.serial(0, 16):
+                        with T.block():
+                            T.reads([A[tx, 0, j]])
+                            T.writes([A_shared[tx, 0, j]])
+                            A_shared[tx, 0, j] = A[tx, 0, j]
+                with T.block():
+                    T.reads([A_shared[tx, 0, 0:16]])
+                    T.writes([A_local[0, 0, 0:16]])
+                    for j in T.serial(0, 16):
+                        with T.block():
+                            T.reads([A_shared[tx, 0, j]])
+                            T.writes([A_local[0, 0, j]])
+                            A_local[0, 0, j] = A_shared[tx, 0, j]
+                with T.block():
+                    T.reads([A_local[tx, 0, 0]])
+                    T.writes([B[0, tx, 0, 0]])
+                    B[0, tx, 0, 0] = A_local[0, 0, 0] * T.float32(2)
+            with T.block():
+                T.reads(
+                    [
+                        A[tx, 1:16, 0:16],
+                        A_local[tx, 0:16, 0:16],
+                        B[0:2, tx, 0:15, 0],
+                        A_shared[tx, 0, 0:16],
+                    ]
+                )
+                T.writes(
+                    [
+                        A_shared[tx, 0, 0:16],
+                        B[0:2, tx, 0:16, 0],
+                        C[tx, 0:15, 0:16],
+                        A_local[0, 0, 0:16],
+                    ]
+                )
+                for i in T.serial(0, 15):
+                    with T.block():
+                        T.reads([A[tx, i + 1, 0:16]])
+                        T.writes([A_shared[tx, 0, 0:16]])
+                        for j in T.serial(0, 16):
+                            with T.block():
+                                T.reads([A[tx, i + 1, j]])
+                                T.writes([A_shared[tx, 0, j]])
+                                A_shared[tx, 0, j] = A[tx, i + 1, j]
+                    with T.block():
+                        T.reads([A_local[tx, i, 1:16], B[0:2, tx, i, 0]])
+                        T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
+                        for j in T.serial(0, 15):
+                            with T.block():
+                                T.reads([A_local[tx, i, j + 1]])
+                                T.writes([B[(j + 1) % 2, tx, i, 0]])
+                                B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2)
+                            with T.block():
+                                T.reads([B[j % 2, tx, i, 0]])
+                                T.writes([C[tx, i, j]])
+                                C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
+                    with T.block():
+                        T.reads([A_shared[tx, 0, 0:16]])
+                        T.writes([A_local[0, 0, 0:16]])
+                        for j in T.serial(0, 16):
+                            with T.block():
+                                T.reads([A_shared[tx, 0, j]])
+                                T.writes([A_local[0, 0, j]])
+                                A_local[0, 0, j] = A_shared[tx, i + 1, j]
+                    with T.block():
+                        T.reads([A_local[tx, i + 1, 0]])
+                        T.writes([B[0, tx, i + 1, 0]])
+                        B[0, tx, i + 1, 0] = A_local[0, 0, 0] * T.float32(2)
+                    with T.block():
+                        T.reads([B[1, tx, i, 0]])
+                        T.writes([C[tx, i, 15]])
+                        C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1)
+            with T.block():
+                T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]])
+                T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]])
+                with T.block():
+                    T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]])
+                    T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]])
+                    for j in T.serial(0, 15):
+                        with T.block():
+                            T.reads([A_local[tx, 15, j + 1]])
+                            T.writes([B[(j + 1) % 2, tx, 15, 0]])
+                            B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2)
+                        with T.block():
+                            T.reads([B[j % 2, tx, 15, 0]])
+                            T.writes([C[tx, 15, j]])
+                            C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1)
+                with T.block():
+                    T.reads([B[1, tx, 15, 0]])
+                    T.writes([C[tx, 15, 15]])
+                    C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1)
+
+
+@T.prim_func
+def nested_pipeline_double_buffer(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0,
+            16,
+            annotations={
+                "software_pipeline_stage": [0, 0, 0, 1, 1],
+                "software_pipeline_order": [0, 2, 3, 1, 4],
+            },
+        ):
+            with T.block():
+                T.reads(A[tx, i, 0:16])
+                T.writes(C[tx, i, 0:16])
+                A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared")
+                A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local")
+                for j in T.serial(0, 16):
+                    with T.block():
+                        T.reads(A[tx, i, j])
+                        T.writes(A_shared[tx, 0, j])
+                        A_shared[tx, 0, j] = A[tx, i, j]
+                for j in T.serial(0, 16):
+                    with T.block():
+                        T.block_attr({"double_buffer_scope": 0})
+                        T.reads(A_shared[tx, 0, j])
+                        T.writes(A_local[0, 0, j])
+                        A_local[0, 0, j] = A_shared[tx, i, j]
+                for j in T.serial(
+                    0,
+                    16,
+                    annotations={
+                        "software_pipeline_stage": [0, 1],
+                        "software_pipeline_order": [0, 1],
+                    },
+                ):
+                    with T.block():
+                        T.reads(A_local[0, 0, j])
+                        T.writes(C[tx, i, j])
+                        B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared")
+                        with T.block():
+                            T.reads(A_local[tx, i, j])
+                            T.writes(B[tx, i, 0])
+                            B[tx, i, 0] = A_local[0, 0, j] * T.float32(2)
+                        with T.block():
+                            T.reads(B[tx, i, 0])
+                            T.writes(C[tx, i, j])
+                            C[tx, i, j] = B[tx, i, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_nested_pipeline_double_buffer(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+) -> None:
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        with T.block():
+            T.reads([A[tx, 0:16, 0:16]])
+            T.writes([C[tx, 0:16, 0:16]])
+            A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared")
+            A_local = T.alloc_buffer([2, 1, 1, 16], dtype="float32", scope="local")
+            B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
+            with T.block():
+                T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0]])
+                T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0, 0]])
+                with T.block():
+                    T.reads([A[tx, 0, 0:16]])
+                    T.writes([A_shared[tx, 0, 0:16]])
+                    for j in T.serial(0, 16):
+                        with T.block():
+                            T.reads([A[tx, 0, j]])
+                            T.writes([A_shared[tx, 0, j]])
+                            A_shared[tx, 0, j] = A[tx, 0, j]
+                with T.block():
+                    T.reads([A_shared[tx, 0, 0:16]])
+                    T.writes([A_local[0, 0, 0, 0:16]])
+                    for j in T.serial(0, 16):
+                        with T.block():
+                            T.reads([A_shared[tx, 0, j]])
+                            T.writes([A_local[0, 0, 0, j]])
+                            T.block_attr({"double_buffer_scope": 0})
+                            A_local[0, 0, 0, j] = A_shared[tx, 0, j]
+                with T.block():
+                    T.reads([A_local[0, tx, 0, 0]])
+                    T.writes([B[0, tx, 0, 0]])
+                    B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2)
+            with T.block():
+                T.reads(
+                    [
+                        A[tx, 1:16, 0:16],
+                        A_local[0:2, tx, 0:16, 0:16],
+                        B[0:2, tx, 0:15, 0],
+                        A_shared[tx, 0, 0:16],
+                    ]
+                )
+                T.writes(
+                    [
+                        A_shared[tx, 0, 0:16],
+                        B[0:2, tx, 0:16, 0],
+                        C[tx, 0:15, 0:16],
+                        A_local[0:2, 0, 0, 0:16],
+                    ]
+                )
+                for i in T.serial(0, 15):
+                    with T.block():
+                        T.reads([A[tx, i + 1, 0:16]])
+                        T.writes([A_shared[tx, 0, 0:16]])
+                        for j in T.serial(0, 16):
+                            with T.block():
+                                T.reads([A[tx, i + 1, j]])
+                                T.writes([A_shared[tx, 0, j]])
+                                A_shared[tx, 0, j] = A[tx, i + 1, j]
+                    with T.block():
+                        T.reads([A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
+                        T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
+                        for j in T.serial(0, 15):
+                            with T.block():
+                                T.reads([A_local[i % 2, tx, i, j + 1]])
+                                T.writes([B[(j + 1) % 2, tx, i, 0]])
+                                B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(
+                                    2
+                                )
+                            with T.block():
+                                T.reads([B[j % 2, tx, i, 0]])
+                                T.writes([C[tx, i, j]])
+                                C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
+                    with T.block():
+                        T.reads([A_shared[tx, 0, 0:16]])
+                        T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]])
+                        for j in T.serial(0, 16):
+                            with T.block():
+                                T.reads([A_shared[tx, 0, j]])
+                                T.writes([A_local[(i + 1) % 2, 0, 0, j]])
+                                T.block_attr({"double_buffer_scope": 0})
+                                A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j]

Review comment:
       @JosephTheOctonaut do you think we could improve our doc to make it less misleading? If so, would you love to suggest some change? Thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org