You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/04/28 05:12:06 UTC
[tvm] branch main updated: [TensorIR][PASS] CompactBufferAllocation
(#7923)
This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dee3133 [TensorIR][PASS] CompactBufferAllocation (#7923)
dee3133 is described below
commit dee3133c5418fc1d44ab202bff8b2c6906593d1a
Author: Siyuan Feng <Hz...@vip.qq.com>
AuthorDate: Wed Apr 28 13:11:43 2021 +0800
[TensorIR][PASS] CompactBufferAllocation (#7923)
Co-authored-by: Tianqi Chen <tq...@users.noreply.github.com>
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Cody Yu <co...@gmail.com>
---
include/tvm/tir/expr.h | 1 +
include/tvm/tir/stmt.h | 12 +-
include/tvm/tir/transform.h | 46 ++
python/tvm/tir/transform/transform.py | 50 +++
src/support/utils.h | 19 +
src/tir/ir/stmt.cc | 8 +
src/tir/transforms/compact_buffer_region.cc | 468 +++++++++++++++++++++
src/tir/transforms/convert_blocks_to_opaque.cc | 104 +++++
.../test_tir_transform_compact_buffer_region.py | 331 +++++++++++++++
.../test_tir_transform_convert_blocks_to_opaque.py | 77 ++++
10 files changed, 1115 insertions(+), 1 deletion(-)
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 7cab197..e1d0974 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -638,6 +638,7 @@ class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};
/*!
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 0931768..cc10c21 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -324,6 +324,7 @@ class BufferStore : public Stmt {
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};
/*!
@@ -991,13 +992,22 @@ class BufferRegion : public ObjectRef {
TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
/*!
- * \brief Create a BufferRegion which is full region of the given buffer..
+ * \brief Create a BufferRegion which is full region of the given buffer.
* \param buffer The buffer to generate full BufferRegion.
* \return The BufferRegion which covers all region of the given buffer
*/
TVM_DLL static BufferRegion FullRegion(Buffer buffer);
+ /*!
+ * \brief Create a BufferRegion which is a single point of the given buffer.
+ * \param buffer The buffer to generate single point BufferRegion.
+ * \param indices The access point indices of the buffer
+ * \return The BufferRegion which is the single point of the given buffer.
+ */
+ TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> indices);
+
TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode);
};
/*!
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 8e7c16b..a236c50 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -360,6 +360,52 @@ TVM_DLL Pass LowerInitBlock();
*/
TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
+/*!
+ * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the
+ * corresponding iter_values in BlockRealize, for opaque blocks by removing all
+ *. the iter_values in BlockRealize and iter_vars in Block.
+ * \return The pass.
+ */
+TVM_DLL Pass ConvertBlocksToOpaque();
+
+/*!
+ * \brief Compact the buffer access region by removing the buffer regions that are not accessed,
+ * i.e. narrowing the buffer shape and adjust the access region if necessary.
+ * \example
+ * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed.
+ * \code
+ *
+ * for i in range(0, 16):
+ * with tir.block([]):
+ * B = tir.alloc_buffer(16, 16)
+ * for j in range(0, 16):
+ * B[i, j] = A[i, j] + 1
+ * for j in range(0, 16):
+ * C[i, j] = B[i, j] + 1
+ *
+ * \endcode
+ *
+ * This pass narrows the buffer shape and adjust its accessed region accordingly.
+ * In this particular case, because only a `1 * 16` vector of `B` is accessed,
+ * the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`.
+ *
+ * \code
+ *
+ * for i in range(0, 16):
+ * with tir.block([]):
+ * B = tir.alloc_buffer(1, 16)
+ * for j in range(0, 16):
+ * B[0, j] = A[i, j] + 1
+ * for j in range(0, 16):
+ * C[i, j] = B[0, j] + 1
+ *
+ * \endcode
+ *
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CompactBufferAllocation();
+
} // namespace transform
} // namespace tir
} // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 8317421..2ae75d2 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -560,3 +560,53 @@ def PlanAndUpdateBufferAllocationLocation():
The result pass
"""
return _ffi_api.PlanAndUpdateBufferAllocationLocation()
+
+
+def ConvertBlocksToOpaque():
+ """Substitute all the block vars with the PrimExprs they are bound to, indicated by
+ the corresponding iter_values in BlockRealize, and then convert the blocks into
+ opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.ConvertBlocksToOpaque()
+
+
+def CompactBufferAllocation():
+ """Compact the buffer access region. by removing the buffer regions that are not accessed,
+ i.e. narrowing the buffer shape and adjust the access region if necessary.
+
+ Example
+ -------
+ Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed.
+ .. code-block:: python
+
+ for i in range(0, 16):
+ with tir.block([]):
+ B = tir.alloc_buffer(16, 16)
+ for j in range(0, 16):
+ B[i, j] = A[i, j] + 1
+ for j in range(0, 16):
+ C[i, j] = B[i, j] + 1
+ This pass narrows the buffer shape and adjust its accessed region accordingly.
+ In this particular case, because only a `1 * 16` vector of `B` is accessed,
+ the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`.
+ .. code-block:: python
+
+ for i in range(0, 16):
+ with tir.block([]):
+ B = tir.alloc_buffer(1, 16)
+ for j in range(0, 16):
+ B[0, j] = A[i, j] + 1
+ for j in range(0, 16):
+ C[i, j] = B[0, j] + 1
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.CompactBufferAllocation()
diff --git a/src/support/utils.h b/src/support/utils.h
index 2f55d40..0753517 100644
--- a/src/support/utils.h
+++ b/src/support/utils.h
@@ -31,6 +31,9 @@
#include <sys/wait.h>
#endif // __hexagon__
#endif // _WIN32
+
+#include <tvm/runtime/container.h>
+
#include <algorithm>
#include <array>
#include <cctype>
@@ -129,6 +132,22 @@ inline std::vector<std::string> Split(const std::string& str, char delim) {
}
/*!
+ * \brief Check whether the string starts with a given prefix.
+ * \param str The given string.
+ * \param prefix The given prefix.
+ * \return Whether the prefix matched.
+ */
+inline bool StartsWith(const String& str, const char* prefix) {
+ size_t n = str.length();
+ for (size_t i = 0; i < n; i++) {
+ if (prefix[i] == '\0') return true;
+ if (str.data()[i] != prefix[i]) return false;
+ }
+ // return true if the str is equal to the prefix
+ return prefix[n + 1] == '\0';
+}
+
+/*!
* \brief EndsWith check whether the strings ends with
* \param value The full string
* \param end The end substring
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 87ead3e..b2016eb 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -646,6 +646,14 @@ BufferRegion BufferRegion::FullRegion(Buffer buffer) {
return BufferRegion(buffer, region);
}
+BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) {
+ Array<Range> region;
+ for (const PrimExpr& index : indices) {
+ region.push_back(Range::FromMinExtent(index, 1));
+ }
+ return BufferRegion(buffer, region);
+}
+
TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<Range> region) {
return BufferRegion(buffer, region);
});
diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc
new file mode 100644
index 0000000..a5ca67e
--- /dev/null
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -0,0 +1,468 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file compact_buffer_region.cc
+ * \brief Compact the buffer size into its exact need.
+ */
+
+#include <tvm/arith/int_set.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <stack>
+
+#include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
+#include "../../support/utils.h"
+
+namespace tvm {
+namespace tir {
+
+using NDIntSet = std::vector<arith::IntSet>;
+
+arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) {
+ return arith::IntSet::FromRange(Range::FromMinExtent(min, extent));
+}
+
+NDIntSet NDIntSetFromRegion(const Region& region) {
+ NDIntSet result;
+ result.reserve(region.size());
+ for (const Range& range : region) {
+ result.push_back(arith::IntSet::FromRange(range));
+ }
+ return result;
+}
+
+NDIntSet NDIntSetFromShape(const Array<PrimExpr>& shape) {
+ PrimExpr zero = Integer(0);
+ NDIntSet result;
+ result.reserve(shape.size());
+ for (const PrimExpr& extent : shape) {
+ result.push_back(IntSetFromMinExtent(zero, extent));
+ }
+ return result;
+}
+
+NDIntSet NDIntSetFromPoint(const Array<PrimExpr>& indices) {
+ NDIntSet result;
+ result.reserve(indices.size());
+ for (const PrimExpr& index : indices) {
+ result.push_back(arith::IntSet::SinglePoint(index));
+ }
+ return result;
+}
+
+void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) {
+ ICHECK_EQ(lhs->size(), rhs.size());
+ int ndim = rhs.size();
+ for (int i = 0; i < ndim; ++i) {
+ arith::IntSet& int_set = lhs->at(i);
+ int_set = arith::Union({int_set, rhs.at(i)});
+ }
+}
+
+NDIntSet NDIntSetEmpty(int ndim) {
+ return std::vector<arith::IntSet>(ndim, arith::IntSet::Nothing());
+}
+
+NDIntSet EvalNDIntSet(const NDIntSet& nd_int_set,
+ const std::unordered_map<const VarNode*, arith::IntSet>& dom_map) {
+ NDIntSet ret;
+ ret.reserve(nd_int_set.size());
+ for (const arith::IntSet& s : nd_int_set) {
+ ret.push_back(arith::EvalSet(s, dom_map));
+ }
+ return ret;
+}
+
+/*!
+ * \brief return the region collected by NDIntSet. return the oroginal buffer shape if the
+ * int_set is empty.
+ */
+Region NarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
+ const Array<PrimExpr>& original_shape) {
+ Array<Range> result;
+ result.reserve(nd_int_set.size());
+ for (size_t i = 0; i < nd_int_set.size(); ++i) {
+ const arith::IntSet& int_set = nd_int_set[i];
+ result.push_back(int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i])));
+ }
+ return result;
+}
+
+/*!
+ * \brief Collect the access region of each buffer.
+ * \note The param buffer regions will not be collected.
+ */
+class BufferAccessRegionCollector : public StmtExprVisitor {
+ public:
+ static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> Collect(
+ const PrimFunc& f) {
+ BufferAccessRegionCollector collector;
+ collector(f->body);
+ return std::move(collector.buffer_access_region_);
+ }
+
+ private:
+ struct BufferAccessInfo {
+ /*! \brief The buffer. */
+ Buffer buffer;
+ /*! \brief The buffer access region, which can be updated during visiting. */
+ NDIntSet accessed_region;
+
+ explicit BufferAccessInfo(const Buffer& buffer, const NDIntSet& region)
+ : buffer(buffer), accessed_region(region) {}
+ };
+
+ BufferAccessRegionCollector() = default;
+
+ /**************** Visitor overload ****************/
+
+ void VisitStmt_(const BufferStoreNode* op) final {
+ VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
+ }
+
+ void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef<Var>(op)); }
+
+ void VisitExpr_(const LoadNode* op) final {
+ StmtExprVisitor::VisitExpr_(op);
+ VisitBufferVar(op->buffer_var);
+ }
+
+ void VisitStmt_(const StoreNode* op) final {
+ StmtExprVisitor::VisitStmt_(op);
+ VisitBufferVar(op->buffer_var);
+ }
+
+ void VisitStmt_(const ForNode* op) final {
+ ancestor_loops_.push_back(op);
+ StmtExprVisitor::VisitStmt_(op);
+ ancestor_loops_.pop_back();
+ // The iter_dom_map is updated by post DFS order.
+ // If the union point is under the for node, the loop var will not be relaxed.
+ // If the union point is outer of the for loop, the loop var should be relaxed.
+ iter_dom_map_on_post_order_[op->loop_var.get()] = IntSetFromMinExtent(op->min, op->extent);
+ }
+
+ void VisitStmt_(const BlockNode* op) final {
+ // Step 0. Check there is no init part.
+ ICHECK(!op->init.defined());
+ // Step 1. Update outer buffer access info using buffer region
+ for (const BufferRegion& region : op->reads) {
+ VisitBufferAccess(region);
+ }
+ for (const BufferRegion& region : op->writes) {
+ VisitBufferAccess(region);
+ }
+
+ // Step 2. Update inner buffer
+ // Step 2.1. rebuild map buffer_var_in_scope
+ std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_var_in_scope;
+ for (const Buffer& buffer : op->alloc_buffers) {
+ buffer_var_in_scope.emplace(buffer->data, buffer);
+ }
+ // Step 2.2 Record top stack element before recursive visiting.
+ size_t stack_top = buffer_access_stack_.size();
+
+ // Step 2.3. Update the buffer_var_in_scope_ of visitor and visit recursively
+ std::swap(buffer_var_in_scope, buffer_var_in_scope_);
+ StmtExprVisitor::VisitStmt_(op);
+ std::swap(buffer_var_in_scope, buffer_var_in_scope_);
+
+ // Step 2.4. Combine and relax access
+ std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> relaxed_region =
+ CombineAndRelax(stack_top);
+
+ // Step 2.5. Visit ancestor_loops and try to relax outer thread loops.
+ for (const Buffer& buffer : op->alloc_buffers) {
+ auto it = relaxed_region.find(buffer);
+ ICHECK(it != relaxed_region.end());
+ const NDIntSet& nd_int_set = it->second;
+ std::unordered_map<const VarNode*, arith::IntSet> dom_map;
+ for (const ForNode* loop : ancestor_loops_) {
+ const VarNode* loop_var = loop->loop_var.get();
+ if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer->scope))) {
+ dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent);
+ }
+ }
+ NDIntSet int_set = EvalNDIntSet(nd_int_set, dom_map);
+ buffer_access_region_[buffer] = NarrowBufferRegionFromNDIntSet(int_set, buffer->shape);
+ }
+ }
+
+ /**************** Helper functions ****************/
+
+ void VisitBufferAccess(const BufferRegion& buffer_region) {
+ const BufferNode* buffer = buffer_region->buffer.get();
+ auto it = buffer_var_in_scope_.find(buffer->data);
+ if (it != buffer_var_in_scope_.end()) {
+ const Buffer& buffer = it->second;
+ const BufferAccessInfo* info =
+ arena_.make<BufferAccessInfo>(buffer, NDIntSetFromRegion(buffer_region->region));
+ buffer_access_stack_.push(info);
+ }
+ }
+
+ void VisitBufferVar(const Var& var) {
+ auto it = buffer_var_in_scope_.find(var);
+ if (it != buffer_var_in_scope_.end()) {
+ const Buffer& buffer = it->second;
+ VisitBufferAccess(BufferRegion::FullRegion(buffer));
+ }
+ }
+
+ /*!
+ * \brief Combine buffer accesses in the sub-tree.
+ * \details The access info is stored in a stack by DFS order, so that the accesses in the
+ * sub-tree are top-n elements in the stack.
+ * \param stack_top compact the access information in `stack[stack_top:end]`.
+ */
+ std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> CombineAndRelax(
+ size_t stack_top) {
+ std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> accesses;
+ while (buffer_access_stack_.size() > stack_top) {
+ const BufferAccessInfo* info = buffer_access_stack_.top();
+ buffer_access_stack_.pop();
+ NDIntSet nd_int_set = EvalNDIntSet(info->accessed_region, iter_dom_map_on_post_order_);
+ auto it = accesses.find(info->buffer);
+ if (it != accesses.end()) {
+ NDIntSetUnionWith(&it->second, nd_int_set);
+ } else {
+ accesses[info->buffer] = nd_int_set;
+ }
+ }
+ return accesses;
+ }
+
+ /*!
+ * \brief Combine buffer accesses in the sub-tree and push the combined result into the stack.
+ * \details The access info is stored in a stack by DFS order, so that the accesses in the
+ * sub-tree are top-n elements in the stack.
+ * \param stack_top The top element of the stack before visiting the sub-tree.
+ */
+ std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> CombineRelaxAndPushStack(
+ size_t stack_top) {
+ std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> accesses =
+ CombineAndRelax(stack_top);
+ for (const auto& kv : accesses) {
+ const Buffer& buffer = kv.first;
+ const NDIntSet& int_set = kv.second;
+ buffer_access_stack_.push(arena_.make<BufferAccessInfo>(buffer, int_set));
+ }
+ return accesses;
+ }
+
+ /*! \brief Check whether the thread binding loop should be relaxed with given storage scope. */
+ static bool NeedRelaxThread(const For& loop, const runtime::StorageScope& scope) {
+ if (loop->kind != ForKind::kThreadBinding) {
+ return false;
+ }
+ ICHECK(loop->thread_binding.defined());
+ IterVar binding = loop->thread_binding.value();
+ runtime::ThreadScope ts = runtime::ThreadScope::Create(binding->thread_tag);
+
+ // When there is warp memory
+ // threadIdx.x must be set to be warp index.
+ if (scope.rank == runtime::StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) {
+ return true;
+ }
+ return static_cast<int>(scope.rank) <= ts.rank;
+ }
+
+ /**************** Class members ****************/
+
+ /*! \brief Buffer access in DFS order. */
+ std::stack<const BufferAccessInfo*> buffer_access_stack_;
+ /*! \brief The loops from the current node up to the root. */
+ std::vector<const ForNode*> ancestor_loops_;
+ /*! \brief The vars of the buffer allocated under the current block. */
+ std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_var_in_scope_;
+ /*! \brief The map from loop vars to their iter range. */
+ std::unordered_map<const VarNode*, arith::IntSet> iter_dom_map_on_post_order_;
+ /*! \brief The map from Buffer to it entire access region, used for returning. */
+ std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> buffer_access_region_;
+ /*! \brief Internal arena. */
+ support::Arena arena_;
+};
+
+/*! \brief Reallocate the buffers with minimal region. */
+class BufferCompactor : public StmtExprMutator {
+ public:
+ static Stmt Compact(
+ const PrimFunc& f,
+ const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions) {
+ std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info;
+
+ for (const auto& kv : regions) {
+ const Buffer& buffer = kv.first;
+ Region region = kv.second;
+ buffer_info.emplace(buffer, BufferAllocInfo(std::move(region)));
+ }
+ BufferCompactor compactor(std::move(buffer_info));
+ Stmt stmt = compactor(f->body);
+ return stmt;
+ }
+
+ private:
+ struct BufferAllocInfo {
+ /*! \brief The buffer access region. */
+ Region region;
+ /*!
+ * \brief The reallocated buffer with minimal size.
+ * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer).
+ */
+ Buffer new_buffer;
+
+ explicit BufferAllocInfo(Region region) : region(std::move(region)) {}
+ };
+
+ explicit BufferCompactor(
+ std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info)
+ : buffer_info_(std::move(buffer_info)) {}
+
+ Stmt VisitStmt_(const BufferStoreNode* _op) final {
+ BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
+ BufferStoreNode* op = store.CopyOnWrite();
+ RewriteBufferAccess(&op->buffer, &op->indices);
+ return std::move(store);
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
+ BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
+ BufferLoadNode* op = load.CopyOnWrite();
+ RewriteBufferAccess(&op->buffer, &op->indices);
+ return std::move(load);
+ }
+
+ Stmt VisitStmt_(const BlockNode* op) final {
+ // Step 0. Check there is no Init part.
+ ICHECK(!op->init.defined());
+ // Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo.
+ Array<Buffer> alloc_buffers = RewriteAllocBuffer(op->alloc_buffers);
+ // Step 2. Recursively rewrite BufferLoad/BufferStore.
+ Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+ // Step 3. Update block signature.
+ BlockNode* n = block.CopyOnWrite();
+ RewriteBufferRegions(&n->reads);
+ RewriteBufferRegions(&n->writes);
+ n->alloc_buffers = std::move(alloc_buffers);
+ return std::move(block);
+ }
+
+ Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) {
+ Array<Buffer> result;
+ result.reserve(buffers.size());
+ for (const Buffer& buffer : buffers) {
+ auto it = buffer_info_.find(buffer);
+ ICHECK(it != buffer_info_.end());
+ BufferAllocInfo& info = it->second;
+ Array<PrimExpr> shape;
+ shape.reserve(info.region.size());
+ for (const Range& range : info.region) {
+ shape.push_back(range->extent);
+ }
+ ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
+ n->shape = std::move(shape);
+ info.new_buffer = Buffer(std::move(n));
+ result.push_back(info.new_buffer);
+ }
+ return result;
+ }
+
+ void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) const {
+ auto it = buffer_info_.find(*buffer);
+ if (it == buffer_info_.end()) {
+ // Skip if the buffer is parameter
+ return;
+ }
+ const BufferAllocInfo& info = it->second;
+ ICHECK_EQ(indices->size(), info.region.size());
+ int ndim = info.region.size();
+ Array<PrimExpr> new_indices;
+ new_indices.reserve(ndim);
+ for (int i = 0; i < ndim; ++i) {
+ new_indices.push_back((*indices)[i] - info.region[i]->min);
+ }
+ *buffer = info.new_buffer;
+ *indices = std::move(new_indices);
+ }
+
+ void RewriteBufferRegion(Buffer* buffer, Region* region) const {
+ auto it = buffer_info_.find(*buffer);
+ if (it == buffer_info_.end()) {
+ // Skip if the buffer is parameter
+ return;
+ }
+ const BufferAllocInfo& info = it->second;
+ ICHECK_EQ(region->size(), info.region.size());
+ Region new_region;
+ new_region.reserve(info.region.size());
+ for (size_t i = 0; i < info.region.size(); ++i) {
+ const Range& range = (*region)[i];
+ new_region.push_back(Range::FromMinExtent(range->min - info.region[i]->min, range->extent));
+ }
+ *buffer = info.new_buffer;
+ *region = std::move(new_region);
+ }
+
+ void RewriteBufferRegions(Array<BufferRegion>* regions) const {
+ Array<BufferRegion> new_regions;
+ new_regions.reserve(regions->size());
+ for (const auto& region : *regions) {
+ BufferRegion buffer_region = region;
+ BufferRegionNode* p = buffer_region.CopyOnWrite();
+ RewriteBufferRegion(&p->buffer, &p->region);
+ new_regions.push_back(buffer_region);
+ }
+ *regions = std::move(new_regions);
+ }
+
+ /*! \brief The allocation information about each buffer. */
+ std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info_;
+};
+
+PrimFunc CompactBufferAllocation(PrimFunc f) {
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
+ BufferAccessRegionCollector::Collect(f);
+ fptr->body = BufferCompactor::Compact(f, region);
+ return f;
+}
+
+namespace transform {
+
+Pass CompactBufferAllocation() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ return CompactBufferAllocation(std::move(f));
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation")
+ .set_body_typed(CompactBufferAllocation);
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc
new file mode 100644
index 0000000..4c5e1dd
--- /dev/null
+++ b/src/tir/transforms/convert_blocks_to_opaque.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file convert_block_to_opaque.cc
+ * \brief Convert the blocks to opaque blocks which do not have block vars.
+ */
+
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Substitute expr via BlockRealize value bindings and convert each block into opaque
+ * blocks.
+ */
+class OpaqueBlockConverter : public StmtExprMutator {
+ public:
+ static Stmt Substitute(const PrimFunc& f) {
+ OpaqueBlockConverter substituter;
+ return substituter.VisitStmt(f->body);
+ }
+
+ private:
+ OpaqueBlockConverter() = default;
+
+ PrimExpr VisitExpr_(const VarNode* var) final {
+ auto it = var_substitutes_.find(var);
+ if (it != var_substitutes_.end()) {
+ return it->second;
+ }
+ return GetRef<Var>(var);
+ }
+
+ Stmt VisitStmt_(const BlockNode* block) final {
+ ICHECK(!block->init.defined())
+ << "Block Init part is not allowed in pass ConvertBlocksToOpaque";
+ Block new_block = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
+ if (!new_block->iter_vars.empty()) {
+ new_block.CopyOnWrite()->iter_vars.clear();
+ }
+ return std::move(new_block);
+ }
+
+ Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+ const auto* block_op = realize->block.get();
+ ICHECK(!block_op->init.defined());
+ // Step 1. Update "block vars => binding values" for substitution.
+ ICHECK_EQ(block_op->iter_vars.size(), realize->iter_values.size());
+ for (int i = 0, n = block_op->iter_vars.size(); i < n; ++i) {
+ IterVar block_var = block_op->iter_vars[i];
+ PrimExpr v = this->VisitExpr(realize->iter_values[i]);
+ var_substitutes_.emplace(block_var->var.get(), v);
+ }
+ // Step 2. Visit recursively.
+ BlockRealize new_realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize));
+ if (!new_realize->iter_values.empty()) {
+ new_realize.CopyOnWrite()->iter_values.clear();
+ }
+ return std::move(new_realize);
+ }
+
+ /*! \brief The map from block vars to thier binding values. */
+ std::unordered_map<const VarNode*, PrimExpr> var_substitutes_;
+};
+
+PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ fptr->body = OpaqueBlockConverter::Substitute(f);
+ return f;
+}
+
+namespace transform {
+
+Pass ConvertBlocksToOpaque() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ return ConvertBlocksToOpaque(std::move(f));
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque").set_body_typed(ConvertBlocksToOpaque);
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
new file mode 100644
index 0000000..7c06b5e
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -0,0 +1,331 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+def _check(original, transformed):
+ func = original
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.CompactBufferAllocation()(mod)
+ mod = tvm.tir.transform.Simplify()(mod)
+ tvm.ir.assert_structural_equal(mod["main"], transformed)
+
+
+@tvm.script.tir
+def elementwise_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with tir.block([]):
+ tir.reads(A[i, 0:16])
+ tir.writes(C[i, 0:16])
+ B = tir.alloc_buffer((16, 16), "float32")
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(A[i, j])
+ tir.writes(B[i, j])
+ B[i, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(B[i, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[i, j] * 2.0
+
+
+@tvm.script.tir
+def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with tir.block([]):
+ tir.reads(A[i, 0:16])
+ tir.writes(C[i, 0:16])
+ B = tir.alloc_buffer((1, 16), "float32")
+ for j in range(0, 16):
+ with tir.block() as []:
+ tir.reads(A[i, j])
+ tir.writes(B[0, j])
+ B[0, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with tir.block() as []:
+ tir.reads(B[0, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[0, j] * 2.0
+
+
+@tvm.script.tir
+def unschedulable_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with tir.block([]):
+ tir.reads(A[i, 0:16])
+ tir.writes(C[i, 0:16])
+ B = tir.alloc_buffer((16, 16), "float32")
+ for j in range(0, 16):
+ tir.store(B.data, i * 16 + j, A[i, j] + 1.0)
+ for j in range(0, 16):
+ C[i, j] = B[i, j] * 2.0
+
+
+@tvm.script.tir
+def param_buffer_access_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (20, 20), "float32")
+ B = tir.match_buffer(c, (20, 20), "float32")
+ for i in range(0, 16):
+ with tir.block([]):
+ tir.reads(A[i, 0:16])
+ tir.writes(B[i, 0:16])
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(A[i, j])
+ tir.writes(B[i, j])
+ B[i, j] = A[i, j] + 1.0
+
+
+@tvm.script.tir
+def shared_mem_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"):
+ for i1 in tir.thread_binding(0, 2, thread="vthread"):
+ for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"):
+ with tir.block([]):
+ tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+ tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+ B = tir.alloc_buffer((16, 16), "float32", scope="shared")
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(A[i0 * 8 + i1 * 4 + i2, j])
+ tir.writes(B[i0 * 8 + i1 * 4 + i2, j])
+ B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(B[i0 * 8 + i1 * 4 + i2, j])
+ tir.writes(C[i0 * 8 + i1 * 4 + i2, j])
+ C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0
+
+
+@tvm.script.tir
+def compacted_shared_mem_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"):
+ for i1 in tir.thread_binding(0, 2, thread="vthread"):
+ for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"):
+ with tir.block([]):
+ tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+ tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+ B = tir.alloc_buffer((8, 16), "float32", scope="shared")
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(A[i0 * 8 + i1 * 4 + i2, j])
+ tir.writes(B[i1 * 4 + i2, j])
+ B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(B[i1 * 4 + i2, j])
+ tir.writes(C[i0 * 8 + i1 * 4 + i2, j])
+ C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 2.0
+
+
+@tvm.script.tir
+def warp_mem_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"):
+ for i1 in tir.thread_binding(0, 2, thread="vthread"):
+ for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"):
+ with tir.block([]):
+ tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+ tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+ B = tir.alloc_buffer((16, 16), "float32", scope="warp")
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(A[i0 * 8 + i1 * 4 + i2, j])
+ tir.writes(B[i0 * 8 + i1 * 4 + i2, j])
+ B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(B[i0 * 8 + i1 * 4 + i2, j])
+ tir.writes(C[i0 * 8 + i1 * 4 + i2, j])
+ C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0
+
+
+@tvm.script.tir
+def compacted_warp_mem_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"):
+ for i1 in tir.thread_binding(0, 2, thread="vthread"):
+ for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"):
+ with tir.block([]):
+ tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+ tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+ B = tir.alloc_buffer((4, 16), "float32", scope="warp")
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(A[i0 * 8 + i1 * 4 + i2, j])
+ tir.writes(B[i2, j])
+ B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(B[i2, j])
+ tir.writes(C[i0 * 8 + i1 * 4 + i2, j])
+ C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0
+
+
+@tvm.script.tir
+def symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None:
+ A = tir.match_buffer(a, (n * 8,), "float32")
+ C = tir.match_buffer(c, (n * 8,), "float32")
+ for i in range(0, n):
+ with tir.block([]):
+ tir.reads(A[i * 8 : i * 8 + 8])
+ tir.writes(C[i * 8 : i * 8 + 8])
+ B = tir.alloc_buffer((n * 8,), "float32")
+ for j in range(0, 8):
+ with tir.block([]) as []:
+ tir.reads(A[i * 8 + j])
+ tir.writes(B[i * 8 + j])
+ B[i * 8 + j] = A[i * 8 + j] + 1.0
+ for j in range(0, 8):
+ with tir.block([]) as []:
+ tir.reads(B[i * 8 + j])
+ tir.writes(C[i * 8 + j])
+ C[i * 8 + j] = B[i * 8 + j] * 2.0
+
+
+@tvm.script.tir
+def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None:
+ A = tir.match_buffer(a, (n * 8,), "float32")
+ C = tir.match_buffer(c, (n * 8,), "float32")
+ for i in range(0, n):
+ with tir.block([]):
+ tir.reads(A[i * 8 : i * 8 + 8])
+ tir.writes(C[i * 8 : i * 8 + 8])
+ B = tir.alloc_buffer((8,), "float32")
+ for j in range(0, 8):
+ with tir.block([]) as []:
+ tir.reads(A[i * 8 + j])
+ tir.writes(B[j])
+ B[j] = A[i * 8 + j] + 1.0
+ for j in range(0, 8):
+ with tir.block([]) as []:
+ tir.reads(B[j])
+ tir.writes(C[i * 8 + j])
+ C[i * 8 + j] = B[j] * 2.0
+
+
+@tvm.script.tir
+def complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None:
+ A = tir.match_buffer(a, (8, 8), "float32")
+ C = tir.match_buffer(c, (8, 8), "float32")
+ for i in range(0, 8):
+ with tir.block([]):
+ tir.reads(A[0, 8])
+ tir.writes(C[0, 8])
+ B = tir.alloc_buffer((8, 8), "float32")
+ for j in range(0, 4):
+ with tir.block([]) as []:
+ D = tir.alloc_buffer((8, 8), "float32")
+ tir.reads(A[i, j])
+ tir.writes(B[i, j])
+ for k in range(4, 8):
+ D[k, j] = 1.0
+ for k in range(2, 4):
+ tir.store(B.data, j, A[i, j] + D[k, j])
+ for j in range(3, 5):
+ with tir.block([]) as []:
+ tir.reads(B[i, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[i, j]
+ for j in range(6, 8):
+ with tir.block([]) as []:
+ tir.reads(B[i, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[i, j]
+
+
+@tvm.script.tir
+def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None:
+ A = tir.match_buffer(a, (8, 8), "float32")
+ C = tir.match_buffer(c, (8, 8), "float32")
+ for i in range(0, 8):
+ with tir.block([]):
+ tir.reads(A[0, 8])
+ tir.writes(C[0, 8])
+ B = tir.alloc_buffer((1, 8), "float32")
+ for j in range(0, 4):
+ with tir.block([]) as []:
+ D = tir.alloc_buffer((6, 1), "float32")
+ tir.reads(A[i, j])
+ tir.writes(B[0, j])
+ for k in range(4, 8):
+ D[k - 2, 0] = 1.0
+ for k in range(2, 4):
+ tir.store(B.data, j, A[i, j] + D[k - 2, 0])
+ for j in range(3, 5):
+ with tir.block([]) as []:
+ tir.reads(B[0, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[0, j]
+ for j in range(6, 8):
+ with tir.block([]) as []:
+ tir.reads(B[0, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[0, j]
+
+
+def test_elementwise():
+ _check(elementwise_func, compacted_elementwise_func)
+
+
+def test_unschedulable_block():
+ _check(unschedulable_func, unschedulable_func) # changes nothing
+
+
+def test_param_access():
+ _check(param_buffer_access_func, param_buffer_access_func) # changes nothing
+
+
+def test_shared_mem():
+ _check(shared_mem_func, compacted_shared_mem_func)
+
+
+def test_warp_mem():
+ _check(warp_mem_func, compacted_warp_mem_func)
+
+
+def test_symbolic():
+ _check(symbolic_func, compacted_symbolic_func)
+
+
+def test_complex():
+ _check(complex_func, compacted_complex_func)
+
+
+if __name__ == "__main__":
+ test_elementwise()
+ test_unschedulable_block()
+ test_param_access()
+ test_shared_mem()
+ test_warp_mem()
+ test_symbolic()
+ test_complex()
diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
new file mode 100644
index 0000000..38fe1c9
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+def _check(original, transformed):
+ func = original
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
+ mod = tvm.tir.transform.Simplify()(mod)
+ tvm.ir.assert_structural_equal(mod["main"], transformed)
+
+
+@tvm.script.tir
+def elementwise_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with tir.block([]):
+ tir.reads(A[i, 0:16])
+ tir.writes(C[i, 0:16])
+ B = tir.alloc_buffer((16, 16), "float32")
+ for j in range(0, 16):
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.bind(vi, i)
+ tir.bind(vj, j)
+ B[vi, vj] = A[vi, vj] + 1.0
+ for j in range(0, 16):
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.bind(vi, i)
+ tir.bind(vj, j)
+ C[vi, vj] = B[vi, vj] * 2.0
+
+
+@tvm.script.tir
+def substituted_elementwise_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with tir.block([]):
+ tir.reads(A[i, 0:16])
+ tir.writes(C[i, 0:16])
+ B = tir.alloc_buffer([16, 16], "float32")
+ for j in range(0, 16):
+ with tir.block() as []:
+ tir.reads(A[i, j])
+ tir.writes(B[i, j])
+ B[i, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with tir.block() as []:
+ tir.reads(B[i, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[i, j] * 2.0
+
+
+def test_elementwise():
+ _check(elementwise_func, substituted_elementwise_func)
+
+
+if __name__ == "__main__":
+ test_elementwise()