You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/04/19 20:52:45 UTC
[tvm] branch main updated: [TensorIR][PASS][M1c]
PlanUpdateBufferAllocationLocation (#7873)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5ae0ef0 [TensorIR][PASS][M1c] PlanUpdateBufferAllocationLocation (#7873)
5ae0ef0 is described below
commit 5ae0ef045ca9d24e4ce4c0c90ee480e6bcf09c0e
Author: Siyuan Feng <Hz...@vip.qq.com>
AuthorDate: Tue Apr 20 04:52:23 2021 +0800
[TensorIR][PASS][M1c] PlanUpdateBufferAllocationLocation (#7873)
Co-authored-by: Tianqi Chen <tq...@users.noreply.github.com>
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
---
include/tvm/tir/transform.h | 8 +
python/tvm/tir/transform/transform.py | 13 ++
.../plan_update_buffer_allocation_location.cc | 169 +++++++++++++++++++++
...sform_plan_update_buffer_allocation_location.py | 128 ++++++++++++++++
4 files changed, 318 insertions(+)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 2397caf..8e7c16b 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -352,6 +352,14 @@ TVM_DLL Pass HoistIfThenElse();
*/
TVM_DLL Pass LowerInitBlock();
+/*!
+ * \brief Locate the buffer allocation to the exact position (usually is
+ * the lca of buffer access). This pass will inject opaque block
+ * with alloc_buffers at the allocation site.
+ * \return The pass.
+ */
+TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
+
} // namespace transform
} // namespace tir
} // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 8bd63bd..8317421 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -547,3 +547,16 @@ def LowerInitBlock():
The result pass
"""
return _ffi_api.LowerInitBlock()
+
+
+def PlanAndUpdateBufferAllocationLocation():
+ """Locate the buffer allocation to the exact position (usually is
+ the lca of buffer access). This pass will inject opaque block
+ with alloc_buffers at the allocation site.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.PlanAndUpdateBufferAllocationLocation()
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc
new file mode 100644
index 0000000..ecedaa6
--- /dev/null
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Planning where buffers to be allocated and update the AST.
+ * \file plan_update_buffer_allocation_location.cc
+ */
+
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class BufferAllocationLocator : public StmtExprMutator {
+ public:
+ explicit BufferAllocationLocator(const PrimFunc& func) {
+ Map<Buffer, Stmt> buffer_lca = DetectBufferAccessLCA(func);
+ std::unordered_set<const BufferNode*> arg_buffers;
+ for (const auto& kv : func->buffer_map) {
+ const Buffer& buffer = kv.second;
+ arg_buffers.emplace(buffer.get());
+ buffer_data_to_buffer_.Set(buffer->data, buffer);
+ }
+ // create buffers to be allocated at each stmts
+ for (const auto& kv : buffer_lca) {
+ const Buffer& buffer = kv.first;
+ const StmtNode* stmt = kv.second.get();
+ if (arg_buffers.count(buffer.get())) {
+ continue;
+ }
+ alloc_buffers_[stmt].push_back(buffer);
+ }
+ }
+
+ private:
+ Stmt VisitStmt_(const ForNode* op) final {
+ auto it = alloc_buffers_.find(op);
+ if (it == alloc_buffers_.end()) {
+ return StmtMutator::VisitStmt_(op);
+ }
+ for (const Buffer& buf : it->second) {
+ buffer_data_to_buffer_.Set(buf->data, buf);
+ }
+ Stmt stmt = StmtMutator::VisitStmt_(op);
+ op = stmt.as<ForNode>();
+ ICHECK(op != nullptr);
+ for (const Buffer& buf : it->second) {
+ buffer_data_to_buffer_.erase(buf->data);
+ }
+ Stmt body = InjectOpaqueBlock(op->body, it->second);
+ ObjectPtr<ForNode> n = CopyOnWrite(op);
+ n->body = std::move(body);
+ return Stmt(n);
+ }
+
+ Stmt VisitStmt_(const BlockNode* op) final {
+ ICHECK(!op->init.defined());
+ bool is_root = is_root_;
+ is_root_ = false;
+ Array<Buffer> alloc_buffers;
+ auto it = alloc_buffers_.find(op);
+ if (it != alloc_buffers_.end()) {
+ alloc_buffers = it->second;
+ for (const Buffer& buf : it->second) {
+ buffer_data_to_buffer_.Set(buf->data, buf);
+ }
+ }
+ Stmt stmt = StmtMutator::VisitStmt_(op);
+ op = stmt.as<BlockNode>();
+ ICHECK(op != nullptr);
+
+ // Ignore buffer allocated inside the block when getting access region.
+ if (it != alloc_buffers_.end()) {
+ for (const Buffer& buf : it->second) {
+ buffer_data_to_buffer_.erase(buf->data);
+ }
+ }
+
+ ObjectPtr<BlockNode> n = CopyOnWrite(op);
+ n->alloc_buffers = std::move(alloc_buffers);
+ // The read/write regions of root block are always empty.
+ if (!is_root) {
+ // Recalculate block access region
+ CollectReadWrite(GetRef<Block>(op), &n->reads, &n->writes);
+ }
+
+ return Stmt(n);
+ }
+
+ Stmt VisitStmt_(const BufferRealizeNode* op) final {
+ ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR.";
+ throw;
+ }
+
+ Stmt InjectOpaqueBlock(Stmt body, const Array<Buffer>& alloc_buffers) {
+ ICHECK(!alloc_buffers.empty());
+ Block opaque_block(/*iter_vars=*/{},
+ /*reads=*/{},
+ /*writes=*/{},
+ /*name_hint=*/"",
+ /*body=*/std::move(body),
+ /*init=*/NullOpt,
+ /*alloc_buffers=*/alloc_buffers);
+ ObjectPtr<BlockNode> n = CopyOnWrite(opaque_block.get());
+ CollectReadWrite(opaque_block, &n->reads, &n->writes);
+ BlockRealize realize({}, Bool(true), Block(n));
+ return std::move(realize);
+ }
+
+ void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
+ Array<BufferRegion>* writes) {
+ Array<Array<BufferRegion>> access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
+ *reads = access[0];
+ *writes = access[1];
+ for (const auto& opaque_access : access[2]) {
+ reads->push_back(opaque_access);
+ writes->push_back(opaque_access);
+ }
+ }
+
+ /*! \brief The map from stmt to the buffers to be allocated under it. */
+ std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
+ /*! \brief The buffer already allocated during recursive visiting. */
+ Map<Var, Buffer> buffer_data_to_buffer_;
+ /*! \brief indicate the whether the block is root. */
+ bool is_root_{true};
+};
+
+PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
+ auto fptr = func.CopyOnWrite();
+ BufferAllocationLocator locator(func);
+ fptr->body = locator(fptr->body);
+ return func;
+}
+
+namespace transform {
+
+Pass PlanAndUpdateBufferAllocationLocation() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ return PlanAndUpdateBufferAllocationLocation(std::move(f));
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation")
+ .set_body_typed(PlanAndUpdateBufferAllocationLocation);
+
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
new file mode 100644
index 0000000..d42c5e1
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+def _check(original, transformed):
+ func = original
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
+ tvm.ir.assert_structural_equal(mod["main"], transformed)
+
+
+@tvm.script.tir
+def element_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16))
+ C = tir.match_buffer(c, (16, 16))
+ B = tir.alloc_buffer((16, 16))
+ for i_0 in range(0, 16):
+ for j_0 in range(0, 16):
+ with tir.block([16, 16]) as [i, j]:
+ B[i, j] = A[i, j] + 1.0
+ for j_0 in range(0, 16):
+ with tir.block([16, 16]) as [i, j]:
+ C[i, j] = B[i, j] * 2.0
+
+
+@tvm.script.tir
+def transformed_element_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [16, 16])
+ C = tir.match_buffer(c, [16, 16])
+
+ for i_0 in range(0, 16):
+ with tir.block([]):
+ tir.reads([A[i_0, 0:16]])
+ tir.writes([C[i_0, 0:16]])
+ B = tir.alloc_buffer([16, 16])
+ for j_0 in tir.serial(0, 16):
+ with tir.block([16, 16], "") as [i, j]:
+ tir.bind(i, i_0)
+ tir.bind(j, j_0)
+ B[i, j] = A[i, j] + 1.0
+ for j_0 in tir.serial(0, 16):
+ with tir.block([16, 16], "") as [i, j]:
+ tir.bind(i, i_0)
+ tir.bind(j, j_0)
+ C[i, j] = B[i, j] * 2.0
+
+
+@tvm.script.tir
+def original_func() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([128, 128]) as [i, j]:
+ A[i, j] = tir.float32(0)
+ with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
+ B = tir.alloc_buffer((128, 128), "float32")
+ C = tir.alloc_buffer((128, 128), "float32")
+ D = tir.alloc_buffer((128, 128), "float32")
+ if k == 0:
+ for ii, jj in tir.grid(4, 4):
+ B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
+ for ii, jj in tir.grid(4, 4):
+ for kk in range(0, 4):
+ B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
+ for kk in range(0, 4):
+ B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
+
+
+@tvm.script.tir
+def transformed_func() -> None:
+ A = tir.alloc_buffer([128, 128])
+ with tir.block([128, 128], "") as [i, j]:
+ A[i, j] = tir.float32(0)
+ with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]:
+ B = tir.alloc_buffer([128, 128])
+ if k == 0:
+ for ii, jj in tir.grid(4, 4):
+ B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
+ for ii, jj in tir.grid(4, 4):
+ with tir.block([], ""):
+ tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]])
+ tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
+ C = tir.alloc_buffer([128, 128])
+ for kk in tir.serial(0, 4):
+ B[((i * 4) + ii), ((j * 4) + jj)] = (
+ B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)]
+ )
+ for kk in tir.serial(0, 4):
+ with tir.block([], ""):
+ tir.reads(
+ [
+ B[((i * 4) + ii), ((j * 4) + jj)],
+ C[((i * 4) + ii), ((k * 4) + kk)],
+ ]
+ )
+ tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
+ D = tir.alloc_buffer([128, 128])
+ B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + (
+ D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)]
+ )
+
+
+def test_elementwise():
+ _check(element_func, transformed_element_func)
+
+
+def test_locate_buffer_allocation():
+ _check(original_func, transformed_func)
+
+
+if __name__ == "__main__":
+ test_elementwise()
+ test_locate_buffer_allocation()