You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/05/06 22:21:23 UTC
[tvm] branch main updated: [TIR] Add schedule primitive SetAxisSeparator (#11225)
This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 31be30062b [TIR] Add schedule primitive SetAxisSeparator (#11225)
31be30062b is described below
commit 31be30062badf658bc71cb3a906411291a7db12a
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri May 6 15:21:18 2022 -0700
[TIR] Add schedule primitive SetAxisSeparator (#11225)
* [TIR] Add schedule primitive SetAxisSeparator
* remove unused include
* Move ReplaceBufferMutator impl to cc file
---
include/tvm/tir/schedule/schedule.h | 12 ++
python/tvm/script/tir/__init__.pyi | 3 +
python/tvm/script/tir/special_stmt.py | 12 +-
python/tvm/script/tir/ty.py | 3 +
python/tvm/tir/schedule/schedule.py | 81 +++++++++++-
src/printer/tvmscript_printer.cc | 6 +
src/tir/schedule/concrete_schedule.cc | 10 ++
src/tir/schedule/concrete_schedule.h | 3 +
src/tir/schedule/primitive.h | 11 ++
src/tir/schedule/primitive/block_annotate.cc | 107 ++--------------
.../schedule/primitive/layout_transformation.cc | 112 +++++++++++++++++
src/tir/schedule/schedule.cc | 7 +-
src/tir/schedule/traced_schedule.cc | 13 ++
src/tir/schedule/traced_schedule.h | 3 +
src/tir/schedule/transform.cc | 81 ++++++++++++
src/tir/schedule/transform.h | 55 ++++++++
.../test_tir_schedule_set_axis_separator.py | 139 +++++++++++++++++++++
tests/python/unittest/test_tvmscript_roundtrip.py | 20 +++
18 files changed, 575 insertions(+), 103 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index e78cef2cac..18e15d1670 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -545,6 +545,18 @@ class ScheduleNode : public runtime::Object {
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;
+ /*!
+ * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
+ * or write index
+ * \param block_rv The block that accesses the target buffer.
+ * \param buffer_index The index of the buffer in block's read or write region.
+ * \param buffer_index_type The type of the buffer index, kRead or kWrite.
+ * \param axis_separators The axis separator of the buffer
+ */
+ virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type,
+ const Array<IntImm>& axis_separators) = 0;
+
/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
index 9727a8db63..e4513feb43 100644
--- a/python/tvm/script/tir/__init__.pyi
+++ b/python/tvm/script/tir/__init__.pyi
@@ -199,6 +199,7 @@ def match_buffer(
align: int = -1,
offset_factor: int = 0,
buffer_type: str = "default",
+ axis_separators: Optional[List[int]] = None,
) -> Buffer: ...
def buffer_decl(
shape: Sequence[Union[PrimExpr, int]],
@@ -210,6 +211,7 @@ def buffer_decl(
align: int = -1,
offset_factor: int = 0,
buffer_type: str = "default",
+ axis_separators: Optional[List[int]] = None,
) -> Buffer: ...
def alloc_buffer(
shape: Sequence[Union[PrimExpr, int]],
@@ -221,6 +223,7 @@ def alloc_buffer(
align: int = -1,
offset_factor: int = 0,
buffer_type: str = "default",
+ axis_separators: Optional[List[int]] = None,
) -> Buffer: ...
"""
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
index 45eaa8b8be..39a345de7f 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/tir/special_stmt.py
@@ -100,7 +100,7 @@ class SpecialStmt:
@register
class MatchBuffer(SpecialStmt):
"""Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align,
- offset_factor, buffer_type)
+ offset_factor, buffer_type, axis_separators)
Note
----
@@ -131,6 +131,7 @@ class MatchBuffer(SpecialStmt):
align=-1,
offset_factor=0,
buffer_type="default",
+ axis_separators=None,
span=None,
):
if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
@@ -157,6 +158,7 @@ class MatchBuffer(SpecialStmt):
align,
offset_factor,
buffer_type,
+ axis_separators,
span=span,
)
if isinstance(param, tvm.tir.Var):
@@ -184,7 +186,7 @@ class MatchBuffer(SpecialStmt):
@register
class BufferDeclare(SpecialStmt):
"""Special Stmt buffer_decl(shape, dtype, data, strides, elem_offset, scope, align,
- offset_factor, buffer_type)
+ offset_factor, buffer_type, axis_separators)
Example
-------
.. code-block:: python
@@ -202,6 +204,7 @@ class BufferDeclare(SpecialStmt):
align=-1,
offset_factor=0,
buffer_type="default",
+ axis_separators=None,
span=None,
):
if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
@@ -228,6 +231,7 @@ class BufferDeclare(SpecialStmt):
align,
offset_factor,
buffer_type,
+ axis_separators,
span=span,
)
self.context.update_symbol(buffer_name, buffer, self.node)
@@ -239,7 +243,7 @@ class BufferDeclare(SpecialStmt):
@register
class AllocBuffer(SpecialStmt):
"""Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align,
- offset_factor, buffer_type)
+ offset_factor, buffer_type, axis_separators)
Example
-------
@@ -259,6 +263,7 @@ class AllocBuffer(SpecialStmt):
align=-1,
offset_factor=0,
buffer_type="default",
+ axis_separators=None,
span=None,
):
if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
@@ -286,6 +291,7 @@ class AllocBuffer(SpecialStmt):
align,
offset_factor,
buffer_type,
+ axis_separators,
span=span,
)
if self.context.current_block_scope():
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index dfe2fbbe42..7d90dec646 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -121,6 +121,7 @@ class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods,
align=-1,
offset_factor=0,
buffer_type="default",
+ axis_separators=None,
span=None,
):
if strides is None:
@@ -140,6 +141,7 @@ class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods,
align,
offset_factor,
buffer_type,
+ axis_separators,
span=span,
)
return buffer
@@ -160,6 +162,7 @@ class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods,
align=-1,
offset_factor=0,
buffer_type="default",
+ axis_separators=None,
span=None,
):
"""
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index d537db2800..8bfd906315 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2125,7 +2125,7 @@ class Schedule(Object):
"""Apply a transformation represented by IndexMap to buffer
Parameters
----------
- block_rv : BlockRV
+ block : BlockRV
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
@@ -2190,6 +2190,85 @@ class Schedule(Object):
self, block, buffer_index, buffer_index_type_enum, index_map
)
+ @type_checked
+ def set_axis_separator(
+ self,
+ block: BlockRV,
+ buffer_index: int,
+ buffer_index_type: str,
+ axis_separators: Optional[List[int]],
+ ) -> None:
+ """Set the axis separator of a buffer, where the buffer is specified by a block and a read
+ or write index.
+
+ Parameters
+ ----------
+ block : BlockRV
+ The block that accesses the target buffer
+ buffer_index: int
+ The index of the buffer in block's read or write region
+ buffer_index_type : str
+ Type of the buffer index, "read" or "write"
+ axis_separators : Optional[List[int]]
+ The axis separators.
+
+ Examples
+ --------
+
+ Before set_axis_separator, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_set_axis_separator(
+ A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
+ ) -> None:
+ B = T.alloc_buffer((128, 128), dtype="float32")
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ Create the schedule and do set_axis_separator:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_set_axis_separator)
+ sch.set_axis_separators(sch.get_block("B"), buffer_index=0, buffer_index_type="write",
+ axis_separators=[1])
+ print(sch.mod["main"].script())
+
+ After applying set_axis_separator, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_set_axis_separators(
+ A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
+ ) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1])
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+ """
+ axis_separators = axis_separators or []
+ assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type"
+ buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
+ _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member
+ self, block, buffer_index, buffer_index_type_enum, axis_separators
+ )
+
########## Schedule: Misc ##########
@type_checked
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index aeb118a49c..6f8d10b320 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -502,6 +502,9 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
if (buf->buffer_type != BufferType::kDefault) {
doc << ", type=" << Doc::StrLiteral("auto");
}
+ if (buf->axis_separators.size()) {
+ doc << ", axis_separators=" << Print(buf->axis_separators);
+ }
return doc;
}
@@ -606,6 +609,9 @@ bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
if (buf->buffer_type != BufferType::kDefault) {
return false;
}
+ if (buf->axis_separators.size()) {
+ return false;
+ }
return true;
}
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 59a19631fc..7b953220f2 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -693,6 +693,16 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_i
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}
+void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type,
+ const Array<IntImm>& axis_separators) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ tir::SetAxisSeparator(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
+ axis_separators);
+ TVM_TIR_SCHEDULE_END("set-axis-separator", this->error_render_level_);
+ this->state_->DebugVerify();
+}
+
/******** Schedule: Misc ********/
} // namespace tir
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 4534406d79..9293aa3493 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -134,6 +134,9 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
const IndexMap& index_map) override;
+ void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type,
+ const Array<IntImm>& axis_separators) override;
/******** Schedule: Misc ********/
void EnterPostproc() override {}
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 5e21075d58..d55b896934 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -377,6 +377,17 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu
*/
TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& storage_scope);
+/*!
+ * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
+ * or write index
+ * \param block_rv The block that accesses the target buffer.
+ * \param buffer_index The index of the buffer in block's read or write region.
+ * \param buffer_index_type The type of the buffer index, kRead or kWrite.
+ * \param axis_separators The axis separator of the buffer
+ */
+TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
+ BufferIndexType buffer_index_type,
+ const Array<IntImm>& axis_separators);
/******** Schedule: Blockize & Tensorize ********/
diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc
index f9cec421cd..ede239878a 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include "../../ir/functor_common.h"
#include "../utils.h"
namespace tvm {
@@ -202,7 +201,7 @@ class StorageAlignInvalidAnnotationError : public ScheduleError {
* \brief A helper mutator which recursively mutates the old buffer's storage scope and collects
* the block sref reuse information for the following replacement.
*/
-class StorageScopeMutator : StmtExprMutator {
+class StorageScopeMutator : private ReplaceBufferMutator {
public:
/*!
* \param allocate_site The block where `old_buffer` was allocated.
@@ -222,107 +221,19 @@ class StorageScopeMutator : StmtExprMutator {
private:
StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, String storage_scope,
Map<Block, Block>* block_sref_reuse)
- : storage_scope_(std::move(storage_scope)), block_sref_reuse_(block_sref_reuse) {
- buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
- }
-
- PrimExpr VisitExpr_(const VarNode* var) final {
- auto it = buffer_var_map_.find(var);
- return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
- }
-
- PrimExpr VisitExpr_(const BufferLoadNode* load) final {
- BufferLoad res = Downcast<BufferLoad>(ExprMutator::VisitExpr_(load));
-
- auto it = buffer_var_map_.find(res->buffer->data.get());
- if (it != buffer_var_map_.end()) {
- ObjectPtr<BufferLoadNode> ptr = make_object<BufferLoadNode>(*res.get());
- ptr->buffer = it->second;
- return PrimExpr(ptr);
- } else {
- return std::move(res);
- }
- }
+ : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse) {}
- Stmt VisitStmt_(const BufferStoreNode* store) final {
- BufferStore res = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
-
- auto it = buffer_var_map_.find(res->buffer->data.get());
+ MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
+ auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
if (it != buffer_var_map_.end()) {
- ObjectPtr<BufferStoreNode> ptr = make_object<BufferStoreNode>(*res.get());
- ptr->buffer = it->second;
- return Stmt(ptr);
+ Buffer new_target_buffer = WithScope(match_buffer->buffer, it->second.scope());
+ buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
+ return MatchBufferRegion(new_target_buffer,
+ BufferRegion(it->second, match_buffer->source->region));
} else {
- return std::move(res);
+ return match_buffer;
}
}
-
- Stmt VisitStmt_(const BlockNode* block) final {
- // To reduce the number of blocks in block sref reuse map, we check whether the block is really
- // mutated (i.e., the old buffer appears in the block). If so, we return the block after
- // mutation. Otherwise we just return the original block.
-
- // Define the mutation functions.
- auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) {
- auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
- if (it != buffer_var_map_.end()) {
- Buffer new_target_buffer = WithScope(match_buffer->buffer, storage_scope_);
- buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
- return MatchBufferRegion(new_target_buffer,
- BufferRegion(it->second, match_buffer->source->region));
- } else {
- return match_buffer;
- }
- };
- auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
- auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
- return it == buffer_var_map_.end() ? buffer_region
- : BufferRegion(it->second, buffer_region->region);
- };
- auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
- auto it = buffer_var_map_.find(buffer->data.get());
- return it == buffer_var_map_.end() ? buffer : it->second;
- };
-
- // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion,
- // the storage scope of the target buffer also needs to be set.
- Array<MatchBufferRegion> match_buffers =
- MutateArray(block->match_buffers, f_mutate_match_buffers);
- // Step 2. Mutate the read/write region.
- Array<BufferRegion> reads = MutateArray(block->reads, f_mutate_read_write_region);
- Array<BufferRegion> writes = MutateArray(block->writes, f_mutate_read_write_region);
- // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block.
- Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers);
- // Step 4. Recursively mutate the block.
- Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
-
- if (mutated_block.get() == block && reads.same_as(mutated_block->reads) &&
- writes.same_as(mutated_block->writes) &&
- alloc_buffers.same_as(mutated_block->alloc_buffers) &&
- match_buffers.same_as(mutated_block->match_buffers)) {
- return GetRef<Block>(block);
- } else {
- ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get());
- n->reads = std::move(reads);
- n->writes = std::move(writes);
- n->alloc_buffers = std::move(alloc_buffers);
- n->match_buffers = std::move(match_buffers);
-
- Block new_block(n);
- block_sref_reuse_->Set(GetRef<Block>(block), new_block);
- return std::move(new_block);
- }
- }
-
- /*! \brief The storage scope to be set. */
- String storage_scope_;
- /*!
- * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
- * MatchBufferRegion.
- */
- std::unordered_map<const VarNode*, Buffer> buffer_var_map_;
- /*! \brief The block sref reuse map for the following replacement */
- Map<Block, Block>* block_sref_reuse_;
};
void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis,
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index fcfce5d217..b133f537b5 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -185,6 +185,84 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
self->Replace(scope_sref, new_scope_block, block_sref_reuse);
}
+class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
+ public:
+ static Block Mutate(const Block& scope_block, const Buffer& old_buffer, Buffer new_buffer,
+ Map<Block, Block>* block_sref_reuse) {
+ BufferAxisSeparatorMutator mutator(old_buffer, std::move(new_buffer), block_sref_reuse);
+ return Downcast<Block>(mutator.VisitStmt(scope_block));
+ }
+
+ private:
+ BufferAxisSeparatorMutator(const Buffer& old_buffer, Buffer new_buffer,
+ Map<Block, Block>* block_sref_reuse)
+ : ReplaceBufferMutator(old_buffer, new_buffer, block_sref_reuse) {}
+
+ MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
+ auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ const Buffer& new_source_buffer = it->second;
+ Buffer new_target_buffer = match_buffer->buffer;
+ new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
+ if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) {
+ LOG(WARNING)
+ << "Target buffer in match_buffer doesn't have the same dimensionality as its source "
+ "buffer. `axis_separators` for the target buffer might be incorrect.";
+ }
+ buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer;
+ return MatchBufferRegion(new_target_buffer,
+ BufferRegion(new_source_buffer, match_buffer->source->region));
+ }
+ return match_buffer;
+ }
+};
+
+void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
+ BufferIndexType buffer_index_type, const Array<IntImm>& axis_separators) {
+ const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+ Buffer old_buffer = GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index,
+ buffer_index_type == BufferIndexType::kWrite);
+ Optional<StmtSRef> defining_site_sref;
+ bool is_alloc;
+ std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer);
+ if (defining_site_sref.defined() && !is_alloc) {
+ throw BufferIsSubregionError(self->mod, old_buffer);
+ }
+
+ StmtSRef scope_sref = defining_site_sref.defined()
+ ? defining_site_sref.value()
+ : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
+ const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+
+ // Step 1: Check and update axis_separators of the buffer.
+ Buffer new_buffer = old_buffer;
+ new_buffer.CopyOnWrite()->axis_separators = axis_separators;
+
+ Map<Block, Block> block_sref_reuse;
+
+ // Step 2: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc.
+ Block new_scope_block = BufferAxisSeparatorMutator::Mutate(GetRef<Block>(scope_block), old_buffer,
+ new_buffer, &block_sref_reuse);
+ if (!defining_site_sref.defined()) {
+ // mutate buffer_map of the PrimFunc
+ GlobalVar g_var;
+ GetRootPrimFunc(self->mod, scope_block, &g_var);
+ IRModuleNode* new_mod = self->mod.CopyOnWrite();
+ MapNode* new_map = new_mod->functions.CopyOnWrite();
+ PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+ PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+ MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite();
+ for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) {
+ if ((*it).second.same_as(old_buffer)) {
+ (*it).second = new_buffer;
+ }
+ }
+ new_map->at(g_var) = std::move(ref_new_func);
+ }
+
+ // Step 4: Replace the scope block with the new block
+ self->Replace(scope_sref, new_scope_block, block_sref_reuse);
+}
/******** InstructionKind Registration ********/
struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits> {
@@ -238,7 +316,41 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
friend struct ::tvm::tir::UnpackedInstTraits;
};
+struct SetAxisSeparatorTraits : public UnpackedInstTraits<SetAxisSeparatorTraits> {
+ static constexpr const char* kName = "SetAxisSeparator";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 3;
+ static constexpr size_t kNumDecisions = 0;
+
+ static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
+ Integer buffer_index_type, Array<IntImm> axis_separators) {
+ return sch->SetAxisSeparator(block_rv, buffer_index,
+ static_cast<BufferIndexType>(buffer_index_type->value),
+ axis_separators);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
+ Integer buffer_index_type, Array<IntImm> axis_separators) {
+ PythonAPICall py("set_axis_separator");
+ py.Input("block", block_rv);
+ py.Input("buffer_index", buffer_index);
+ py.Input("buffer_index_type", '"' +
+ std::string(BufferIndexType2Str(
+ static_cast<BufferIndexType>(buffer_index_type->value))) +
+ '"');
+ py.Input("axis_separators", axis_separators);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits);
+TVM_REGISTER_INST_KIND_TRAITS(SetAxisSeparatorTraits);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 82cd0a4a35..8dc0c52111 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -233,7 +233,12 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
return self->TransformLayout(block_rv, buffer_index,
static_cast<BufferIndexType>(buffer_index_type), index_map);
});
-
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator")
+ .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
+ int buffer_index_type, const Array<IntImm>& axis_separators) {
+ return self->SetAxisSeparator(
+ block_rv, buffer_index, static_cast<BufferIndexType>(buffer_index_type), axis_separators);
+ });
/******** (FFI) Misc ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
.set_body_method<Schedule>(&ScheduleNode::EnterPostproc);
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 417f80dd93..865b6f3784 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -442,6 +442,19 @@ void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_ind
/*outputs=*/{}));
}
+void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type,
+ const Array<IntImm>& axis_separators) {
+ ConcreteScheduleNode::SetAxisSeparator(block_rv, buffer_index, buffer_index_type,
+ axis_separators);
+ static const InstructionKind& kind = InstructionKind::Get("SetAxisSeparator");
+ trace_->Append(/*inst=*/Instruction(
+ /*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), axis_separators},
+ /*outputs=*/{}));
+}
+
/******** Schedule: Misc ********/
void TracedScheduleNode::EnterPostproc() {
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index 442b50ad0c..12c076d886 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -98,6 +98,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
const IndexMap& index_map) override;
+ void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type,
+ const Array<IntImm>& axis_separators) final;
/******** Schedule: Misc ********/
void EnterPostproc() final;
};
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index b2e71a9a0d..6c4f3e1b7a 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -70,6 +70,87 @@ Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, c
return match_buffers;
}
+/******** ReplaceBufferMutator ********/
+ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+ Map<Block, Block>* block_sref_reuse)
+ : block_sref_reuse_(block_sref_reuse) {
+ buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+}
+
+PrimExpr ReplaceBufferMutator::VisitExpr_(const VarNode* var) {
+ auto it = buffer_var_map_.find(var);
+ return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
+}
+
+Stmt ReplaceBufferMutator::VisitStmt_(const BufferStoreNode* op) {
+ auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+ return VisitBufferAccess(std::move(node));
+}
+
+PrimExpr ReplaceBufferMutator::VisitExpr_(const BufferLoadNode* op) {
+ auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+ return VisitBufferAccess(std::move(node));
+}
+
+MatchBufferRegion ReplaceBufferMutator::VisitMatchBufferRegion(
+ const MatchBufferRegion& match_buffer) {
+ auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ return MatchBufferRegion(match_buffer->buffer,
+ BufferRegion(it->second, match_buffer->source->region));
+ } else {
+ return match_buffer;
+ }
+}
+
+Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) {
+ // To reduce the number of blocks in block sref reuse map, we check whether the block is really
+ // mutated (i.e., the old buffer appears in the block). If so, we return the block after
+ // mutation. Otherwise we just return the original block.
+
+ auto f_mutate_match_buffer = [this](const MatchBufferRegion& match_buffer) {
+ return this->VisitMatchBufferRegion(match_buffer);
+ };
+ auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
+ auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+ return it == buffer_var_map_.end() ? buffer_region
+ : BufferRegion(it->second, buffer_region->region);
+ };
+ auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
+ auto it = buffer_var_map_.find(buffer->data.get());
+ return it == buffer_var_map_.end() ? buffer : it->second;
+ };
+
+ // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion,
+ Array<MatchBufferRegion> match_buffers = MutateArray(block->match_buffers, f_mutate_match_buffer);
+ // Step 2. Mutate the read/write region.
+ Array<BufferRegion> reads = MutateArray(block->reads, f_mutate_read_write_region);
+ Array<BufferRegion> writes = MutateArray(block->writes, f_mutate_read_write_region);
+ // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block.
+ Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers);
+ // Step 4. Recursively mutate the block.
+ Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
+
+ if (mutated_block.get() == block && reads.same_as(mutated_block->reads) &&
+ writes.same_as(mutated_block->writes) &&
+ alloc_buffers.same_as(mutated_block->alloc_buffers) &&
+ match_buffers.same_as(mutated_block->match_buffers)) {
+ return GetRef<Block>(block);
+ } else {
+ ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get());
+ n->reads = std::move(reads);
+ n->writes = std::move(writes);
+ n->alloc_buffers = std::move(alloc_buffers);
+ n->match_buffers = std::move(match_buffers);
+
+ Block new_block(n);
+ if (block_sref_reuse_ != nullptr) {
+ block_sref_reuse_->Set(GetRef<Block>(block), new_block);
+ }
+ return std::move(new_block);
+ }
+}
+
/******** Block Removal ********/
void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref,
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 12326b3418..52e27350d4 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -21,6 +21,12 @@
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+#include "../ir/functor_common.h"
namespace tvm {
namespace tir {
@@ -66,6 +72,55 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
const Buffer& target);
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+ /*!
+ * \brief The constructor
+ * \param old_buffer The old buffer
+ * \param new_buffer The new buffer
+ * \param block_sref_reuse Optional map to record mapping between old and new blocks that reuse
+ * sref.
+ */
+ ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+ Map<Block, Block>* block_sref_reuse);
+
+ protected:
+ PrimExpr VisitExpr_(const VarNode* var) final;
+
+ template <typename Node>
+ Node VisitBufferAccess(Node node) {
+ auto it = buffer_var_map_.find(node->buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ node.CopyOnWrite()->buffer = it->second;
+ }
+ return node;
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final;
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final;
+
+ virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer);
+
+ Stmt VisitStmt_(const BlockNode* block) final;
+
+ /*!
+ * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
+ * MatchBufferRegion.
+ */
+ std::unordered_map<const VarNode*, Buffer> buffer_var_map_;
+ /*! \brief The block sref reuse map for the following replacement */
+ Map<Block, Block>* block_sref_reuse_;
+};
+
/******** Block Removal ********/
/*!
diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
new file mode 100644
index 0000000000..d829a3f1b7
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+@T.prim_func
+def element_wise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
+ B = T.alloc_buffer((128, 128), dtype="float32")
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + 1.0
+
+
+@T.prim_func
+def element_wise_set_axis_separator(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1])
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+
+
+@T.prim_func
+def element_wise_set_axis_separator_input_buffer(A: T.Buffer(shape=(128, 128), dtype="float32", axis_separators=(1,)), C: T.Buffer[(128, 128), "float32"]) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+
+
+@T.prim_func
+def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
+ B = T.alloc_buffer((128, 128), dtype="float32")
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1)
+ B_subregion0[()] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1)
+ C[vi, vj] = B_subregion1[()] + 1.0
+
+
+@T.prim_func
+def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1])
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B_subregion0 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
+ B_subregion0[()] = A[vi, vj] * T.float32(2)
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B_subregion1 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
+ C[vi, vj] = B_subregion1[()] + T.float32(1)
+
+
+# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+def test_set_axis_separator():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ s.set_axis_separator(s.get_block("B"), 0, "write", [1])
+ tvm.ir.assert_structural_equal(element_wise_set_axis_separator, s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_set_scope_fail_on_index_out_of_bound():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ with pytest.raises(tvm.tir.ScheduleError):
+ s.set_axis_separator(s.get_block("B"), 1, "write",[1])
+ with pytest.raises(tvm.tir.ScheduleError):
+ s.set_axis_separator(s.get_block("B"), -1, "read",[1])
+
+
+def test_set_axis_separator_input_buffer():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ s.set_axis_separator(s.get_block("B"), 0, "read", [1])
+ tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_set_axis_separator_subregion():
+ func = element_wise_subregion_match
+ s = tir.Schedule(func, debug_mask='all')
+ s.set_axis_separator(s.get_block("B"), 0, "write", [1])
+ tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 0437576462..c704baebc7 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3253,6 +3253,25 @@ def pointer_type():
return func_with_ptr_type_annotations
+def buffer_axis_separator():
+ @T.prim_func
+ def element_wise(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128), "float32", axis_separators=[1])
+ C = T.match_buffer(c, (128, 128), "float32")
+ B = T.alloc_buffer((128, 128), "float32", axis_separators=[1])
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+
+ return element_wise
+
+
ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
@@ -3288,6 +3307,7 @@ ir_generator = tvm.testing.parameter(
int64_support,
string_annotation_escaping,
pointer_type,
+ buffer_axis_separator,
)