You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/04/19 20:52:45 UTC

[tvm] branch main updated: [TensorIR][PASS][M1c] PlanUpdateBufferAllocationLocation (#7873)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ae0ef0  [TensorIR][PASS][M1c] PlanUpdateBufferAllocationLocation (#7873)
5ae0ef0 is described below

commit 5ae0ef045ca9d24e4ce4c0c90ee480e6bcf09c0e
Author: Siyuan Feng <Hz...@vip.qq.com>
AuthorDate: Tue Apr 20 04:52:23 2021 +0800

    [TensorIR][PASS][M1c] PlanUpdateBufferAllocationLocation (#7873)
    
    Co-authored-by: Tianqi Chen <tq...@users.noreply.github.com>
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
---
 include/tvm/tir/transform.h                        |   8 +
 python/tvm/tir/transform/transform.py              |  13 ++
 .../plan_update_buffer_allocation_location.cc      | 169 +++++++++++++++++++++
 ...sform_plan_update_buffer_allocation_location.py | 128 ++++++++++++++++
 4 files changed, 318 insertions(+)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 2397caf..8e7c16b 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -352,6 +352,14 @@ TVM_DLL Pass HoistIfThenElse();
  */
 TVM_DLL Pass LowerInitBlock();
 
+/*!
+ * \brief Locate the buffer allocation to the exact position (usually is
+ *        the lca of buffer access). This pass will inject opaque block
+ *        with alloc_buffers at the allocation site.
+ * \return The pass.
+ */
+TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 8bd63bd..8317421 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -547,3 +547,16 @@ def LowerInitBlock():
         The result pass
     """
     return _ffi_api.LowerInitBlock()
+
+
+def PlanAndUpdateBufferAllocationLocation():
+    """Locate the buffer allocation to the exact position (usually is
+    the lca of buffer access). This pass will inject opaque block
+    with alloc_buffers at the allocation site.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.PlanAndUpdateBufferAllocationLocation()
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc
new file mode 100644
index 0000000..ecedaa6
--- /dev/null
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Planning where buffers to be allocated and update the AST.
+ * \file plan_update_buffer_allocation_location.cc
+ */
+
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class BufferAllocationLocator : public StmtExprMutator {
+ public:
+  explicit BufferAllocationLocator(const PrimFunc& func) {
+    Map<Buffer, Stmt> buffer_lca = DetectBufferAccessLCA(func);
+    std::unordered_set<const BufferNode*> arg_buffers;
+    for (const auto& kv : func->buffer_map) {
+      const Buffer& buffer = kv.second;
+      arg_buffers.emplace(buffer.get());
+      buffer_data_to_buffer_.Set(buffer->data, buffer);
+    }
+    // create buffers to be allocated at each stmts
+    for (const auto& kv : buffer_lca) {
+      const Buffer& buffer = kv.first;
+      const StmtNode* stmt = kv.second.get();
+      if (arg_buffers.count(buffer.get())) {
+        continue;
+      }
+      alloc_buffers_[stmt].push_back(buffer);
+    }
+  }
+
+ private:
+  Stmt VisitStmt_(const ForNode* op) final {
+    auto it = alloc_buffers_.find(op);
+    if (it == alloc_buffers_.end()) {
+      return StmtMutator::VisitStmt_(op);
+    }
+    for (const Buffer& buf : it->second) {
+      buffer_data_to_buffer_.Set(buf->data, buf);
+    }
+    Stmt stmt = StmtMutator::VisitStmt_(op);
+    op = stmt.as<ForNode>();
+    ICHECK(op != nullptr);
+    for (const Buffer& buf : it->second) {
+      buffer_data_to_buffer_.erase(buf->data);
+    }
+    Stmt body = InjectOpaqueBlock(op->body, it->second);
+    ObjectPtr<ForNode> n = CopyOnWrite(op);
+    n->body = std::move(body);
+    return Stmt(n);
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    ICHECK(!op->init.defined());
+    bool is_root = is_root_;
+    is_root_ = false;
+    Array<Buffer> alloc_buffers;
+    auto it = alloc_buffers_.find(op);
+    if (it != alloc_buffers_.end()) {
+      alloc_buffers = it->second;
+      for (const Buffer& buf : it->second) {
+        buffer_data_to_buffer_.Set(buf->data, buf);
+      }
+    }
+    Stmt stmt = StmtMutator::VisitStmt_(op);
+    op = stmt.as<BlockNode>();
+    ICHECK(op != nullptr);
+
+    // Ignore buffer allocated inside the block when getting access region.
+    if (it != alloc_buffers_.end()) {
+      for (const Buffer& buf : it->second) {
+        buffer_data_to_buffer_.erase(buf->data);
+      }
+    }
+
+    ObjectPtr<BlockNode> n = CopyOnWrite(op);
+    n->alloc_buffers = std::move(alloc_buffers);
+    // The read/write regions of root block are always empty.
+    if (!is_root) {
+      // Recalculate block access region
+      CollectReadWrite(GetRef<Block>(op), &n->reads, &n->writes);
+    }
+
+    return Stmt(n);
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR.";
+    throw;
+  }
+
+  Stmt InjectOpaqueBlock(Stmt body, const Array<Buffer>& alloc_buffers) {
+    ICHECK(!alloc_buffers.empty());
+    Block opaque_block(/*iter_vars=*/{},
+                       /*reads=*/{},
+                       /*writes=*/{},
+                       /*name_hint=*/"",
+                       /*body=*/std::move(body),
+                       /*init=*/NullOpt,
+                       /*alloc_buffers=*/alloc_buffers);
+    ObjectPtr<BlockNode> n = CopyOnWrite(opaque_block.get());
+    CollectReadWrite(opaque_block, &n->reads, &n->writes);
+    BlockRealize realize({}, Bool(true), Block(n));
+    return std::move(realize);
+  }
+
+  void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
+                        Array<BufferRegion>* writes) {
+    Array<Array<BufferRegion>> access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
+    *reads = access[0];
+    *writes = access[1];
+    for (const auto& opaque_access : access[2]) {
+      reads->push_back(opaque_access);
+      writes->push_back(opaque_access);
+    }
+  }
+
+  /*! \brief The map from stmt to the buffers to be allocated under it. */
+  std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
+  /*! \brief The buffer already allocated during recursive visiting. */
+  Map<Var, Buffer> buffer_data_to_buffer_;
+  /*! \brief indicate the whether the block is root. */
+  bool is_root_{true};
+};
+
+PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
+  auto fptr = func.CopyOnWrite();
+  BufferAllocationLocator locator(func);
+  fptr->body = locator(fptr->body);
+  return func;
+}
+
+namespace transform {
+
+Pass PlanAndUpdateBufferAllocationLocation() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    return PlanAndUpdateBufferAllocationLocation(std::move(f));
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation")
+    .set_body_typed(PlanAndUpdateBufferAllocationLocation);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
new file mode 100644
index 0000000..d42c5e1
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+def _check(original, transformed):
+    func = original
+    mod = tvm.IRModule.from_expr(func)
+    mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
+    tvm.ir.assert_structural_equal(mod["main"], transformed)
+
+
+@tvm.script.tir
+def element_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16))
+    C = tir.match_buffer(c, (16, 16))
+    B = tir.alloc_buffer((16, 16))
+    for i_0 in range(0, 16):
+        for j_0 in range(0, 16):
+            with tir.block([16, 16]) as [i, j]:
+                B[i, j] = A[i, j] + 1.0
+        for j_0 in range(0, 16):
+            with tir.block([16, 16]) as [i, j]:
+                C[i, j] = B[i, j] * 2.0
+
+
+@tvm.script.tir
+def transformed_element_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [16, 16])
+    C = tir.match_buffer(c, [16, 16])
+
+    for i_0 in range(0, 16):
+        with tir.block([]):
+            tir.reads([A[i_0, 0:16]])
+            tir.writes([C[i_0, 0:16]])
+            B = tir.alloc_buffer([16, 16])
+            for j_0 in tir.serial(0, 16):
+                with tir.block([16, 16], "") as [i, j]:
+                    tir.bind(i, i_0)
+                    tir.bind(j, j_0)
+                    B[i, j] = A[i, j] + 1.0
+            for j_0 in tir.serial(0, 16):
+                with tir.block([16, 16], "") as [i, j]:
+                    tir.bind(i, i_0)
+                    tir.bind(j, j_0)
+                    C[i, j] = B[i, j] * 2.0
+
+
+@tvm.script.tir
+def original_func() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([128, 128]) as [i, j]:
+        A[i, j] = tir.float32(0)
+    with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
+        B = tir.alloc_buffer((128, 128), "float32")
+        C = tir.alloc_buffer((128, 128), "float32")
+        D = tir.alloc_buffer((128, 128), "float32")
+        if k == 0:
+            for ii, jj in tir.grid(4, 4):
+                B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
+        for ii, jj in tir.grid(4, 4):
+            for kk in range(0, 4):
+                B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
+            for kk in range(0, 4):
+                B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
+
+
+@tvm.script.tir
+def transformed_func() -> None:
+    A = tir.alloc_buffer([128, 128])
+    with tir.block([128, 128], "") as [i, j]:
+        A[i, j] = tir.float32(0)
+    with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]:
+        B = tir.alloc_buffer([128, 128])
+        if k == 0:
+            for ii, jj in tir.grid(4, 4):
+                B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
+        for ii, jj in tir.grid(4, 4):
+            with tir.block([], ""):
+                tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]])
+                tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
+                C = tir.alloc_buffer([128, 128])
+                for kk in tir.serial(0, 4):
+                    B[((i * 4) + ii), ((j * 4) + jj)] = (
+                        B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)]
+                    )
+                for kk in tir.serial(0, 4):
+                    with tir.block([], ""):
+                        tir.reads(
+                            [
+                                B[((i * 4) + ii), ((j * 4) + jj)],
+                                C[((i * 4) + ii), ((k * 4) + kk)],
+                            ]
+                        )
+                        tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
+                        D = tir.alloc_buffer([128, 128])
+                        B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + (
+                            D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)]
+                        )
+
+
+def test_elementwise():
+    _check(element_func, transformed_element_func)
+
+
+def test_locate_buffer_allocation():
+    _check(original_func, transformed_func)
+
+
+if __name__ == "__main__":
+    test_elementwise()
+    test_locate_buffer_allocation()