You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/26 22:39:45 UTC
[GitHub] [tvm] junrushao1994 commented on a change in pull request #10066: [TIR] Add software pipelining
junrushao1994 commented on a change in pull request #10066:
URL: https://github.com/apache/tvm/pull/10066#discussion_r793110174
##########
File path: tests/python/unittest/test_tir_transform_inject_software_pipeline.py
##########
@@ -0,0 +1,824 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import sys
+
+import tvm
+from tvm import tir, te, TVMError
+from tvm.script import tir as T
+
+
+def _check(original, transformed):
+ func = original
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.InjectSoftwarePipeline()(mod)
+ mod = tvm.tir.transform.Simplify()(mod)
+ tvm.ir.assert_structural_equal(mod["main"], transformed, True)
+
+
+def _check_error(func):
+ mod = tvm.IRModule.from_expr(func)
+ with pytest.raises(ValueError):
+ tvm.tir.transform.InjectSoftwarePipeline()(mod)
+
+
+@T.prim_func
+def trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]):
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ for i in T.serial(
+ 0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}
+ ):
+ with T.block():
+ T.reads(A[tx, i])
+ T.writes(C[tx, i])
+ B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+ with T.block():
+ T.reads(A[tx, i])
+ T.writes(B[tx, 0])
+ B[tx, 0] = A[tx, i] * T.float32(2)
+ with T.block():
+ T.reads(B[tx, 0])
+ T.writes(C[tx, i])
+ C[tx, i] = B[tx, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_trivial_pipeline(
+ A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]
+) -> None:
+ for tx in T.thread_binding(16, thread="threadIdx.x"):
+ with T.block():
+ T.reads(A[tx, 0])
+ T.writes(C[tx, 0])
+ B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+ with T.block():
+ T.reads(A[tx, 0])
+ T.writes(B[0, tx, 0])
+ B[0, tx, 0] = A[tx, 0] * T.float32(2)
+ with T.block():
+ T.reads()
+ T.writes()
+ T.evaluate(0)
+ with T.block():
+ T.reads(B[0, tx, 0])
+ T.writes(C[tx, 0])
+ C[tx, 0] = B[0, tx, 0] + T.float32(1)
+
+
+@T.prim_func
+def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ for i in T.serial(
+ 0,
+ 16,
+ annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]},
+ ):
+ with T.block():
+ T.reads(A[tx, i])
+ T.writes(C[tx, i])
+ B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+ with T.block():
+ T.reads(A[tx, i])
+ T.writes(B[tx, 0])
+ B[tx, 0] = A[tx, i] * T.float32(2)
+ with T.block():
+ T.reads(B[tx, 0])
+ T.writes(C[tx, i])
+ C[tx, i] = B[tx, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_simple_compute(
+ A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]
+) -> None:
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ with T.block():
+ T.reads([A[tx, 0:16]])
+ T.writes([C[tx, 0:16]])
+ B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+ with T.block():
+ T.reads([A[tx, 0]])
+ T.writes([B[0, tx, 0]])
+ B[0, tx, 0] = A[tx, 0] * T.float32(2)
+ with T.block():
+ T.reads([A[tx, 1:16], B[0:2, tx, 0]])
+ T.writes([B[0:2, tx, 0], C[tx, 0:15]])
+ for i in T.serial(0, 15):
+ with T.block():
+ T.reads([A[tx, i + 1]])
+ T.writes([B[(i + 1) % 2, tx, 0]])
+ B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+ with T.block():
+ T.reads([B[i % 2, tx, 0]])
+ T.writes([C[tx, i]])
+ C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
+ with T.block():
+ T.reads([B[1, tx, 0]])
+ T.writes([C[tx, 15]])
+ C[tx, 15] = B[1, tx, 0] + T.float32(1)
+
+
+@T.prim_func
+def nested_pipeline_simple(
+ A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ for i in T.serial(
+ 0,
+ 16,
+ annotations={
+ "software_pipeline_stage": [0, 1, 1, 1],
+ "software_pipeline_order": [0, 1, 2, 3],
+ },
+ ):
+ with T.block():
+ T.reads(A[tx, i, 0:16])
+ T.writes(C[tx, i, 0:16])
+ A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared")
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads(A[tx, i, j])
+ T.writes(A_shared[tx, 0, j])
+ A_shared[tx, 0, j] = A[tx, i, j]
+ for j in T.serial(
+ 0,
+ 16,
+ annotations={
+ "software_pipeline_stage": [0, 1],
+ "software_pipeline_order": [0, 1],
+ },
+ ):
+ with T.block():
+ T.reads(A_shared[tx, 0, j])
+ T.writes(C[tx, i, j])
+ B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared")
+ with T.block():
+ T.reads(A_shared[tx, i, j])
+ T.writes(B[tx, i, 0])
+ B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2)
+ with T.block():
+ T.reads(B[tx, i, 0])
+ T.writes(C[tx, i, j])
+ C[tx, i, j] = B[tx, i, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_nested_pipeline_simple(
+ A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+) -> None:
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ with T.block():
+ T.reads([A[tx, 0:16, 0:16]])
+ T.writes([C[tx, 0:16, 0:16]])
+ A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared")
+ B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
+ with T.block():
+ T.reads([A[tx, 0, 0:16]])
+ T.writes([A_shared[0, tx, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A[tx, 0, j]])
+ T.writes([A_shared[0, tx, 0, j]])
+ A_shared[0, tx, 0, j] = A[tx, 0, j]
+ with T.block():
+ T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:15, 0:16], B[0:2, tx, 0:15, 0]])
+ T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:15, 0], C[tx, 0:15, 0:16]])
+ for i in T.serial(0, 15):
+ with T.block():
+ T.reads([A[tx, i + 1, 0:16]])
+ T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A[tx, i + 1, j]])
+ T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
+ A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
+ with T.block():
+ T.reads([A_shared[i % 2, tx, i, 0]])
+ T.writes([B[0, tx, i, 0]])
+ B[0, tx, i, 0] = A_shared[i % 2, tx, 0, 0] * T.float32(2)
+ with T.block():
+ T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
+ T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
+ for j in T.serial(0, 15):
+ with T.block():
+ T.reads([A_shared[i % 2, tx, i, j + 1]])
+ T.writes([B[(j + 1) % 2, tx, i, 0]])
+ B[(j + 1) % 2, tx, i, 0] = A_shared[
+ i % 2, tx, 0, j + 1
+ ] * T.float32(2)
+ with T.block():
+ T.reads([B[j % 2, tx, i, 0]])
+ T.writes([C[tx, i, j]])
+ C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
+ with T.block():
+ T.reads([B[1, tx, i, 0]])
+ T.writes([C[tx, i, 15]])
+ C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1)
+ with T.block():
+ T.reads([A_shared[1, tx, 15, 0:16], B[0:2, tx, 15, 0]])
+ T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]])
+ with T.block():
+ T.reads([A_shared[1, tx, 15, 0]])
+ T.writes([B[0, tx, 15, 0]])
+ B[0, tx, 15, 0] = A_shared[1, tx, 0, 0] * T.float32(2)
+ with T.block():
+ T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]])
+ T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]])
+ for j in T.serial(0, 15):
+ with T.block():
+ T.reads([A_shared[1, tx, 15, j + 1]])
+ T.writes([B[(j + 1) % 2, tx, 15, 0]])
+ B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
+ with T.block():
+ T.reads([B[j % 2, tx, 15, 0]])
+ T.writes([C[tx, 15, j]])
+ C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1)
+ with T.block():
+ T.reads([B[1, tx, 15, 0]])
+ T.writes([C[tx, 15, 15]])
+ C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1)
+
+
+@T.prim_func
+def nested_pipeline_prefetch_inner(
+ A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ for i in T.serial(
+ 0,
+ 16,
+ annotations={
+ "software_pipeline_stage": [0, 0, 1, 1],
+ "software_pipeline_order": [0, 2, 1, 3],
+ },
+ ):
+ with T.block():
+ T.reads(A[tx, i, 0:16])
+ T.writes(C[tx, i, 0:16])
+ A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared")
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads(A[tx, i, j])
+ T.writes(A_shared[tx, 0, j])
+ A_shared[tx, 0, j] = A[tx, i, j]
+ for j in T.serial(
+ 0,
+ 16,
+ annotations={
+ "software_pipeline_stage": [0, 1],
+ "software_pipeline_order": [0, 1],
+ },
+ ):
+ with T.block():
+ T.reads(A_shared[tx, 0, j])
+ T.writes(C[tx, i, j])
+ B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared")
+ with T.block():
+ T.reads(A_shared[tx, i, j])
+ T.writes(B[tx, i, 0])
+ B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2)
+ with T.block():
+ T.reads(B[tx, i, 0])
+ T.writes(C[tx, i, j])
+ C[tx, i, j] = B[tx, i, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_nested_pipeline_prefetch_inner(
+ A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+) -> None:
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ with T.block():
+ T.reads([A[tx, 0:16, 0:16]])
+ T.writes([C[tx, 0:16, 0:16]])
+ A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared")
+ B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
+ with T.block():
+ T.reads([A[tx, 0, 0:16], A_shared[0, tx, 0, 0]])
+ T.writes([A_shared[0, tx, 0, 0:16], B[0, tx, 0, 0]])
+ with T.block():
+ T.reads([A[tx, 0, 0:16]])
+ T.writes([A_shared[0, tx, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A[tx, 0, j]])
+ T.writes([A_shared[0, tx, 0, j]])
+ A_shared[0, tx, 0, j] = A[tx, 0, j]
+ with T.block():
+ T.reads([A_shared[0, tx, 0, 0]])
+ T.writes([B[0, tx, 0, 0]])
+ B[0, tx, 0, 0] = A_shared[0, tx, 0, 0] * T.float32(2)
+ with T.block():
+ T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:16, 0:16], B[0:2, tx, 0:15, 0]])
+ T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:16, 0], C[tx, 0:15, 0:16]])
+ for i in T.serial(0, 15):
+ with T.block():
+ T.reads([A[tx, i + 1, 0:16]])
+ T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A[tx, i + 1, j]])
+ T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
+ A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
+ with T.block():
+ T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
+ T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
+ for j in T.serial(0, 15):
+ with T.block():
+ T.reads([A_shared[i % 2, tx, i, j + 1]])
+ T.writes([B[(j + 1) % 2, tx, i, 0]])
+ B[(j + 1) % 2, tx, i, 0] = A_shared[
+ i % 2, tx, 0, j + 1
+ ] * T.float32(2)
+ with T.block():
+ T.reads([B[j % 2, tx, i, 0]])
+ T.writes([C[tx, i, j]])
+ C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
+ with T.block():
+ T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]])
+ T.writes([B[0, tx, i + 1, 0]])
+ B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2)
+ with T.block():
+ T.reads([B[1, tx, i, 0]])
+ T.writes([C[tx, i, 15]])
+ C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1)
+ with T.block():
+ T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]])
+ T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]])
+ with T.block():
+ T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]])
+ T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]])
+ for j in T.serial(0, 15):
+ with T.block():
+ T.reads([A_shared[1, tx, 15, j + 1]])
+ T.writes([B[(j + 1) % 2, tx, 15, 0]])
+ B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
+ with T.block():
+ T.reads([B[j % 2, tx, 15, 0]])
+ T.writes([C[tx, 15, j]])
+ C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1)
+ with T.block():
+ T.reads([B[1, tx, 15, 0]])
+ T.writes([C[tx, 15, 15]])
+ C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1)
+
+
+@T.prim_func
+def nested_pipeline_interleaving(
+ A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ for i in T.serial(
+ 0,
+ 16,
+ annotations={
+ "software_pipeline_stage": [0, 0, 0, 1, 1],
+ "software_pipeline_order": [0, 2, 3, 1, 4],
+ },
+ ):
+ with T.block():
+ T.reads(A[tx, i, 0:16])
+ T.writes(C[tx, i, 0:16])
+ A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared")
+ A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local")
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads(A[tx, i, j])
+ T.writes(A_shared[tx, 0, j])
+ A_shared[tx, 0, j] = A[tx, i, j]
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads(A_shared[tx, 0, j])
+ T.writes(A_local[0, 0, j])
+ A_local[0, 0, j] = A_shared[tx, i, j]
+ for j in T.serial(
+ 0,
+ 16,
+ annotations={
+ "software_pipeline_stage": [0, 1],
+ "software_pipeline_order": [0, 1],
+ },
+ ):
+ with T.block():
+ T.reads(A_local[0, 0, j])
+ T.writes(C[tx, i, j])
+ B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared")
+ with T.block():
+ T.reads(A_local[tx, i, j])
+ T.writes(B[tx, i, 0])
+ B[tx, i, 0] = A_local[0, 0, j] * T.float32(2)
+ with T.block():
+ T.reads(B[tx, i, 0])
+ T.writes(C[tx, i, j])
+ C[tx, i, j] = B[tx, i, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_nested_pipeline_interleaving(
+ A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+) -> None:
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ with T.block():
+ T.reads([A[tx, 0:16, 0:16]])
+ T.writes([C[tx, 0:16, 0:16]])
+ A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared")
+ A_local = T.alloc_buffer([1, 1, 16], dtype="float32", scope="local")
+ B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
+ with T.block():
+ T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[tx, 0, 0]])
+ T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0:16], B[0, tx, 0, 0]])
+ with T.block():
+ T.reads([A[tx, 0, 0:16]])
+ T.writes([A_shared[tx, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A[tx, 0, j]])
+ T.writes([A_shared[tx, 0, j]])
+ A_shared[tx, 0, j] = A[tx, 0, j]
+ with T.block():
+ T.reads([A_shared[tx, 0, 0:16]])
+ T.writes([A_local[0, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A_shared[tx, 0, j]])
+ T.writes([A_local[0, 0, j]])
+ A_local[0, 0, j] = A_shared[tx, 0, j]
+ with T.block():
+ T.reads([A_local[tx, 0, 0]])
+ T.writes([B[0, tx, 0, 0]])
+ B[0, tx, 0, 0] = A_local[0, 0, 0] * T.float32(2)
+ with T.block():
+ T.reads(
+ [
+ A[tx, 1:16, 0:16],
+ A_local[tx, 0:16, 0:16],
+ B[0:2, tx, 0:15, 0],
+ A_shared[tx, 0, 0:16],
+ ]
+ )
+ T.writes(
+ [
+ A_shared[tx, 0, 0:16],
+ B[0:2, tx, 0:16, 0],
+ C[tx, 0:15, 0:16],
+ A_local[0, 0, 0:16],
+ ]
+ )
+ for i in T.serial(0, 15):
+ with T.block():
+ T.reads([A[tx, i + 1, 0:16]])
+ T.writes([A_shared[tx, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A[tx, i + 1, j]])
+ T.writes([A_shared[tx, 0, j]])
+ A_shared[tx, 0, j] = A[tx, i + 1, j]
+ with T.block():
+ T.reads([A_local[tx, i, 1:16], B[0:2, tx, i, 0]])
+ T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
+ for j in T.serial(0, 15):
+ with T.block():
+ T.reads([A_local[tx, i, j + 1]])
+ T.writes([B[(j + 1) % 2, tx, i, 0]])
+ B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2)
+ with T.block():
+ T.reads([B[j % 2, tx, i, 0]])
+ T.writes([C[tx, i, j]])
+ C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
+ with T.block():
+ T.reads([A_shared[tx, 0, 0:16]])
+ T.writes([A_local[0, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A_shared[tx, 0, j]])
+ T.writes([A_local[0, 0, j]])
+ A_local[0, 0, j] = A_shared[tx, i + 1, j]
+ with T.block():
+ T.reads([A_local[tx, i + 1, 0]])
+ T.writes([B[0, tx, i + 1, 0]])
+ B[0, tx, i + 1, 0] = A_local[0, 0, 0] * T.float32(2)
+ with T.block():
+ T.reads([B[1, tx, i, 0]])
+ T.writes([C[tx, i, 15]])
+ C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1)
+ with T.block():
+ T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]])
+ T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]])
+ with T.block():
+ T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]])
+ T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]])
+ for j in T.serial(0, 15):
+ with T.block():
+ T.reads([A_local[tx, 15, j + 1]])
+ T.writes([B[(j + 1) % 2, tx, 15, 0]])
+ B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2)
+ with T.block():
+ T.reads([B[j % 2, tx, 15, 0]])
+ T.writes([C[tx, 15, j]])
+ C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1)
+ with T.block():
+ T.reads([B[1, tx, 15, 0]])
+ T.writes([C[tx, 15, 15]])
+ C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1)
+
+
+@T.prim_func
+def nested_pipeline_double_buffer(
+ A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ for i in T.serial(
+ 0,
+ 16,
+ annotations={
+ "software_pipeline_stage": [0, 0, 0, 1, 1],
+ "software_pipeline_order": [0, 2, 3, 1, 4],
+ },
+ ):
+ with T.block():
+ T.reads(A[tx, i, 0:16])
+ T.writes(C[tx, i, 0:16])
+ A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared")
+ A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local")
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads(A[tx, i, j])
+ T.writes(A_shared[tx, 0, j])
+ A_shared[tx, 0, j] = A[tx, i, j]
+ for j in T.serial(0, 16):
+ with T.block():
+ T.block_attr({"double_buffer_scope": 0})
+ T.reads(A_shared[tx, 0, j])
+ T.writes(A_local[0, 0, j])
+ A_local[0, 0, j] = A_shared[tx, i, j]
+ for j in T.serial(
+ 0,
+ 16,
+ annotations={
+ "software_pipeline_stage": [0, 1],
+ "software_pipeline_order": [0, 1],
+ },
+ ):
+ with T.block():
+ T.reads(A_local[0, 0, j])
+ T.writes(C[tx, i, j])
+ B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared")
+ with T.block():
+ T.reads(A_local[tx, i, j])
+ T.writes(B[tx, i, 0])
+ B[tx, i, 0] = A_local[0, 0, j] * T.float32(2)
+ with T.block():
+ T.reads(B[tx, i, 0])
+ T.writes(C[tx, i, j])
+ C[tx, i, j] = B[tx, i, 0] + T.float32(1)
+
+
+@T.prim_func
+def transformed_nested_pipeline_double_buffer(
+ A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+) -> None:
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ with T.block():
+ T.reads([A[tx, 0:16, 0:16]])
+ T.writes([C[tx, 0:16, 0:16]])
+ A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared")
+ A_local = T.alloc_buffer([2, 1, 1, 16], dtype="float32", scope="local")
+ B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
+ with T.block():
+ T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0]])
+ T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0, 0]])
+ with T.block():
+ T.reads([A[tx, 0, 0:16]])
+ T.writes([A_shared[tx, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A[tx, 0, j]])
+ T.writes([A_shared[tx, 0, j]])
+ A_shared[tx, 0, j] = A[tx, 0, j]
+ with T.block():
+ T.reads([A_shared[tx, 0, 0:16]])
+ T.writes([A_local[0, 0, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A_shared[tx, 0, j]])
+ T.writes([A_local[0, 0, 0, j]])
+ T.block_attr({"double_buffer_scope": 0})
+ A_local[0, 0, 0, j] = A_shared[tx, 0, j]
+ with T.block():
+ T.reads([A_local[0, tx, 0, 0]])
+ T.writes([B[0, tx, 0, 0]])
+ B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2)
+ with T.block():
+ T.reads(
+ [
+ A[tx, 1:16, 0:16],
+ A_local[0:2, tx, 0:16, 0:16],
+ B[0:2, tx, 0:15, 0],
+ A_shared[tx, 0, 0:16],
+ ]
+ )
+ T.writes(
+ [
+ A_shared[tx, 0, 0:16],
+ B[0:2, tx, 0:16, 0],
+ C[tx, 0:15, 0:16],
+ A_local[0:2, 0, 0, 0:16],
+ ]
+ )
+ for i in T.serial(0, 15):
+ with T.block():
+ T.reads([A[tx, i + 1, 0:16]])
+ T.writes([A_shared[tx, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A[tx, i + 1, j]])
+ T.writes([A_shared[tx, 0, j]])
+ A_shared[tx, 0, j] = A[tx, i + 1, j]
+ with T.block():
+ T.reads([A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
+ T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
+ for j in T.serial(0, 15):
+ with T.block():
+ T.reads([A_local[i % 2, tx, i, j + 1]])
+ T.writes([B[(j + 1) % 2, tx, i, 0]])
+ B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(
+ 2
+ )
+ with T.block():
+ T.reads([B[j % 2, tx, i, 0]])
+ T.writes([C[tx, i, j]])
+ C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
+ with T.block():
+ T.reads([A_shared[tx, 0, 0:16]])
+ T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]])
+ for j in T.serial(0, 16):
+ with T.block():
+ T.reads([A_shared[tx, 0, j]])
+ T.writes([A_local[(i + 1) % 2, 0, 0, j]])
+ T.block_attr({"double_buffer_scope": 0})
+ A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j]
Review comment:
@JosephTheOctonaut do you think we could improve our doc to make it less misleading? If so, would you love to suggest some change? Thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org