You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/07 12:07:04 UTC
[tvm] branch main updated: [M1b] Scaffolding ScheduleState data
structure (#7765)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bf0f87d [M1b] Scaffolding ScheduleState data structure (#7765)
bf0f87d is described below
commit bf0f87dcc6400ab15626538cdf694cb9fa07a7df
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Apr 7 05:04:48 2021 -0700
[M1b] Scaffolding ScheduleState data structure (#7765)
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Cody Yu <co...@gmail.com>
Co-authored-by: Jared Roesch <ro...@gmail.com>
---
include/tvm/runtime/object.h | 15 +
include/tvm/tir/schedule/block_scope.h | 271 +++++++
include/tvm/tir/schedule/state.h | 216 ++++++
python/tvm/tir/__init__.py | 3 +
python/tvm/tir/schedule/__init__.py | 21 +
python/tvm/tir/schedule/_ffi_api_schedule.py | 20 +
python/tvm/tir/schedule/block_scope.py | 152 ++++
python/tvm/tir/schedule/state.py | 185 +++++
python/tvm/tir/stmt.py | 19 +-
src/printer/tvmscript_printer.cc | 1 +
src/tir/analysis/var_touch.cc | 2 +-
src/tir/ir/stmt.cc | 6 +-
src/tir/schedule/analysis.h | 47 ++
src/tir/schedule/analysis/analysis.cc | 60 ++
src/tir/schedule/analysis/verify.cc | 146 ++++
src/tir/schedule/block_scope.cc | 162 +++++
src/tir/schedule/state.cc | 870 +++++++++++++++++++++++
src/tir/schedule/utils.h | 93 +++
tests/python/unittest/test_tir_block_scope.py | 145 ++++
tests/python/unittest/test_tir_schedule_state.py | 352 +++++++++
20 files changed, 2776 insertions(+), 10 deletions(-)
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index 048fc1d..f13bdee 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -739,6 +739,21 @@ struct ObjectPtrEqual {
ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
using ContainerType = ObjectName;
+/*
+ * \brief Define object reference methods that is both not nullable and mutable.
+ *
+ * \param TypeName The object type name
+ * \param ParentType The parent type of the objectref
+ * \param ObjectName The type name of the object.
+ */
+#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
+ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
+ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
+ ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
+ ObjectName* get() const { return operator->(); } \
+ static constexpr bool _type_is_nullable = false; \
+ using ContainerType = ObjectName;
+
/*!
* \brief Define CopyOnWrite function in an ObjectRef.
* \param ObjectName The Type of the Node.
diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h
new file mode 100644
index 0000000..49d5e7f
--- /dev/null
+++ b/include/tvm/tir/schedule/block_scope.h
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: A StmtSRef that points to a TensorIR block.
+ * - Loop sref: A StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent reference of an sref is the block or loop reference to the closest
+ schedulable statement. We define closest to be the nearest schedulable statement of an ancestor in
+ the AST.
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+ /*!
+ * \brief The block or `for` stmt the object refers to
+ * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+ * optimization on statements when possible. The strong reference is held in the ScheduleState.
+ */
+ const StmtNode* stmt;
+ /*! \brief The parent sref. */
+ StmtSRefNode* parent;
+ /*!
+ * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+ * then `seq_index` is set to its index; otherwise `seq_index` is -1
+ */
+ int64_t seq_index;
+
+ void VisitAttrs(AttrVisitor* v) {
+ // `stmt` is not visited
+ // `parent` is not visited
+ v->Visit("seq_index", &seq_index);
+ }
+
+ static constexpr const char* _type_key = "tir.StmtSRef";
+ TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+ /*! \brief Reset the object inplace to the invalid state */
+ void Reset() {
+ this->stmt = nullptr;
+ this->parent = nullptr;
+ this->seq_index = -1;
+ }
+
+ /*!
+ * \brief Get the referenced statement with proper type checking.
+ * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+ * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+ * tvm::tir::BlockNode or tvm::tir::ForNode
+ * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+ */
+ template <typename StmtType>
+ const StmtType* StmtAs() const {
+ if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+ return static_cast<const StmtType*>(stmt);
+ } else {
+ return nullptr;
+ }
+ }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor
+ * \param stmt The corresponding stmt node, can be either block or for loop.
+ * \param parent The parent sref.
+ * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+ * -1 if the parent does not contain multiple children.
+ */
+ TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+ /*! \return The mutable pointer to the StmtSRefNode */
+ StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+ /*!
+ * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+ * only serving as a "mark" to hint compute-at to do the work of compute-inline
+ * \note This is only as a faked loop sref for compute-at and reverse-compute-at,
+ * i.e.
+ *
+ * compute-at(block, loop_sref):
+ * compute-inline(block) if loop_sref.same_as(InlineMark())
+ * no-op if loop_sref.same_as(RootMark())
+ * compute-at-impl(block, loop_sref) otherwise
+ */
+ TVM_DLL static StmtSRef InlineMark();
+ /*!
+ * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+ * only serving as a "mark" to hint compute-at to do nothing
+ * \note This is only as a faked loop sref for compute-at and reverse-compute-at,
+ * i.e.
+ *
+ * compute-at(block, loop_sref):
+ * compute-inline(block) if loop_sref.same_as(InlineMark())
+ * no-op if loop_sref.same_as(RootMark())
+ * compute-at-impl(block, loop_sref) otherwise
+ */
+ TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+ kRAW = 0,
+ kWAW = 1,
+ kWAR = 2,
+ kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+ /*! \brief The source of the dependency relation */
+ StmtSRef src;
+ /*! \brief The destination of the dependency relation */
+ StmtSRef dst;
+ /*! \brief The dependency kind */
+ DepKind kind;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("src", &src);
+ v->Visit("dst", &dst);
+ v->Visit("kind", &kind);
+ }
+
+ static constexpr const char* _type_key = "tir.Dependency";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+ /*! \brief Constructor */
+ TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object with 1-to-1 correspondence with each block reference in the sref tree.
+ * This data structure is used to track the producer-consumer dependencies between blocks.
+ * For example even leaf nodes have a scope node, even though they have no dependencies.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ * - scope root: a block sref
+ * - internal srefs: loop srefs
+ * - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+ /*!
+ * \brief Lookup table for the `src` of dependencies
+ * \note We intentionally didn't use tvm::Map as the data structure, because we need the values
+ * inside to be mutable so that they could be further maintained properly during transformations.
+ */
+ std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+ /*! \brief Lookup table for the `dst` of dependencies */
+ std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+ /*! \brief The mapping from the buffer to the blocks who write it */
+ std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
+ /*!
+ * \brief This property indicates that the block scope (rooted at its corresponding block) is
+ * equivalent to of a stage pipeline. Under the following conditions:
+ *
+ * 1) The region cover property holds for every of its child blocks
+ * 2) No write-after-read dependency
+ */
+ bool stage_pipeline{false};
+
+ void VisitAttrs(AttrVisitor* v) {}
+
+ static constexpr const char* _type_key = "tir.BlockScope";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object);
+
+ public:
+ /******** Dependency ********/
+ /*!
+ * \brief Get all dependencies whose `src` equals `src`
+ * \param src The queried block
+ * \return The dependencies
+ */
+ TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
+ /*!
+ * \brief Get all dependencies whose `dst` equals `dst`
+ * \param dst The queried block
+ * \return The dependencies
+ */
+ TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
+};
+
+/*!
+ * \brief Managed reference to BlockScopeNode
+ * \sa BlockScopeNode
+ */
+class BlockScope : public ObjectRef {
+ public:
+ /*! \brief The constructor creating an empty block scope with on dependency information */
+ TVM_DLL BlockScope();
+ /*!
+ * \brief Create the object with the specific leaf blocks, and compute the dependency information
+ * between the leaf blocks.
+ * \param child_block_srefs The srefs to the leaf blocks
+ * \note We assume the leaf blocks are given in pre-DFS order
+ */
+ TVM_DLL BlockScope(const Array<StmtSRef>& child_block_srefs);
+
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode);
+};
+
+} // namespace tir
+} // namespace tvm
+
+#endif // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
new file mode 100644
index 0000000..12b6fc1
--- /dev/null
+++ b/include/tvm/tir/schedule/state.h
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/state.h
+ * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling.
+ */
+#ifndef TVM_TIR_SCHEDULE_STATE_H_
+#define TVM_TIR_SCHEDULE_STATE_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/schedule/block_scope.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief The information about a TensorIR block, it contains two categories of information
+ * 1) Info on the block scope rooted at a specific block, including dependency tracking,
+ * flags indicating if the scope is a stage pipeline, etc.
+ * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it
+ * reads are completely covered by their producers, etc.
+ */
+struct BlockInfo {
+ /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */
+ BlockScope scope{nullptr};
+ // The properties below are information about the current block realization under its parent scope
+ /*! \brief Property of a block, indicating the block realization binding is quasi-affine */
+ bool affine_binding{false};
+ /*!
+ * \brief Property of a block, indicating each of the block's read regions is fully
+ * produced by its producers
+ */
+ bool region_cover{false};
+
+ BlockInfo() = default;
+
+ explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
+ : scope(std::move(scope)), //
+ affine_binding(affine_binding), //
+ region_cover(region_cover) {}
+};
+
+/*!
+ * \brief The bitmask of the debug flag in the ScheduleStateNode.
+ * \sa ScheduleStateNode
+ */
+enum class ScheduleDebugMask : int32_t {
+ /*! \brief Verify the correctness of the sref tree */
+ kVerifySRefTree = 1,
+ /*! \brief Verify the correctness of affine_binding */
+ kVerifyAffineBinding = 2,
+ /*! \brief Verify the correctness of region_cover */
+ kVerifyRegionCover = 4,
+ /*! \brief Verify the correctness of stage_pipeline */
+ kVerifyStagePipeline = 8,
+};
+
+/*!
+ * \brief The state of scheduling, which exposes a `Replace` method as
+ * the primary interface for all the scheduling primitives to manipulate the TensorIR.
+ *
+ * The data structure contains the following information
+ * 1) The AST being scheduled (mod)
+ * 2) The sref tree of schedulable statements (indicated by the srefs)
+ * 3) The dependency information of each block scope (block_info)
+ * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
+ * 5) A debug flag, if set, extra checking is enabled (debug_mode)
+ */
+class ScheduleStateNode : public Object {
+ public:
+ /*! \brief The AST of the module being scheduled */
+ IRModule mod;
+ /*!
+ * \brief Mapping from a block sref to its correpsonding BlockInfo,
+ * tracking the dependency inside the block scope,
+ * and storing necessary information flags for scheduling
+ */
+ std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual> block_info;
+ /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */
+ std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;
+ /*!
+ * \brief Do extra correctness checking after the class creation
+ * and each time after calling the Replace method.
+ * \sa ScheduleDebugMask
+ */
+ int debug_mode;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("mod", &mod);
+ // `block_info` is not visited
+ // `stmt2ref` is not visited
+ v->Visit("debug_mode", &debug_mode);
+ }
+ /*!
+ * \brief Replace the part of the AST, as being pointed to by `src_sref`,
+ * with a specific statement `tgt_stmt`, and maintain the sref tree accordingly.
+ * Replace will try to perform copy on write as much as possible when the ScheduleState holds
+ * the only copy to the IRModule and IR nodes.
+ *
+ * Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`.
+ * 1) Block -> Block
+ * 2) Loop -> Loop
+ * 3) Loop -> BlockRealize
+ *
+ * \param src_sref The sref to the statement to be replaced
+ * \param tgt_stmt The statement to be replaced in
+ * \param block_sref_reuse Maps an old block (to be replaced in the subtree under
+ * `src_sref->stmt`) to a new block (replaced to, in the subtree under `tgt_stmt`), and enforces
+ * reuse of srefs between them (rather than create new srefs) i.e. after being replaced, the sref
+ * that points to the old block will point to the new one
+ * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars.
+ */
+ TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
+ const Map<Block, Block>& block_sref_reuse);
+ /*!
+ * \brief Trigger the verification according to the `debug_mode` bitmask.
+ * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
+ * 2) If the bitmask `kVerifyAffineBinding` is on, verify the correctness of `affine_binding`
+ * 3) If the bitmask `kVerifyRegionCover` is on, verify the correctness of `region_cover`
+ * 4) If the bitmask `kVerifyStagePipeline` is on, verify the correctness of `stage_pipeline`
+ */
+ TVM_DLL void DebugVerify() const;
+
+ static constexpr const char* _type_key = "tir.ScheduleState";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleStateNode, Object);
+
+ /******** Property of blocks ********/
+ /*! \brief Returns the BlockInfo correpsonding to the block sref */
+ TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
+ /*!
+ * \brief Get the BlockScope correpsonding to the sref of scope root block
+ * \param scope_root The block sref to be retrieved
+ * \return The corresponding BlockScope
+ */
+ BlockScope GetBlockScope(const StmtSRef& scope_root) const {
+ return GetBlockInfo(scope_root).scope;
+ }
+ /*!
+ * \brief Check a cached flag indicating if the specific block has quasi-affine bindings
+ * \param block_sref The block sref to be checked
+ * \return A boolean flag indicating if the block has quasi-affine bindings
+ */
+ bool IsAffineBlockBinding(const StmtSRef& block_sref) const {
+ return GetBlockInfo(block_sref).affine_binding;
+ }
+ /*!
+ * \brief Check a cached flag indicating if each of the specific consumer block's read region
+ * is fully produced by its producers
+ * \param consumer_block_sref The specific consumer block
+ * \return A boolean flag indicating if the block has quasi-affine bindings
+ */
+ bool IsRegionCoveredConsumer(const StmtSRef& consumer_block_sref) const {
+ return GetBlockInfo(consumer_block_sref).region_cover;
+ }
+ /*!
+ * \brief Check a cached flag indicating if a block scope is an equivalence of a stage pipeline
+ * \param scope_root The block sref to be retrieved
+ * \return The corresponding BlockScope
+ */
+ bool IsStagePipeline(const StmtSRef& scope_root) const {
+ return GetBlockScope(scope_root)->stage_pipeline;
+ }
+};
+
+/*!
+ * \brief Managed reference to ScheduleStateNode
+ * \sa ScheduleStateNode
+ */
+class ScheduleState : public ObjectRef {
+ public:
+ /*!
+ * \brief Construct a schedule state from an IRModule
+ * \param mod The IRModule to be scheduled
+ * \param debug_mode Do extra correctness checking after the class creation
+ * and each time after calling the Replace method.
+ */
+ TVM_DLL explicit ScheduleState(IRModule mod, int debug_mode = 0);
+ /*!
+ * \brief Construct a schedule state from a PrimFunc
+ * \param func The PrimFunc to be scheduled. A new IRModule will be created with
+ * this specific PrimFunc as "main" function in the module to be scheduled
+ * \param debug_mode Do extra correctness checking after the class creation
+ * and each time after calling the Replace method.
+ */
+ TVM_DLL explicit ScheduleState(PrimFunc func, int debug_mode = 0);
+
+ /*! \return The mutable pointer to the ScheduleStateNode */
+ ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); }
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleState, ObjectRef, ScheduleStateNode);
+};
+
+} // namespace tir
+} // namespace tvm
+
+#endif // TVM_TIR_SCHEDULE_STATE_H_
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index ad91eab..681fc31 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -48,6 +48,9 @@ from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift
+from .schedule import StmtSRef, BlockScope, ScheduleState
+
+from . import schedule
from . import ir_builder
from . import transform
from . import analysis
diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py
new file mode 100644
index 0000000..21721f7
--- /dev/null
+++ b/python/tvm/tir/schedule/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+"""Namespace for the TensorIR schedule API."""
+
+from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
+from .state import ScheduleDebugMask, ScheduleState
diff --git a/python/tvm/tir/schedule/_ffi_api_schedule.py b/python/tvm/tir/schedule/_ffi_api_schedule.py
new file mode 100644
index 0000000..ae8bdfd
--- /dev/null
+++ b/python/tvm/tir/schedule/_ffi_api_schedule.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.tir.schedule"""
+import tvm._ffi
+
+tvm._ffi._init_api("tir.schedule", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
new file mode 100644
index 0000000..8281452
--- /dev/null
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope."""
+from enum import IntEnum
+from typing import List, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir import Block, For
+
+from . import _ffi_api_schedule
+
+
+@register_object("tir.StmtSRef")
+class StmtSRef(Object):
+ """An object that refers to schedulable elements in the TensorIR, aka "sref".
+
+ Glossary
+ - Block sref: An StmtSref that points to a TensorIR block.
+ - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ schedulable statement of its ancestors on the TensorIR AST.
+ - Root sref: Sref to the root block. Every sref has exactly one parent sref
+ except for root sref.
+ - Sref tree: The parent-children-relationship of srefs that forms a tree,
+ uniquely determined by the TensorIR AST.
+ """
+
+ seq_index: int
+
+ @property
+ def stmt(self) -> Optional[Union[Block, For]]:
+ """The block/for stmt the object refers to"""
+ return _ffi_api_schedule.StmtSRefStmt(self) # pylint: disable=no-member
+
+ @property
+ def parent(self) -> Optional["StmtSRef"]:
+ """The parent sref"""
+ return _ffi_api_schedule.StmtSRefParent(self) # pylint: disable=no-member
+
+ @staticmethod
+ def inline_mark() -> "StmtSRef":
+ """A special StmtSRef, which doesn't point to any stmt in the AST,
+ only serving as a "mark" to hint compute-at to do the work of compute-inline"""
+ return _ffi_api_schedule.StmtSRefInlineMark() # pylint: disable=no-member
+
+ @staticmethod
+ def root_mark() -> "StmtSRef":
+ """A special StmtSRef, which doesn't point to any stmt in the AST,
+ only serving as a "mark" to hint compute-at to do nothing"""
+ return _ffi_api_schedule.StmtSRefRootMark() # pylint: disable=no-member
+
+
+class DepKind(IntEnum):
+ """Type of dependency.
+
+ Attributes
+ ----------
+ RAW : int = 0
+ Read-after-write dependency
+ WAW : int = 1
+ Write-after-write dependency
+ WAR : int = 2
+ Write-after-read dependency. Not supported in TensorIR for now.
+ OPAQUE: int = 3
+ Opaque dependency
+ """
+
+ RAW = 0
+ WAW = 1
+ WAR = 2
+ OPAQUE = 3
+
+
+@register_object("tir.Dependency")
+class Dependency(Object):
+ """A tuple (src, dst, kind) representing certain types of dependency.
+ For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ read-after-write, which means block B reads the result written by block A.
+
+ Parameters
+ ----------
+ src : StmtSRef
+ The source of the dependency relation
+ dst : StmtSRef
+ The destination of the dependency relation
+ kind : DepKind
+ The dependency kind
+ """
+
+ src: StmtSRef
+ dst: StmtSRef
+ kind: DepKind
+
+
+@register_object("tir.BlockScope")
+class BlockScope(Object):
+ """An object corresponds to each block sref in the sref tree,
+ which tracks the producer-consumer dependency between blocks.
+
+ Glossary:
+ - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ whose components are:
+ - scope root: a block sref
+ - internal srefs: loop srefs
+ - scope leaves: block srefs
+ - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ """
+
+ def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]:
+ """Get all dependencies whose `src` is the target`block`.
+
+ Parameters
+ ----------
+ block: StmtSRef
+ The queried block
+
+ Returns
+ -------
+ blocks: List[Dependency]
+ The dependencies
+ """
+ return _ffi_api_schedule.BlockScopeGetDepsBySrc(self, block) # pylint: disable=no-member
+
+ def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]:
+ """Get all dependencies whose `dst` is the target `block`.
+
+ Parameters
+ ----------
+ block: StmtSRef
+ The queried block
+
+ Returns
+ -------
+ blocks: List[Dependency]
+ The dependencies
+ """
+ return _ffi_api_schedule.BlockScopeGetDepsByDst(self, block) # pylint: disable=no-member
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
new file mode 100644
index 0000000..180fede
--- /dev/null
+++ b/python/tvm/tir/schedule/state.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This file defines ScheduleState, the core data structure of TensorIR scheduling."""
+from enum import IntEnum
+from typing import Dict, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.ir import IRModule
+from tvm.runtime import Object
+from tvm.tir import Block, BlockRealize, For, PrimFunc
+
+from . import _ffi_api_schedule
+from .block_scope import BlockScope, StmtSRef
+
+
+class ScheduleDebugMask(IntEnum):
+ """The bitmask of the `debug_mode` flag in the ScheduleState class.
+
+ If the `debug_mode` flag has a certain bit on, then the correpsonding
+ verification pass will be conducted. For example, if `(debug_mode & VERIFY_SREF_TREE) != 0`,
+ then the correctness of the sref tree will be verified after each schedule instruction.
+
+ Attributes
+ ----------
+ VERIFY_SREF_TREE : int = 1
+ Verify the correctness of the sref tree
+ VERIFY_AFFINE_BINDING : int = 2
+ Verify the correctness of affine_binding
+ VERIFY_REGION_COVER : int = 4
+ Verify the correctness of region_cover
+ VERIFY_STAGE_PIPELINE: int = 8
+ Verify the correctness of stage_pipeline
+ """
+
+ VERIFY_SREF_TREE = 1
+ VERIFY_AFFINE_BINDING = 2
+ VERIFY_REGION_COVER = 4
+ VERIFY_STAGE_PIPELINE = 8
+
+
+@register_object("tir.ScheduleState")
+class ScheduleState(Object):
+ """The state of scheduling, which exposes a `Replace` method as
+ the primary resort for all the scheduling primitives to manipulate the TensorIR.
+
+ The data structure contains the following information
+ 1) The AST being scheduled (mod)
+ 2) The sref tree of schedulable statements (indicated by the srefs)
+ 3) The dependency information of each block scope (block_info)
+ 4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
+ 5) A debug flag, if set, extra checking is enabled (debug_mode)
+
+ Parameters
+ ----------
+ mod : IRModule
+ The AST of the module being scheduled
+ debug_mode : int
+ Do extra correctness checking after the object construction
+ and each time after calling the Replace method.
+ """
+
+ mod: IRModule
+ debug_mode: int
+
+ def __init__(
+ self,
+ func_or_mod: Union[PrimFunc, IRModule],
+ debug_mode: Union[bool, int] = False,
+ ):
+ """Construct a schedule state from an IRModule or a PrimFunc
+
+ Parameters
+ ----------
+ func_or_mod : Union[PrimFunc, IRModule]
+ The IRModule or PrimFunc to be scheduled
+ debug_mode : Union[bool, int]
+ Do extra correctness checking after the class creation and each time
+ after calling the Replace method.
+ Possible choices of `debug_mode`:
+ 1) True - Turn on all the checks
+ 2) False - Turn off all the checks
+ 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
+ """
+ if isinstance(debug_mode, bool):
+ if debug_mode:
+ debug_mode = -1
+ else:
+ debug_mode = 0
+ if not isinstance(debug_mode, int):
+ raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
+ self.__init_handle_by_constructor__(
+ _ffi_api_schedule.ScheduleState, # pylint: disable=no-member
+ func_or_mod,
+ debug_mode,
+ )
+
+ def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]:
+ """Return the corresponding sref that points to the stmt
+
+ Parameters
+ ----------
+ stmt : Union[Block, For]
+ The schedulable statement in the TensorIR to be retrieved for its sref
+
+ Returns
+ -------
+ sref : StmtSRef
+ The corresponding sref
+ """
+ return _ffi_api_schedule.ScheduleStateGetSRef(self, stmt) # pylint: disable=no-member
+
+ def get_block_scope(self, block_sref: StmtSRef) -> BlockScope:
+ """Get the BlockScope correpsonding to the block sref
+
+ Parameters
+ ----------
+ block_sref : StmtSRef
+ The block sref to be retrieved
+
+ Returns
+ -------
+ sref : StmtSRef
+ The corresponding sref
+ """
+ return _ffi_api_schedule.ScheduleStateGetBlockScope( # pylint: disable=no-member
+ self, block_sref
+ )
+
+ def replace(
+ self,
+ src_sref: StmtSRef,
+ tgt_stmt: Union[Block, For, BlockRealize],
+ block_sref_reuse: Optional[Dict[Block, Block]] = None,
+ ) -> None:
+ """
+ Replace the part of the AST, as being pointed to by `src_sref`,
+ with a specific statement `tgt_stmt`, and maintain the sref tree accordingly.
+ Replace will try to perform copy on write as much as possible when the ScheduleState holds
+ the only copy to the IRModule and IR nodes.
+
+ Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`.
+ 1) Block -> Block
+ 2) Loop -> Loop
+ 3) Loop -> BlockRealize
+
+ Parameters
+ ----------
+ src_sref : StmtSRef
+ The sref to the statement to be replaced in the TensorIR AST
+
+ tgt_stmt : Union[Block, For, BlockRealize]
+ The statement to be replaced to
+
+ block_sref_reuse : Optional[Dict[Block, Block]] = None
+ Maps an old block (to be replaced in the subtree under `src_sref->stmt`)
+ to a new block (replaced to, in the subtree under `tgt_stmt`), and enforces
+ reuse of srefs between them (rather than create new srefs) i.e. after being replaced,
+ the sref that points to the old block will point to the new one
+
+ Note
+ ----------
+ The reuse of loop srefs are detected automatically according to the reuse of loop vars.
+ """
+ if block_sref_reuse is None:
+ block_sref_reuse = {}
+ _ffi_api_schedule.ScheduleStateReplace( # pylint: disable=no-member
+ self,
+ src_sref,
+ tgt_stmt,
+ block_sref_reuse,
+ )
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4746206..46f456c 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -26,12 +26,13 @@ Each statement node have subfields that can be visited from python side.
assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a)
"""
-from typing import List, Optional, Mapping
from enum import IntEnum
+from typing import List, Mapping, Optional, Union
+
import tvm._ffi
+from tvm.ir import PrimExpr, Range, Span
+from tvm.runtime import Object, const
-from tvm.runtime import Object
-from tvm.ir import Span, PrimExpr, Range
from . import _ffi_api
from .buffer import Buffer
from .expr import IterVar
@@ -589,7 +590,7 @@ class BlockRealize(Stmt):
iter_values : List[PrimExpr]
The binding values of the block var.
- predicate : PrimExpr
+ predicate : Union[PrimExpr, bool]
The predicate of the block.
block : Block
@@ -607,12 +608,18 @@ class BlockRealize(Stmt):
def __init__(
self,
iter_values: List[PrimExpr],
- predicate: PrimExpr,
+ predicate: Union[PrimExpr, bool],
block: Block,
span: Optional[Span] = None,
):
+ if isinstance(predicate, bool):
+ predicate = const(predicate, "bool")
self.__init_handle_by_constructor__(
- _ffi_api.BlockRealize, iter_values, predicate, block, span
+ _ffi_api.BlockRealize,
+ iter_values,
+ predicate,
+ block,
+ span,
)
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 4380795..7afdcab 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -683,6 +683,7 @@ Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
doc << "if " << Print(op->condition) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case));
if (!is_one(op->condition) && op->else_case.defined()) {
+ doc << Doc::NewLine();
doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case));
}
return doc;
diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc
index 2a23329..40a2cce 100644
--- a/src/tir/analysis/var_touch.cc
+++ b/src/tir/analysis/var_touch.cc
@@ -18,7 +18,7 @@
*/
/*!
- * \file simple_analysis.cc
+ * \file var_touch.cc
* \brief Implementation of simple passes
*/
#include <tvm/tir/analysis.h>
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 2aeaae3..87ead3e 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -221,7 +221,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<WhileNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const WhileNode*>(node.get());
p->PrintIndent();
- p->stream << "while(" << op->condition << "){\n";
+ p->stream << "while(" << op->condition << ") {\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
@@ -781,7 +781,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
auto* op = static_cast<const BlockNode*>(node.get());
p->PrintIndent();
PrintBlockTitle(op, p);
- p->stream << "{\n";
+ p->stream << " {\n";
p->indent += 2;
// Print block elements (e.g. reads/writes, etc)
@@ -820,7 +820,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
auto* block_op = op->block.get();
p->PrintIndent();
PrintBlockTitle(block_op, p);
- p->stream << "{\n";
+ p->stream << " {\n";
p->indent += 2;
// Print binding iter_values
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
new file mode 100644
index 0000000..32d9f6d
--- /dev/null
+++ b/src/tir/schedule/analysis.h
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
+#define TVM_TIR_SCHEDULE_ANALYSIS_H_
+
+#include <tvm/tir/schedule/state.h>
+
+namespace tvm {
+namespace tir {
+
+/******** Verification ********/
+/*!
+ * \brief Verify the sref tree state is consistent with the IR
+ * \param self The schedule state containing the sref to be verified
+ * \throw An exception will be thrown if the sref tree is not valid
+ */
+void VerifySRefTree(const ScheduleState& self);
+
+/******** Block-loop relation ********/
+/*!
+ * \brief Get the leaf blocks of a scope where a specific block/loop is in
+ * \param self The schedule state
+ * \param parent_sref The StmtSRef that points to the parent block/loop
+ * \return A list of leaf blocks
+ */
+Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref);
+
+} // namespace tir
+} // namespace tvm
+
+#endif // TVM_TIR_SCHEDULE_ANALYSIS_H_
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
new file mode 100644
index 0000000..005ff37
--- /dev/null
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Block-loop relation ********/
+
+Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) {
+ struct Collector : public StmtVisitor {
+ public:
+ static Array<StmtSRef> Collect(const ScheduleState& self, const Stmt& stmt) {
+ Collector collector(self);
+ collector(stmt);
+ return std::move(collector.result_);
+ }
+
+ private:
+ explicit Collector(const ScheduleState& self) : self_(self) {}
+
+ void VisitStmt_(const BlockNode* block) final {
+ auto it = self_->stmt2ref.find(block);
+ ICHECK(it != self_->stmt2ref.end());
+ result_.push_back(it->second);
+ }
+
+ const ScheduleState& self_;
+ Array<StmtSRef> result_;
+ };
+
+ if (parent_sref->stmt->IsInstance<ForNode>()) {
+ const auto* loop = static_cast<const ForNode*>(parent_sref->stmt);
+ return Collector::Collect(self, loop->body);
+ } else if (parent_sref->stmt->IsInstance<BlockNode>()) {
+ const auto* block = static_cast<const BlockNode*>(parent_sref->stmt);
+ return Collector::Collect(self, block->body);
+ }
+ ICHECK(false) << "Unreachable";
+ throw;
+}
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc
new file mode 100644
index 0000000..edb62b5
--- /dev/null
+++ b/src/tir/schedule/analysis/verify.cc
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+class SRefTreeVerifier : public StmtVisitor {
+ public:
+ static void Verify(const ScheduleStateNode* self) { SRefTreeVerifier(self).Verify(); }
+
+ private:
+ /*! \brief Constructor */
+ explicit SRefTreeVerifier(const ScheduleStateNode* self) : self_(self) {}
+
+ void Verify() {
+ VisitPrimFuncs(self_->mod, [this](const PrimFuncNode* func) { this->VisitStmt(func->body); });
+ ICHECK_EQ(n_sref_visited_, static_cast<int>(self_->stmt2ref.size()));
+ for (const auto& kv : self_->block_info) {
+ const StmtSRef& sref = kv.first;
+ ICHECK(sref->stmt != nullptr)
+ << "InternalError: An expired sref is found in the block_scope mapping";
+ auto it = self_->stmt2ref.find(sref->stmt);
+ ICHECK(it != self_->stmt2ref.end())
+ << "InternalError: The sref points to a statement that does not exist in stmt2ref";
+ const StmtSRef& sref2 = it->second;
+ ICHECK(sref.same_as(sref2))
+ << "InternalError: The sref points to a statement whose corresponding sref in stmt2ref "
+ "is not the same object as itself";
+ }
+ ICHECK_EQ(n_block_sref_visited_, static_cast<int>(self_->block_info.size()));
+ }
+
+ void VisitStmt_(const BlockNode* block) final {
+ if (init_block_depth_) {
+ ICHECK(!self_->stmt2ref.count(block)) << "InternalError: A block inside init block has its "
+ "corresponding sref, which is not allowed";
+ StmtVisitor::VisitStmt_(block);
+ return;
+ }
+ ICHECK(self_->stmt2ref.count(block))
+ << "InternalError: A BlockNode should appear in sref map, but it didn't\n"
+ << GetRef<Stmt>(block);
+ ++n_sref_visited_;
+ ++n_block_sref_visited_;
+ const StmtSRef& sref = self_->stmt2ref.at(block);
+ ICHECK(self_->block_info.count(sref))
+ << "InternalError: Cannot find scope information of the BlockNode:\n"
+ << GetRef<Stmt>(block);
+ ICHECK(sref->parent == ancestors_.back())
+ << "InternalError: Parent information mismatch for BlockNode:\n"
+ << GetRef<Stmt>(block) << "\nIts parent is supposed to be:\n"
+ << GetRef<Stmt>(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n"
+ << (sref->parent ? Optional<Stmt>(GetRef<Stmt>(sref->parent->stmt))
+ : Optional<Stmt>(NullOpt));
+ ancestors_.push_back(sref.operator->());
+ if (block->init.defined()) {
+ ++init_block_depth_;
+ VisitStmt(block->init.value());
+ --init_block_depth_;
+ }
+ VisitStmt(block->body);
+ ancestors_.pop_back();
+ }
+
+ void VisitStmt_(const ForNode* loop) final {
+ if (init_block_depth_) {
+ ICHECK(!self_->stmt2ref.count(loop)) << "InternalError: A loop inside init block has its "
+ "corresponding sref, which is not allowed";
+ StmtVisitor::VisitStmt_(loop);
+ return;
+ }
+ ICHECK(self_->stmt2ref.count(loop))
+ << "InternalError: A ForNode should appear in sref map, but it didn't\n"
+ << GetRef<Stmt>(loop);
+ ++n_sref_visited_;
+ const StmtSRef& sref = self_->stmt2ref.at(loop);
+ Optional<Stmt> stmt = NullOpt;
+ ICHECK(sref->parent == ancestors_.back())
+ << "InternalError: Parent information mismatch for ForNode:\n"
+ << GetRef<Stmt>(loop) << "\nIts parent is supposed to be:\n"
+ << GetRef<Stmt>(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n"
+ << (sref->parent ? Optional<Stmt>(GetRef<Stmt>(sref->parent->stmt))
+ : Optional<Stmt>(NullOpt));
+ ancestors_.push_back(sref.operator->());
+ StmtVisitor::VisitStmt_(loop);
+ ancestors_.pop_back();
+ }
+
+ void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ // Verify seq_index
+ if (init_block_depth_) {
+ StmtVisitor::VisitStmt_(seq_stmt);
+ return;
+ }
+ int n = static_cast<int>(seq_stmt->seq.size());
+ for (int i = 0; i < n; ++i) {
+ const Stmt& child = seq_stmt->seq[i];
+ StmtSRef sref{nullptr};
+ if (const auto* realize = child.as<BlockRealizeNode>()) {
+ const auto* block = realize->block.get();
+ ICHECK(self_->stmt2ref.count(block));
+ sref = self_->stmt2ref.at(block);
+ } else if (child->IsInstance<ForNode>()) {
+ ICHECK(self_->stmt2ref.count(child.get()));
+ sref = self_->stmt2ref.at(child.get());
+ } else {
+ continue;
+ }
+ ICHECK_EQ(sref->seq_index, i) << "InternalError: A StmtSRef has incorrect seq_index";
+ }
+ StmtVisitor::VisitStmt_(seq_stmt);
+ }
+
+ /*! \brief The schedule it belongs to */
+ const ScheduleStateNode* self_;
+ /*! \brief Parent information during the visit */
+ std::vector<const StmtSRefNode*> ancestors_ = {nullptr};
+ /*! \brief If the visitor is currently in the init block */
+ int init_block_depth_ = 0;
+ /*! \brief Number of srefs that are visited */
+ int n_sref_visited_ = 0;
+ /*! \brief Number of block srefs that are visited */
+ int n_block_sref_visited_ = 0;
+};
+
+void VerifySRefTree(const ScheduleState& self) { SRefTreeVerifier::Verify(self.get()); }
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/block_scope.cc b/src/tir/schedule/block_scope.cc
new file mode 100644
index 0000000..f1ce65e
--- /dev/null
+++ b/src/tir/schedule/block_scope.cc
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Utility functions ********/
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/*!
+ * \brief Add a dependency relation.
+ * \param src The source of the dependency
+ * \param dst The destination of the dependecy
+ * \param kind Type of the dependency
+ * \note This method is effectively NOP on self-loops
+ */
+void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
+ if (!src.same_as(dst)) {
+ Dependency dep(src, dst, kind);
+ self->src2deps[src].push_back(dep);
+ self->dst2deps[dst].push_back(dep);
+ }
+}
+
+/******** Constructors ********/
+
+StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) {
+ ObjectPtr<StmtSRefNode> n = make_object<StmtSRefNode>();
+ n->stmt = stmt;
+ n->parent = parent;
+ n->seq_index = seq_index;
+ data_ = std::move(n);
+}
+
+StmtSRef StmtSRef::InlineMark() {
+ static StmtSRef result(nullptr, nullptr, -1);
+ return result;
+}
+
+StmtSRef StmtSRef::RootMark() {
+ static StmtSRef result(nullptr, nullptr, -1);
+ return result;
+}
+
+Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) {
+ ObjectPtr<DependencyNode> node = make_object<DependencyNode>();
+ node->src = std::move(src);
+ node->dst = std::move(dst);
+ node->kind = kind;
+ data_ = std::move(node);
+}
+
+BlockScope::BlockScope() { data_ = make_object<BlockScopeNode>(); }
+
+BlockScope::BlockScope(const Array<StmtSRef>& child_block_srefs) {
+ ObjectPtr<BlockScopeNode> n = make_object<BlockScopeNode>();
+ SMap<Buffer, Array<StmtSRef>> buffer_readers;
+ SMap<Buffer, Array<StmtSRef>>& buffer_writers = n->buffer_writers;
+ for (const StmtSRef& child_block_sref : child_block_srefs) {
+ const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref);
+ // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer
+ for (const BufferRegion& region : child_block->reads) {
+ buffer_readers[region->buffer].push_back(child_block_sref);
+ }
+ for (const BufferRegion& region : child_block->writes) {
+ buffer_writers[region->buffer].push_back(child_block_sref);
+ }
+ // Step 2. Update RAW dependency
+ for (const BufferRegion& region : child_block->reads) {
+ auto it = buffer_writers.find(region->buffer);
+ if (it != buffer_writers.end()) {
+ for (const StmtSRef& from : it->second) {
+ AddDependency(n.get(), from, child_block_sref, DepKind::kRAW);
+ }
+ }
+ }
+ // Step 3. Update WAW dependency
+ for (const BufferRegion& region : child_block->writes) {
+ auto it = buffer_writers.find(region->buffer);
+ if (it != buffer_writers.end()) {
+ for (const StmtSRef& from : it->second) {
+ AddDependency(n.get(), from, child_block_sref, DepKind::kWAW);
+ }
+ }
+ }
+ // Step 4. Update WAR dependency
+ for (const BufferRegion& region : child_block->writes) {
+ auto it = buffer_readers.find(region->buffer);
+ if (it != buffer_readers.end()) {
+ for (const StmtSRef& from : it->second) {
+ AddDependency(n.get(), from, child_block_sref, DepKind::kWAR);
+ }
+ }
+ }
+ }
+ data_ = std::move(n);
+}
+
+/******** Dependency ********/
+
+Array<Dependency> BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const {
+ auto iter = this->src2deps.find(block_sref);
+ if (iter != this->src2deps.end()) {
+ return iter->second;
+ } else {
+ return {};
+ }
+}
+
+Array<Dependency> BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const {
+ auto iter = this->dst2deps.find(block_sref);
+ if (iter != this->dst2deps.end()) {
+ return iter->second;
+ } else {
+ return {};
+ }
+}
+
+/******** FFI ********/
+
+TVM_REGISTER_NODE_TYPE(StmtSRefNode);
+TVM_REGISTER_NODE_TYPE(DependencyNode);
+TVM_REGISTER_NODE_TYPE(BlockScopeNode);
+
+TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt")
+ .set_body_typed([](StmtSRef sref) -> Optional<Stmt> {
+ return GetRef<Optional<Stmt>>(sref->stmt);
+ });
+TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent")
+ .set_body_typed([](StmtSRef sref) -> Optional<StmtSRef> {
+ return GetRef<Optional<StmtSRef>>(sref->parent);
+ });
+TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark") //
+ .set_body_typed(StmtSRef::RootMark);
+TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark") //
+ .set_body_typed(StmtSRef::InlineMark);
+TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc")
+ .set_body_method<BlockScope>(&BlockScopeNode::GetDepsBySrc);
+TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst")
+ .set_body_method<BlockScope>(&BlockScopeNode::GetDepsByDst);
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
new file mode 100644
index 0000000..d1b899b
--- /dev/null
+++ b/src/tir/schedule/state.cc
@@ -0,0 +1,870 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+ if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+ const BlockNode* block = realize->block.get();
+ ICHECK(self->stmt2ref.count(block));
+ self->stmt2ref.at(block)->seq_index = seq_index;
+ } else if (const auto* block = stmt.as<BlockNode>()) {
+ ICHECK(self->stmt2ref.count(block));
+ self->stmt2ref.at(block)->seq_index = seq_index;
+ } else if (const auto* loop = stmt.as<ForNode>()) {
+ ICHECK(self->stmt2ref.count(loop));
+ self->stmt2ref.at(loop)->seq_index = seq_index;
+ } else {
+ // do nothing
+ }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+ int i = 0;
+ for (const Stmt& stmt : seq_stmt->seq) {
+ SetSeqIndex(self, stmt, i);
+ ++i;
+ }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ * `sref->stmt` to `new_stmt`
+ * `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+ ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+ const StmtNode* old_stmt = sref->stmt;
+ ICHECK_NE(new_stmt, old_stmt);
+ self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+ self->stmt2ref.erase(sref->stmt);
+ sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+ GlobalVar* result_g_var) {
+ for (const auto& kv : mod->functions) {
+ const GlobalVar& g_var = kv.first;
+ const BaseFunc& base_func = kv.second;
+ if (const auto* func = base_func.as<PrimFuncNode>()) {
+ if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+ if (realize->block.get() == root_block) {
+ *result_g_var = g_var;
+ return func;
+ }
+ }
+ }
+ }
+ LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+ "statement:\n"
+ << GetRef<Stmt>(root_block);
+ throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+ /*!
+ * \brief The entry function
+ * \param self The schedule state to be completed
+ */
+ static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+ ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+ ScheduleStateNode* self = n.get();
+ // Set `n->mod`
+ n->mod = std::move(mod);
+ // Set `n->debug_mode`
+ n->debug_mode = debug_mode;
+ // Set `n->stmt2ref` and `n->block_info`
+ StateCreator creator(self);
+ for (const auto& kv : n->mod->functions) {
+ const BaseFunc& base_func = kv.second;
+ if (const auto* func = base_func.as<PrimFuncNode>()) {
+ creator.VisitStmt(func->body);
+ }
+ }
+ return n;
+ }
+
+ private:
+ explicit StateCreator(ScheduleStateNode* self)
+ : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+ block_frames_.emplace({});
+ }
+
+ /*!
+ * \brief Add a new statement to the stack, which becomes the current scope
+ * \param stmt A for-loop statement or a block statement
+ * \return A sref to the stmt
+ */
+ StmtSRef PushSRef(const StmtNode* stmt) {
+ if (srefs_.empty()) {
+ srefs_.push_back(
+ StmtSRef(stmt,
+ /*parent=*/nullptr,
+ /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
+ } else {
+ StmtSRefNode* parent = srefs_.back().get();
+ srefs_.push_back(
+ StmtSRef(stmt, parent,
+ /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
+ }
+ return srefs_.back();
+ }
+
+ /*! \brief Pop the top of the scope and record it in stmt2ref map */
+ StmtSRef PopAndRecordSRef() {
+ StmtSRef sref = std::move(srefs_.back());
+ self_->stmt2ref[sref->stmt] = sref;
+ srefs_.pop_back();
+ return sref;
+ }
+
+ void MakeBlockInfo(StmtSRef scope_root) {
+ // Calculate `BlockInfo::scope`
+ Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+ BlockInfo& info =
+ self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+ .first->second;
+ // TODO(@junrushao1994): calculate the flags
+ // Set `affine_binding`
+ info.affine_binding = false;
+ // Set `region_cover`
+ info.region_cover = false;
+ // Set `stage_pipeline`
+ info.scope->stage_pipeline = false;
+ }
+
+ void VisitStmt_(const ForNode* loop) final {
+ PushSRef(loop);
+ VisitStmt(loop->body);
+ PopAndRecordSRef();
+ }
+
+ void VisitStmt_(const BlockRealizeNode* realize) final {
+ realizes_.push_back(realize);
+ block_frames_.emplace_back();
+ const BlockNode* block = realize->block.get();
+ // Recursive visit
+ PushSRef(block);
+ VisitStmt(block->body); // `block->init` is not visited
+ StmtSRef sref = PopAndRecordSRef();
+ // Create BlockInfo for the block
+ MakeBlockInfo(sref);
+ // Update parent scope
+ block_frames_.pop_back();
+ block_frames_.back().push_back(sref);
+ realizes_.pop_back();
+ }
+
+ void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ // Set `seq_index` information for SeqStmtNode
+ StmtVisitor::VisitStmt_(seq_stmt);
+ SetSeqIndexInChildren(self_, seq_stmt);
+ }
+
+ /*! \brief The result ScheduleStateNode */
+ ScheduleStateNode* self_;
+ /*! \brief The stack frame used to indicate the current scope */
+ std::vector<StmtSRef> srefs_;
+ /*! \brief The BlockRealize in the ancestors */
+ std::vector<const BlockRealizeNode*> realizes_;
+ /*! \brief The stack frames of blocks in the DFS visit. */
+ std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+ CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+ data_ = StateCreator::Create(mod, debug_mode);
+ (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+ : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+ /*!
+ * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+ * sref is reused and it is intact reuse.
+ */
+ std::unordered_set<const StmtNode*> intact;
+ /*!
+ * \brief Kind 2.1. Loop sref reuse
+ * If the loop var of a loop is in `loop_sref_possible_reuse`,
+ * it means that when `src_stmt` has a loop that uses this loop var,
+ * the reuse kind is loop sref reuse.
+ * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+ * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+ */
+ std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+ /*!
+ * \brief Kind 2.2. Block sref reuse.
+ * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+ * indicating the sref to the old block should be reused in the sref to the new block.
+ */
+ std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+ static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+ ReuseCollector collector(self);
+ collector.VisitStmt(tgt_stmt);
+ ReuseInfo result;
+ result.intact = {collector.intact_.begin(), collector.intact_.end()};
+ result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+ // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+ // and it is supposed to be properly set by the caller.
+ return result;
+ }
+
+ private:
+ explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+ void VisitStmt_(const ForNode* op) final {
+ if (self_->stmt2ref.count(op)) {
+ intact_.push_back(op);
+ } else {
+ // Collect loop vars for detecting reuse of loop sref
+ loop_vars_.push_back(op->loop_var.get());
+ StmtVisitor::VisitStmt_(op);
+ }
+ }
+
+ void VisitStmt_(const BlockNode* op) final {
+ if (self_->stmt2ref.count(op)) {
+ intact_.push_back(op);
+ } else {
+ StmtVisitor::VisitStmt_(op);
+ }
+ }
+
+ /*! \brief The schedule state to be worked on */
+ const ScheduleStateNode* self_;
+ /*! \brief The intact statements we have collected along the way of visiting */
+ std::vector<const StmtNode*> intact_;
+ /*! \brief The loop variable we collected in the tgt_stmt */
+ std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+ /*!
+ * \brief The entry function
+ * \param self The schedule class
+ * \param info The reuse info about intact reuse and loop/block reuse
+ * \param src_stmt The `src_stmt` where stale srefs to be removed
+ * \return Mapping from the reuse elements to reused srefs, more specifically:
+ * 1) Loop reuse: maps a loop var to the reused sref
+ * 2) Block reuse: maps a block stmt to the reused sref,
+ * where the block comes from the subtree of `tgt_stmt`
+ * 3) Intact reuse: not returned
+ */
+ static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+ const ReuseInfo& reuse_info,
+ const Stmt& src_stmt) {
+ SRefTreePruner pruner(self, reuse_info);
+ pruner.VisitStmt(src_stmt);
+ return std::move(pruner.reused_srefs_);
+ }
+
+ private:
+ explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+ : self_(self), reuse_info_(reuse_info) {}
+
+ void VisitStmt_(const ForNode* op) final {
+ if (reuse_info_.intact.count(op)) {
+ return;
+ }
+ auto it = self_->stmt2ref.find(op);
+ ICHECK(it != self_->stmt2ref.end())
+ << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+ << GetRef<For>(op);
+ StmtSRef& sref = it->second;
+ // Detect reuse
+ const VarNode* loop_var = op->loop_var.get();
+ if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+ // sref can be reused
+ reused_srefs_.emplace(loop_var, std::move(sref));
+ } else {
+ sref->Reset();
+ }
+ // erase the statement
+ self_->stmt2ref.erase(it);
+ // detect recursively
+ VisitStmt(op->body);
+ }
+
+ void VisitStmt_(const BlockNode* op) final {
+ if (reuse_info_.intact.count(op)) {
+ return;
+ }
+ auto it = self_->stmt2ref.find(op);
+ ICHECK(it != self_->stmt2ref.end())
+ << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+ << GetRef<Block>(op);
+ StmtSRef& sref = it->second;
+ // Detect reuse
+ auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+ if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+ // sref can be reused
+ reused_srefs_.emplace(reuse_it->second, std::move(sref));
+ } else {
+ sref->Reset();
+ self_->block_info.erase(sref);
+ }
+ // erase the statement
+ self_->stmt2ref.erase(it);
+ // detect recursively
+ // op->init is omitted
+ VisitStmt(op->body);
+ }
+
+ /*! \brief The schedule state we are working on */
+ ScheduleStateNode* self_;
+ /*! \brief The reuse information we collected previously */
+ const ReuseInfo& reuse_info_;
+ /*!
+ * \brief Reused srefs:
+ * 1) loop var -> StmtSRef
+ * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+ */
+ std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+ static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+ const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+ const Stmt& tgt_stmt) {
+ SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+ }
+
+ private:
+ explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+ const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+ : self_(GetRef<ScheduleState>(self)),
+ ancestors_{src_stmt_parent},
+ reused_srefs_(reused_srefs) {}
+
+ void VisitStmt_(const ForNode* op) final {
+ StmtSRef& sref = self_->stmt2ref[op];
+ // Detect intact reuse
+ if (sref.defined()) {
+ sref->parent = ancestors_.back();
+ sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex
+ return;
+ }
+ // Detect loop reuse
+ auto it = reused_srefs_.find(op->loop_var.get());
+ if (it != reused_srefs_.end()) {
+ // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+ sref = it->second;
+ sref->stmt = op;
+ sref->parent = ancestors_.back();
+ sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex
+ } else {
+ // A new loop sref without reuse
+ sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+ /*seq_index=*/-1); // `seq_index` will be set properly in SetSeqIndex
+ }
+ // Recursive visit
+ ancestors_.push_back(sref.get());
+ VisitStmt(op->body);
+ ancestors_.pop_back();
+ }
+
+ void VisitStmt_(const BlockNode* op) final {
+ StmtSRef& sref = self_->stmt2ref[op];
+ // Detect intact
+ if (sref.defined()) {
+ sref->parent = ancestors_.back();
+ sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex
+ return;
+ }
+ // Detect block reuse
+ auto it = reused_srefs_.find(op);
+ if (it != reused_srefs_.end()) {
+ // Update `stmt2ref[op]` to `reused_srefs_[op]`
+ sref = it->second;
+ sref->stmt = op;
+ sref->parent = ancestors_.back();
+ sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex
+ } else {
+ // A new block sref without reuse
+ sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+ /*seq_index=*/-1); // `seq_index` will be set properly in SetSeqIndex
+ }
+ // Recursive visit
+ ancestors_.push_back(sref.get());
+ VisitStmt(op->body);
+ ancestors_.pop_back();
+ // Additionally, need to update the scope because the block is changed
+ UpdateBlockInfo(sref);
+ }
+
+ void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ StmtVisitor::VisitStmt_(seq_stmt);
+ SetSeqIndexInChildren(self_.get(), seq_stmt);
+ }
+
+ void UpdateBlockInfo(const StmtSRef& block_sref) {
+ using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+ // The caller is responsible for correcting the flags
+ BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+ std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+ bool inserted = insert_result.second;
+ BlockInfo& info = insert_result.first->second;
+ if (inserted) {
+ // Insertion has happened, update the flags accordingly
+ BlockInfo& info = insert_result.first->second;
+ info.affine_binding = false;
+ info.region_cover = false;
+ info.scope->stage_pipeline = false;
+ } else {
+ // Insertion didn't take place, because the entry has been there before.
+ // In this case, we assume that flags are still valid so intentionally keep them unchanged
+ info.scope = std::move(new_info.scope);
+ }
+ }
+
+ /*! \brief The schedule state class to be worked on */
+ ScheduleState self_;
+ /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+ std::vector<StmtSRefNode*> ancestors_;
+ /*! \brief Maps the loop var / block to the reused sref */
+ const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+ static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+ const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+ // Check the invariant
+ ICHECK(child_src_stmt->IsInstance<BlockNode>() || //
+ child_src_stmt->IsInstance<ForNode>());
+ ICHECK(child_tgt_stmt->IsInstance<BlockNode>() || //
+ child_tgt_stmt->IsInstance<ForNode>() || //
+ child_tgt_stmt->IsInstance<BlockRealizeNode>());
+ ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+ replacer.allow_copy_on_write_ = allow_copy_on_write;
+ return replacer.CopyOnWriteAndVisit(parent_stmt);
+ }
+
+ private:
+ explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+ : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+ Stmt VisitStmt(const Stmt& stmt) final {
+ if (stmt.get() == src_stmt_) {
+ // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+ return tgt_stmt_;
+ } else {
+ return StmtMutator::VisitStmt(stmt);
+ }
+ }
+
+ // Skipping sibling blocks and loops other than `src_stmt_`
+ Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+ Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+ Stmt VisitStmt_(const SeqStmtNode* op) final {
+ int i = this->seq_index_;
+ int n = static_cast<int>(op->seq.size());
+ if (0 <= i && i < n) {
+ const Stmt& stmt = op->seq[i];
+ Optional<Stmt> new_stmt = NullOpt;
+ const StmtNode* src_stmt = this->src_stmt_;
+ // `stmt` can be For or BlockRealize
+ // `src_stmt` can be For or Block
+ // so the match from `stmt` to `src_stmt` can be
+ // 1) For -> For
+ // 2) BlockRealize -> Block
+ if (stmt.get() == src_stmt) {
+ // Case 1. src_stmt is For, stmt is For
+ new_stmt = tgt_stmt_;
+ } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+ // Case 2. stmt is BlockRealize, src_stmt is Block
+ if (realize->block.get() == src_stmt) {
+ const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+ ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+ new_realize->block = GetRef<Block>(tgt_block);
+ new_stmt = BlockRealize(std::move(new_realize));
+ }
+ }
+ // Move new_stmt to position i
+ if (new_stmt.defined()) {
+ ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+ new_seq_stmt->seq.Set(i, new_stmt.value());
+ return SeqStmt(std::move(new_seq_stmt));
+ }
+ }
+ return StmtMutator::VisitStmt_(op);
+ }
+
+ Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+ // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+ // where `body` means the body of either a block or a loop
+ // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+ // and replace it with `child_tgt_stmt`
+ if (parent_stmt->IsInstance<BlockNode>()) {
+ auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+ ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+ new_block->body = this->VisitStmt(new_block->body);
+ return Block(std::move(new_block));
+ } else if (parent_stmt->IsInstance<ForNode>()) {
+ auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+ ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+ new_loop->body = this->VisitStmt(new_loop->body);
+ return For(std::move(new_loop));
+ }
+ LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+ throw;
+ }
+
+ /*! \brief The `src_stmt` to be replaced */
+ const StmtNode* src_stmt_;
+ /*! \brief The `tgt_stmt` to be replaced in */
+ const Stmt& tgt_stmt_;
+ /*!
+ * \brief The `seq_index` of the `src_stmt`
+ * \sa StmtSRefNode
+ */
+ int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+ const Map<Block, Block>& _block_sref_reuse) {
+ if (this->debug_mode != 0) {
+ const StmtNode* src_stmt = _src_sref->stmt;
+ bool input_correct =
+ (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+ (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+ (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+ if (!input_correct) {
+ LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+ << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+ << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+ << tgt_stmt;
+ }
+ }
+ // Rule out the case that no replacement happens
+ if (_src_sref->stmt == tgt_stmt.get()) {
+ return;
+ }
+ // Reset sref as a new sref so that its content won't be affected by subsequent changes
+ StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+ Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+ // Step 1. Create all the nodes needed for the new sref tree.
+ // After this step
+ // 1) all `parent`s are correct
+ // 2) all `seq_index`s are correct, except for the root
+ // 3) all `stmt`s are correct, except for the root
+ {
+ // Step 0. Setup block_sref_reuse
+ std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+ block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+ for (const auto& kv : _block_sref_reuse) {
+ block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+ }
+ // Step 1.1. Collect info for different kinds of reuses
+ // 1) intact
+ // 2) loop/block reuse
+ ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+ reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+ // Step 1.2. Collect loop/block reuse to their corresponding srefs
+ // and remove those srefs in the `src_stmt` that are no longer used after replacement
+ std::unordered_map<const Object*, StmtSRef> reused_srefs =
+ SRefTreePruner::Prune(this, reuse_info, src_stmt);
+ // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+ // srefs in `tgt_stmt`
+ SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+ }
+ // Step 2. Set the ancestors' children properly
+ // Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+ // The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+ // Along the way, because we create a new ancestor path,
+ // we need to update those sref points from old ancestors to newly created ones
+ // Variables:
+ // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+ // can be mutated inplace, it needs `num_copy_steps + 1` hops.
+ // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+ // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+ int num_copy_steps = -1;
+ bool need_module_copy = false;
+ const PrimFuncNode* g_func = nullptr;
+ GlobalVar g_var;
+ {
+ int i = 0;
+ const StmtSRefNode* p = src_sref.get();
+ while (true) {
+ if (!p->stmt->unique()) {
+ num_copy_steps = i;
+ }
+ if (p->parent == nullptr) {
+ break;
+ }
+ ++i;
+ p = p->parent;
+ }
+ // Find `g_func` and `g_var` where the `src_sref` is in
+ g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+ need_module_copy = num_copy_steps == i || //
+ !this->mod.unique() || //
+ !this->mod->functions.unique() || //
+ !g_func->unique();
+ }
+ // Loop invariant:
+ //
+ // Before step `i`:
+ // 1) `child_sref` is `src_sref` going up by `i` steps
+ // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+ // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+ // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+ // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+ //
+ // During step `i`:
+ // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+ // 2) Point `child_sref` to `child_tgt_stmt`
+ // 3) `tgt_stmt` is of type Loop or Block
+ StmtSRefNode* child_sref = src_sref.get();
+ Stmt child_tgt_stmt = std::move(tgt_stmt);
+ for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+ bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+ // Replace `child_sref->stmt` to `child_tgt_stmt`.
+ const StmtNode* parent_stmt = child_sref->parent->stmt;
+ const StmtNode* child_src_stmt = child_sref->stmt;
+ // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+ if (i == 0) {
+ // As the invariance of SRefUpdater,
+ // the `seq_index` of the root of `tgt_stmt` is set as -1,
+ // which might be incorrect
+ SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+ } else {
+ // Point `child_sref` to `child_tgt_stmt`
+ UpdateSRef(this, child_sref, child_tgt_stmt.get());
+ }
+ // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+ Stmt new_parent_stmt =
+ ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+ /*seq_index=*/child_sref->seq_index,
+ /*allow_copy_on_write=*/can_directly_mutate_parent);
+ // Step 2.3. Go to next parent
+ if (can_directly_mutate_parent) {
+ // If the node can be directly mutated inplace,
+ // then there is no need to update its parent and the function
+ break;
+ }
+ child_tgt_stmt = std::move(new_parent_stmt);
+ child_sref = child_sref->parent;
+ }
+ // Step 3. Handle the case that we mutate the root
+ if (need_module_copy) {
+ // From the loop invariant, upon exit, while its subtree is properly set,
+ // `child_sref` is not properly to `child_tgt_stmt` yet.
+ if (src_sref->parent != nullptr) {
+ // Not replacing a root
+ UpdateSRef(this, child_sref, child_tgt_stmt.get());
+ }
+ // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+ IRModuleNode* new_mod = this->mod.CopyOnWrite();
+ MapNode* new_map = new_mod->functions.CopyOnWrite();
+ // Move out the PrimFunc where the sref belong while ensuring uniqueness
+ PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+ ICHECK(ref_new_func.get() == g_func);
+ PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+ // If `g_func` was not unique, after the 3 lines above:
+ // `ref_new_func` points to a unique PrimFunc
+ // `g_func` points to the previous PrimFunc if it is not unique
+ // If `g_func` was unique, after the 3 lines above:
+ // `ref_new_func` points to the same unique function that `g_func` points to
+ // Update the body of the function the sref belongs to Assign
+ const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+ // Make `child_tgt_stmt` the root block
+ const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+ ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+ new_realize->block = GetRef<Block>(child_block);
+ new_func->body = BlockRealize(std::move(new_realize));
+ // Finally, move the `ref_new_func` back and update `this->mod`
+ new_map->at(g_var) = std::move(ref_new_func);
+ this->mod = GetRef<IRModule>(new_mod);
+ }
+ constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+ if (debug_mode == -1 || (debug_mode & kVerifySRefTree)) {
+ VerifySRefTree(GetRef<ScheduleState>(this));
+ }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+ constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+ constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+ constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+ constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+ ICHECK_GE(debug_mode, -1);
+ if (debug_mode == -1 || (debug_mode & kVerifySRefTree)) {
+ VerifySRefTree(GetRef<ScheduleState>(this));
+ }
+ if (debug_mode == -1 || (debug_mode & kVerifyAffineBinding)) {
+ // TODO(@junrushao1994): Verify affine block binding
+ }
+ if (debug_mode == -1 || (debug_mode & kVerifyRegionCover)) {
+ // TODO(@junrushao1994): Verify region cover
+ }
+ if (debug_mode == -1 || (debug_mode & kVerifyStagePipeline)) {
+ // TODO(@junrushao1994): Verify stage pipeline
+ }
+}
+
+/**************** BlockInfo-related ****************/
+
+BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const {
+ const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ auto it = this->block_info.find(block_sref);
+ CHECK(it != this->block_info.end())
+ << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n"
+ << GetRef<Stmt>(block_sref->stmt);
+ return it->second;
+}
+
+/**************** FFI ****************/
+
+TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState").set_body_typed([](ObjectRef obj, int debug_mode) {
+ if (const auto* func = obj.as<PrimFuncNode>()) {
+ return ScheduleState(GetRef<PrimFunc>(func), debug_mode);
+ }
+ if (const auto* mod = obj.as<IRModuleNode>()) {
+ return ScheduleState(GetRef<IRModule>(mod), debug_mode);
+ }
+ LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey();
+ throw;
+});
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope")
+ .set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace")
+ .set_body_method<ScheduleState>(&ScheduleStateNode::Replace);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef")
+ .set_body_typed([](ScheduleState self, Stmt stmt) -> Optional<StmtSRef> {
+ auto it = self->stmt2ref.find(stmt.get());
+ return it != self->stmt2ref.end() ? it->second : Optional<StmtSRef>(NullOpt);
+ });
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
new file mode 100644
index 0000000..63ec77d
--- /dev/null
+++ b/src/tir/schedule/utils.h
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ */
+#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
+ SRef->StmtAs<Type>(); \
+ ICHECK(Result)
+
+/*!
+ * \brief A helper macro to convert an sref to the block it points to,
+ * throwing an internal error if downcasting fails
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ */
+#define TVM_SREF_TO_BLOCK(Result, SRef) \
+ TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \
+ << "TypeError: Expects StmtSRef `" << #SRef \
+ << "` points to `Block`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None")
+
+/*!
+ * \brief A helper macro to convert an sref to the for-loop it points to,
+ * throwing an internal error if downcasting fails
+ * \param Result The name of the result variable, used for checking
+ * \param SRef The SRef to be casted
+ */
+#define TVM_SREF_TO_FOR(Result, SRef) \
+ TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \
+ << "TypeError: Expects StmtSRef `" << #SRef \
+ << "` points to `Loop`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None")
+
+/*!
+ * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param From The ObjectRef to be downcasted
+ * \param Type The type to be downcasted to
+ */
+#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
+ From.as<Type>(); \
+ ICHECK(Result)
+
+/*!
+ * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
+ * throwing an internal error if downcast fails.
+ * \param Result The result variable, used for checking
+ * \param From The ObjectRef to be downcasted
+ * \param Type The type to be downcasted to
+ */
+#define TVM_TYPE_AS(Result, From, Type) \
+ TVM_TYPE_AS_OR_ERR(Result, From, Type) \
+ << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \
+ << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None")
+
+} // namespace tir
+} // namespace tvm
+
+#endif // TVM_TIR_SCHEDULE_UTILS_H_
diff --git a/tests/python/unittest/test_tir_block_scope.py b/tests/python/unittest/test_tir_block_scope.py
new file mode 100644
index 0000000..4a914f5
--- /dev/null
+++ b/tests/python/unittest/test_tir_block_scope.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import tvm
+from tvm import tir
+from tvm.script import ty
+from tvm.tir.schedule import DepKind
+from tvm.tir.stmt_functor import post_order_visit
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
+@tvm.script.tir
+def elementwise(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128), "float32")
+ C = tir.match_buffer(c, (128, 128), "float32")
+ B = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "init") as [vi, vj]:
+ C[vi, vj] = tir.float32(0)
+ for k in range(0, 128):
+ with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def war_dependency(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.match_buffer(b, (128, 128))
+ C = tir.match_buffer(c, (128, 128))
+
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+
+
+# pylint: enable=no-member,invalid-name,unused-variable
+
+# pylint: disable=invalid-name
+
+
+def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef:
+ result = None
+
+ def f_visit(node):
+ nonlocal result
+ if isinstance(node, tvm.tir.Block) and node.name_hint == name_hint:
+ result = node
+
+ func = s.mod["main"]
+ post_order_visit(func.body, f_visit)
+ assert result is not None and isinstance(result, tvm.tir.Block)
+ return s.get_sref(result)
+
+
+def test_elementwise_dependency():
+ s = tir.ScheduleState(elementwise, debug_mode=True)
+ root = _get_block(s, "root")
+ block_b = _get_block(s, "B")
+ block_c = _get_block(s, "C")
+ # Check get_deps_by_src
+ (dep,) = s.get_block_scope(root).get_deps_by_src(block_b)
+ assert dep.src.same_as(block_b)
+ assert dep.dst.same_as(block_c)
+ assert dep.kind == DepKind.RAW
+ # Check get_deps_by_dst
+ (dep,) = s.get_block_scope(root).get_deps_by_dst(block_c)
+ assert dep.src.same_as(block_b)
+ assert dep.dst.same_as(block_c)
+ assert dep.kind == DepKind.RAW
+
+
+def test_matmul_dependency():
+ s = tir.ScheduleState(matmul, debug_mode=True)
+ root = _get_block(s, "root")
+ init = _get_block(s, "init")
+ update = _get_block(s, "update")
+ # Check get_deps_by_src
+ p0, p1 = s.get_block_scope(root).get_deps_by_src(init)
+ assert p0.src.same_as(init)
+ assert p0.dst.same_as(update)
+ assert p1.src.same_as(init)
+ assert p1.dst.same_as(update)
+ assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or (
+ p0.kind == DepKind.WAW and p1.kind == DepKind.RAW
+ )
+ # Check get_deps_by_dst
+ p0, p1 = s.get_block_scope(root).get_deps_by_dst(update)
+ assert p0.src.same_as(init)
+ assert p0.dst.same_as(update)
+ assert p1.src.same_as(init)
+ assert p1.dst.same_as(update)
+ assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or (
+ p0.kind == DepKind.WAW and p1.kind == DepKind.RAW
+ )
+
+
+def test_war_dependency():
+ s = tir.ScheduleState(war_dependency, debug_mode=True)
+ root = _get_block(s, "root")
+ block_c = _get_block(s, "C")
+ block_b = _get_block(s, "B")
+ # Check get_deps_by_src
+ (dep,) = s.get_block_scope(root).get_deps_by_src(block_c)
+ assert dep.src.same_as(block_c)
+ assert dep.dst.same_as(block_b)
+ assert dep.kind == DepKind.WAR
+ # Check get_deps_by_dst
+ (dep,) = s.get_block_scope(root).get_deps_by_dst(block_b)
+ assert dep.src.same_as(block_c)
+ assert dep.dst.same_as(block_b)
+ assert dep.kind == DepKind.WAR
+
+
+if __name__ == "__main__":
+ test_elementwise_dependency()
+ test_matmul_dependency()
+ test_war_dependency()
diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py
new file mode 100644
index 0000000..ac98725
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_state.py
@@ -0,0 +1,352 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+
+import gc
+
+import tvm
+from tvm import tir
+from tvm.ir import IRModule
+from tvm.script import ty
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
+@tvm.script.tir
+def elementwise(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128), "float32")
+ C = tir.match_buffer(c, (128, 128), "float32")
+ B = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "init") as [vi, vj]:
+ C[vi, vj] = tir.float32(0)
+ for k in range(0, 128):
+ with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128), "float32")
+ B = tir.match_buffer(b, (128, 128), "float32")
+ with tir.block([128], "B") as vi:
+ tir.reads([A[0:128, 0:128]])
+ tir.writes([B[0:128, 0:128]])
+ B[vi, 0] = A[vi, 0]
+ if A[vi, 0] == 0.0:
+ with tir.block([], "C"):
+ tir.reads([A[0:128, 0:128]])
+ tir.writes([B[0:128, 0:128]])
+ with tir.block([128], "D") as vj:
+ B[vi, vj] = A[vi, vj] * 3.0
+ else:
+ with tir.block([], "E"):
+ tir.reads([A[0:128, 0:128]])
+ tir.writes([B[0:128, 0:128]])
+ with tir.block([128], "F") as vj:
+ B[vi, vj] = A[vi, vj] * 2.0
+
+
+# pylint: enable=no-member,invalid-name,unused-variable
+
+
+def replace_ir_builder(deep_copy=False, realize=False):
+ new_func = tvm.script.from_source(tvm.script.asscript(elementwise))
+ s = tir.ScheduleState(new_func, debug_mode=True)
+ target = tvm.tir.Block(
+ iter_vars=[],
+ reads=[],
+ writes=[],
+ name_hint="target",
+ body=s.mod["main"].body.block.body[1],
+ init=None,
+ alloc_buffers=None,
+ match_buffers=None,
+ annotations=None,
+ )
+ if realize:
+ target = tvm.tir.BlockRealize(
+ iter_values=[],
+ predicate=True,
+ block=target,
+ )
+ if deep_copy:
+ target.__setstate__(target.__getstate__())
+ gc.collect()
+ return s, target
+
+
+def replace_ir_builder_module(deep_copy=False, realize=False):
+ new_func = tvm.script.from_source(tvm.script.asscript(elementwise))
+ other_func = tvm.script.from_source(tvm.script.asscript(elementwise))
+ mod = IRModule(functions={"main": new_func, "other": other_func})
+ s = tir.ScheduleState(mod, debug_mode=True)
+ target = tvm.tir.Block(
+ iter_vars=[],
+ reads=[],
+ writes=[],
+ name_hint="target",
+ body=s.mod["main"].body.block.body[1],
+ init=None,
+ alloc_buffers=None,
+ match_buffers=None,
+ annotations=None,
+ )
+ if realize:
+ target = tvm.tir.BlockRealize(
+ iter_values=[],
+ predicate=True,
+ block=target,
+ )
+ if deep_copy:
+ target.__setstate__(target.__getstate__())
+ gc.collect()
+ return s, target
+
+
+def replace_ir_builder_with_opaque():
+ func = tvm.script.from_source(tvm.script.asscript(block_in_opaque_block))
+ s = tir.ScheduleState(func, debug_mode=True)
+ gc.collect()
+ return s
+
+
+def test_replace_direct_write0():
+ s, target = replace_ir_builder(realize=True)
+ old_hash = s.mod["main"].__hash__()
+ sref = s.get_sref(s.mod["main"].body.block.body[1])
+ s.replace(sref, target)
+ # There is no other reference so the AST node can be written directly
+ assert old_hash == s.mod["main"].__hash__()
+ # Check the replaced part is equal to the target
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target)
+ # The target reuse the stmt of the sref, so the sref won't be None
+ assert sref.stmt is not None
+
+
+def test_replace_direct_write1():
+ s, target = replace_ir_builder(realize=True)
+ old_hash = s.mod["main"].body.block.body.__hash__()
+ hold_ref = s.mod["main"].body.block.body[1]
+ sref = s.get_sref(s.mod["main"].body.block.body[1])
+ s.replace(sref, target)
+ # There is no other reference so the AST node can be written directly
+ assert old_hash == s.mod["main"].body.block.body.__hash__()
+ assert not tvm.ir.structural_equal(hold_ref.body, target)
+ # Check the replaced part is equal to the target
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target)
+ # The target reuse `sref.stmt`, so the sref won't be None
+ assert sref.stmt is not None
+
+
+def test_replace_copy():
+ s, target = replace_ir_builder(deep_copy=True, realize=True)
+ old_hash = s.mod["main"].__hash__()
+ # We hold another reference of func
+ old_func = s.mod["main"]
+ sref = s.get_sref(s.mod["main"].body.block.body[0])
+ s.replace(sref, target)
+ # We need to copy the whole func to remain the old_func unchanged
+ assert old_hash != s.mod["main"].__hash__()
+ assert not tvm.ir.structural_equal(old_func.body, s.mod["main"].body)
+ assert old_hash == old_func.__hash__()
+ # Check the replaced part is equal to the target
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target)
+ # The replaced AST node will be deleted, so the ref will be None
+ assert sref.stmt is None
+
+
+def test_replace_partial_copy0():
+ s, target = replace_ir_builder(deep_copy=True, realize=True)
+ func_old_hash = s.mod["main"].__hash__()
+ hold_ref = s.mod["main"].body.block.body[0]
+ ref_old_hash = hold_ref.__hash__()
+ sref = s.get_sref(s.mod["main"].body.block.body[0].body)
+ other_part_hash = s.mod["main"].body.block.body[1].__hash__()
+ s.replace(sref, target)
+ # The stmt is held by `hold_sref`, so it will be coped in copy-on-write because the ref count is not unique
+ assert ref_old_hash != s.mod["main"].body.block.body[0].__hash__()
+ assert not tvm.ir.structural_equal(hold_ref.body, target)
+ # The function and the other part stmt can be directly written
+ assert func_old_hash == s.mod["main"].__hash__()
+ assert other_part_hash == s.mod["main"].body.block.body[1].__hash__()
+ # Check the replaced part is equal to the target
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0].body, target)
+ # The replaced AST node will be deleted, so the ref will be None
+ assert sref.stmt is None
+
+
+def test_replace_partial_copy1():
+ s, target = replace_ir_builder(deep_copy=True)
+ func_old_hash = s.mod["main"].__hash__()
+ hold_ref = s.mod["main"].body.block.body[0].body
+ stmt_old_hash = s.mod["main"].body.block.body[0].__hash__()
+ sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block)
+ other_part_hash = s.mod["main"].body.block.body[1].__hash__()
+ s.replace(sref, target)
+ # The parent stmt will change since there is only one reference
+ assert stmt_old_hash == s.mod["main"].body.block.body[0].__hash__()
+ assert not tvm.ir.structural_equal(hold_ref.body, target)
+ # The function and the other part stmt can be directly written
+ assert func_old_hash == s.mod["main"].__hash__()
+ assert other_part_hash == s.mod["main"].body.block.body[1].__hash__()
+ # Check the replaced part is equal to the target
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0].body.body.block, target)
+ # The replaced AST node will be deleted, so the ref will be None
+ assert sref.stmt is None
+
+
+def test_replace_root_write():
+ s, target = replace_ir_builder()
+ old_hash = s.mod["main"].__hash__()
+ sref = s.get_sref(s.mod["main"].body.block)
+ s.replace(sref, target)
+ # Check no copy and the new body equals to target
+ assert old_hash == s.mod["main"].__hash__()
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+
+
+def test_replace_root_copy0():
+ s, target = replace_ir_builder(deep_copy=True)
+ old_hash = s.mod["main"].__hash__()
+ func_ref = s.mod["main"]
+ sref = s.get_sref(s.mod["main"].body.block)
+ s.replace(sref, target)
+ # Check the new body equals to target
+ assert old_hash != s.mod["main"].__hash__()
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+ # Check the original func remains unchanged
+ assert old_hash == func_ref.__hash__()
+ assert not tvm.ir.structural_equal(func_ref.body, target)
+
+
+def test_replace_root_copy1():
+ s, target = replace_ir_builder(deep_copy=True, realize=True)
+ old_hash = s.mod["main"].body.block.__hash__()
+ func_ref = s.mod["main"].body.block
+ sref = s.get_sref(s.mod["main"].body.block.body[0])
+ s.replace(sref, target)
+ # Check the new body equals to target
+ assert old_hash != s.mod["main"].body.block.__hash__()
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target)
+ # Check the original func remains unchanged
+ assert old_hash == func_ref.__hash__()
+ assert not tvm.ir.structural_equal(func_ref.body, target)
+
+
+def test_replace_root_copy2():
+ s, target = replace_ir_builder(deep_copy=True)
+ old_hash = s.mod.functions.__hash__()
+ func_ref = s.mod.functions
+ sref = s.get_sref(s.mod["main"].body.block)
+ s.replace(sref, target)
+ # Check the new body equals to target
+ assert old_hash != s.mod.functions.__hash__()
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+ # Check the original func remains unchanged
+ assert old_hash == func_ref.__hash__()
+ for _, v in func_ref.items():
+ assert not tvm.ir.structural_equal(v.body.block, target)
+
+
+def test_replace_root_copy3():
+ s, target = replace_ir_builder(deep_copy=True)
+ old_hash = s.mod.__hash__()
+ func_ref = s.mod
+ sref = s.get_sref(s.mod["main"].body.block)
+ s.replace(sref, target)
+ # Check the new body equals to target
+ assert old_hash != s.mod.__hash__()
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+ # Check the original func remains unchanged
+ assert old_hash == func_ref.__hash__()
+ assert not tvm.ir.structural_equal(func_ref["main"].body.block, target)
+
+
+def test_replace_block_remap():
+ func = elementwise
+ s = tir.ScheduleState(func, debug_mode=True)
+ # The target stmt
+ target = matmul.body.block.body.body.body[0].block
+ sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block)
+ s.replace(sref, target, {sref.stmt: target})
+ sref_new = s.get_sref(s.mod["main"].body.block.body[0].body.body.block)
+ # Check the original sref has been remapped
+ assert sref.__hash__() == sref_new.__hash__()
+ tvm.ir.assert_structural_equal(sref.stmt, target)
+
+
+def test_replace_block_in_opaque_block():
+ s = replace_ir_builder_with_opaque()
+ root_hash = s.mod["main"].__hash__()
+ for_loop = s.mod["main"].body.block.body.body.block.body[1].then_case.block.body
+ sref = s.get_sref(for_loop)
+ new_for_loop = tir.For(
+ loop_var=for_loop.loop_var,
+ min_val=0,
+ extent=128,
+ kind=tir.ForKind.SERIAL,
+ body=tir.Evaluate(0),
+ thread_binding=None,
+ annotations=None,
+ )
+ s.replace(sref, new_for_loop)
+ assert root_hash == s.mod["main"].__hash__()
+ tvm.ir.assert_structural_equal(sref.stmt, new_for_loop)
+
+
+def test_replace_ir_module():
+ s, target = replace_ir_builder_module(deep_copy=True)
+ old_hash = s.mod["main"].__hash__()
+ other_func_hash = s.mod["other"].__hash__()
+ func_ref = s.mod["main"]
+ sref = s.get_sref(s.mod["main"].body.block)
+ s.replace(sref, target)
+ # Check the new body equals to target
+ assert old_hash != s.mod["main"].__hash__()
+ tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+ # Check the original func remains unchanged
+ assert old_hash == func_ref.__hash__()
+ assert not tvm.ir.structural_equal(func_ref.body, target)
+ assert other_func_hash == s.mod["other"].__hash__()
+
+
+if __name__ == "__main__":
+ test_replace_direct_write0()
+ test_replace_direct_write1()
+ test_replace_copy()
+ test_replace_partial_copy0()
+ test_replace_partial_copy1()
+ test_replace_root_write()
+ test_replace_root_copy0()
+ test_replace_root_copy1()
+ test_replace_root_copy2()
+ test_replace_root_copy3()
+ test_replace_block_remap()
+ test_replace_block_in_opaque_block()
+ test_replace_ir_module()