You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/29 19:19:08 UTC

[GitHub] [tvm] junrushao1994 opened a new pull request #7765: [M1b] Scaffolding ScheduleState data structure

junrushao1994 opened a new pull request #7765:
URL: https://github.com/apache/tvm/pull/7765


   This PR is part of the stage M1b, TensorIR upstreaming plan (https://github.com/apache/tvm/issues/7527), on the core data structure, ScheduleState.
   
   This PR introduces two key concepts: BlockScope and ScheduleState. The ScheduleState provides a key method `Replace`, which allows all the schedule primitives to be developed around.
   
   Detailed explanation of all the terminologies, concepts and algorithms is provided in the documentation.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-809648900


   cc @comaniac @jroesch @yzhliu @icemelon9 @jcf94 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603697766



##########
File path: src/tir/schedule/block_scope.cc
##########
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Utility functions ********/
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/*!
+ * \brief Add a dependency relation.
+ * \param src The source of the dependency
+ * \param dst The destination of the dependecy
+ * \param kind Type of the dependency
+ * \note This method is effectively NOP on self-loops
+ */
+void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
+  if (!src.same_as(dst)) {
+    Dependency dep(src, dst, kind);
+    self->src2deps[src].push_back(dep);
+    self->dst2deps[dst].push_back(dep);
+  }
+}
+
+/******** Constructors ********/
+
+StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) {
+  ObjectPtr<StmtSRefNode> n = make_object<StmtSRefNode>();
+  n->stmt = stmt;
+  n->parent = parent;
+  n->seq_index = seq_index;
+  data_ = std::move(n);
+}
+
+StmtSRef StmtSRef::InlineMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}
+
+StmtSRef StmtSRef::RootMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}
+
+Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) {
+  ObjectPtr<DependencyNode> node = make_object<DependencyNode>();
+  node->src = std::move(src);
+  node->dst = std::move(dst);
+  node->kind = kind;
+  data_ = std::move(node);
+}
+
+BlockScope::BlockScope() { data_ = make_object<BlockScopeNode>(); }
+
+BlockScope::BlockScope(const Array<StmtSRef>& child_block_srefs) {
+  ObjectPtr<BlockScopeNode> n = make_object<BlockScopeNode>();
+  SMap<Buffer, Array<StmtSRef>> buffer_readers;
+  SMap<Buffer, Array<StmtSRef>>& buffer_writers = n->buffer_writers;
+  for (const StmtSRef& child_block_sref : child_block_srefs) {
+    const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref);
+    // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer
+    for (const BufferRegion& region : child_block->reads) {
+      buffer_readers[region->buffer].push_back(child_block_sref);
+    }
+    for (const BufferRegion& region : child_block->writes) {
+      buffer_writers[region->buffer].push_back(child_block_sref);
+    }
+    // Step 2. Update RAW dependency
+    for (const BufferRegion& region : child_block->reads) {
+      auto it = buffer_writers.find(region->buffer);
+      if (it != buffer_writers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kRAW);
+        }
+      }
+    }
+    // Step 3. Update WAW dependency
+    for (const BufferRegion& region : child_block->writes) {
+      auto it = buffer_writers.find(region->buffer);
+      if (it != buffer_writers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kWAW);
+        }
+      }
+    }
+    // Step 4. Update WAR dependency
+    for (const BufferRegion& region : child_block->writes) {
+      auto it = buffer_readers.find(region->buffer);
+      if (it != buffer_readers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kWAR);
+        }
+      }
+    }

Review comment:
       Yes we can, but I would like to make those steps split fare so that it looks clearer




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-810515117


   Although out of the scope of this RP, I am really glad that we have the discussion about block names.
   
   @comaniac brought up the point https://github.com/apache/tvm/pull/7765#discussion_r603657740:
   
   > This function makes me think that we should make root as a preserved block name, and we should not allow duplicated block names in every tree of a PrimFunc.
   
   I kinda agree with Cody about his points, but would love to hear more discussion on the block name. Particularly, we have three points to discuss:
   - A1. Block names need to be unique. The reason is that the canonical way of retrieving a block is to use its name, i.e. `schedule.get_block(name)`. Without a unique name, we are unable to even retrieve a block, which makes scheduling almost impossible. (of course, it is possible to retrieve a block by the buffer it produces or via a statement, but it is not the canonical way)
   - A2. We need reserved names for the root block. I am kinda in favor of this idea too, because we do provide syntactic sugar to auto complete the root block with the name "root". This could help us eliminate possible name conflicts.
   - A3. Users could specify the names of newly created blocks/loops. Yes, it is doable when implementing schedule primitives.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608159960



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>

Review comment:
       Yeah we did use a lot of non-owned pointers (raw pointers) in the schedule's internal state, and it is intentional to avoid cyclic dependency. Introducing weak references is indeed an overkill to represent those objects, but the mechanism doesn't help in our particular case, because it cannot guarantee weak objects are not released (otherwise it is strong reference), so we still need to manually provide such guarantee in the `Replace` API. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r604371963



##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {

Review comment:
       I brought the discussion here: https://github.com/apache/tvm/pull/7765#issuecomment-810515117




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603737159



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.

Review comment:
       We use `seq_index` for the srefs whose parent in TensorIR AST is a `SeqStmt`, since `SeqStmt` may have multiple children. For a child of `SeqStmt`, if it is a Block/For, we set the `seq_index` of its corresponding sref to its index among the children of the `SeqStmt`. For other srefs(whose parents in AST are not `SeqStmt`), we set their `seq_index` to -1.
   
   Yeah, the document here is not pretty clear. @junrushao1994 Perhaps we can update it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r604293800



##########
File path: src/tir/schedule/utils.h
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ * \note The `E` in the macro means `error`, which means allowing to customize error message
+ */
+#define TVM_SREF_TO_E(Result, SRef, Type) \

Review comment:
       Yeah I think `TVM_SREF_AS_OR_ERR ` is a better name. I will go with the name :-)
   
   The reason that we introduced this macro is that there are too many places using such conversion-and-check with almost the same error message, because most of the subsequent methods/analysis/primitives are mostly based on sref, not stmt. Introducing such macro could help alleviate the burden of writing several lines of almost identical checks and provide a consistent template of the error message.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608132415



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;

Review comment:
       Yeah I think both `int64_t` or `Optional<Integer>` works, and in this particular case where the API is not user-facing, to keep the overhead as low, let's go with the native int64_t :-)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603796202



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.

Review comment:
       Yeah it doesn't read well. I rephrased it a bit. What about this one:
   
   > If the statement the sref points to is an element of a SeqStmt in the AST, then `seq_index` is set to its index; otherwise `seq_index` is -1




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r604369710



##########
File path: src/tir/schedule/utils.h
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ * \note The `E` in the macro means `error`, which means allowing to customize error message
+ */
+#define TVM_SREF_TO_E(Result, SRef, Type) \

Review comment:
       Aha, a better name might be `TVM_SREF_STMT_AS_OR_ERR`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603697955



##########
File path: python/tvm/tir/schedule/state.py
##########
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This file defines ScheduleState, the core data structure of TensorIR scheduling."""
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import Dict, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.ir import IRModule
+from tvm.runtime import Object
+from tvm.tir import Block, BlockRealize, For, PrimFunc
+
+from . import _ffi_api_schedule
+from .block_scope import BlockScope, StmtSRef
+
+
+class ScheduleDebugMask(IntEnum):
+    """The bitmask of the debug flag in the ScheduleStateNode.

Review comment:
       Yes we should. Thanks for pointing out!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603804112



##########
File path: src/tir/schedule/utils.h
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ * \note The `E` in the macro means `error`, which means allowing to customize error message
+ */
+#define TVM_SREF_TO_E(Result, SRef, Type) \

Review comment:
       The macro with "E" suffix is short for "error", which means the error message is customizable. It is only exposed for flexibility, and rarely used in the codebase. Given it is an internal util macro, and has been well documented, I think it is fine to keep it here. Of course better names are definitely welcome :-)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603805772



##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {

Review comment:
       I'd like to see more discussion about uniqueness of block names and reserved block names.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608110217



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited

Review comment:
       Oh we don't want to visit the weak references in the visitors, because those pointers are less meaningful on the python side. Instead, we provide FFI functions that return strong references: see [block_scope.cc:144-151](https://github.com/apache/tvm/pull/7765/files#diff-32dfb07672aaa02e5e57bae323fb938ddb88db0f4cd6bda0f84a20299d4cf5c0R144-R151)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603625732



##########
File path: python/tvm/tir/schedule/block_scope.py
##########
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope."""
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import List, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir import Block, For
+
+from . import _ffi_api_schedule
+
+
+@register_object("tir.StmtSRef")
+class StmtSRef(Object):
+    """An object that refers to schedulable elements in the TensorIR, aka "sref".
+
+    Glossary
+    - Block sref: An StmtSref that points to a TensorIR block.
+    - Loop sref: An StmtSRef that points to a TensorIR for loop.
+    - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+    schedulable statement of its ancestors on the TensorIR AST.
+    - Root sref: Sref to the root block. Every sref has exactly one parent sref
+    except for root sref.
+    - Sref tree: The parent-children-relationship of srefs that forms a tree,
+    uniquely determined by the TensorIR AST.
+    """
+
+    seq_index: int
+
+    @property
+    def stmt(self) -> Optional[Union[Block, For]]:
+        """The block/for stmt the object refers to"""
+        return _ffi_api_schedule.StmtSRefStmt(self)  # pylint: disable=no-member
+
+    @property
+    def parent(self) -> Optional[StmtSRef]:
+        """The parent sref"""
+        return _ffi_api_schedule.StmtSRefParent(self)  # pylint: disable=no-member
+
+    @staticmethod
+    def inline_mark() -> StmtSRef:
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do the work of compute-inline"""
+        return _ffi_api_schedule.StmtSRefInlineMark()  # pylint: disable=no-member
+
+    @staticmethod
+    def root_mark() -> StmtSRef:
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do nothing"""
+        return _ffi_api_schedule.StmtSRefRootMark()  # pylint: disable=no-member
+
+
+class DepKind(IntEnum):
+    """Type of dependency.
+
+    Attributes
+    ----------
+    RAW : int = 0
+        Read-after-write dependency
+    WAW : int = 1
+        Write-after-write dependency
+    WAR : int = 2
+        Write-after-read dependency. Not supported in TensorIR for now.
+    OPAQUE: int = 3
+        Opaque dependency
+    """
+
+    RAW = 0
+    WAW = 1
+    WAR = 2
+    OPAQUE = 3

Review comment:
       1. Out of curiosity, when will we get opaque dependency? Is this determined by the power of the analyzer?
   2. What about RAR?

##########
File path: python/tvm/tir/stmt.py
##########
@@ -607,12 +608,18 @@ class BlockRealize(Stmt):
     def __init__(
         self,
         iter_values: List[PrimExpr],
-        predicate: PrimExpr,
+        predicate: Union[PrimExpr, bool],

Review comment:
       Need to update the docstring.

##########
File path: python/tvm/tir/schedule/block_scope.py
##########
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope."""
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import List, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir import Block, For
+
+from . import _ffi_api_schedule
+
+
+@register_object("tir.StmtSRef")
+class StmtSRef(Object):
+    """An object that refers to schedulable elements in the TensorIR, aka "sref".
+
+    Glossary
+    - Block sref: An StmtSref that points to a TensorIR block.
+    - Loop sref: An StmtSRef that points to a TensorIR for loop.
+    - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+    schedulable statement of its ancestors on the TensorIR AST.
+    - Root sref: Sref to the root block. Every sref has exactly one parent sref
+    except for root sref.
+    - Sref tree: The parent-children-relationship of srefs that forms a tree,
+    uniquely determined by the TensorIR AST.
+    """
+
+    seq_index: int
+
+    @property
+    def stmt(self) -> Optional[Union[Block, For]]:
+        """The block/for stmt the object refers to"""
+        return _ffi_api_schedule.StmtSRefStmt(self)  # pylint: disable=no-member
+
+    @property
+    def parent(self) -> Optional[StmtSRef]:
+        """The parent sref"""
+        return _ffi_api_schedule.StmtSRefParent(self)  # pylint: disable=no-member
+
+    @staticmethod
+    def inline_mark() -> StmtSRef:
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do the work of compute-inline"""
+        return _ffi_api_schedule.StmtSRefInlineMark()  # pylint: disable=no-member
+
+    @staticmethod
+    def root_mark() -> StmtSRef:
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do nothing"""
+        return _ffi_api_schedule.StmtSRefRootMark()  # pylint: disable=no-member
+
+
+class DepKind(IntEnum):
+    """Type of dependency.
+
+    Attributes
+    ----------
+    RAW : int = 0
+        Read-after-write dependency
+    WAW : int = 1
+        Write-after-write dependency
+    WAR : int = 2
+        Write-after-read dependency. Not supported in TensorIR for now.
+    OPAQUE: int = 3
+        Opaque dependency
+    """
+
+    RAW = 0
+    WAW = 1
+    WAR = 2
+    OPAQUE = 3
+
+
+@register_object("tir.Dependency")
+class Dependency(Object):
+    """A tuple (src, dst, kind) representing certain types of dependency.
+    For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+    read-after-write, which means block B reads the result written by block A.
+
+    Parameters
+    ----------
+    src : StmtSRef
+        The source of the dependency relation
+    dst : StmtSRef
+        The destination of the dependency relation
+    kind : DepKind
+        The dependency kind
+    """
+
+    src: StmtSRef
+    dst: StmtSRef
+    kind: DepKind
+
+
+@register_object("tir.BlockScope")
+class BlockScope(Object):
+    """An object corresponds to each block sref in the sref tree,
+       which tracks the producer-consumer dependency between blocks.
+
+    Glossary:
+    - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+    whose components are:
+        - scope root: a block sref
+        - internal srefs: loop srefs
+        - scope leaves: block srefs
+    - Child block: The scope leaf blocks under the scope root or a specific internal sref
+    """
+
+    def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]:
+        """Get all dependencies whose `src` equals `block`

Review comment:
       ```suggestion
           """Get all dependencies whose `src` is the target`block`.
   ```

##########
File path: src/tir/schedule/block_scope.cc
##########
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Utility functions ********/
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/*!
+ * \brief Add a dependency relation.
+ * \param src The source of the dependency
+ * \param dst The destination of the dependecy
+ * \param kind Type of the dependency
+ * \note This method is effectively NOP on self-loops
+ */
+void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
+  if (!src.same_as(dst)) {
+    Dependency dep(src, dst, kind);
+    self->src2deps[src].push_back(dep);
+    self->dst2deps[dst].push_back(dep);
+  }
+}
+
+/******** Constructors ********/
+
+StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) {
+  ObjectPtr<StmtSRefNode> n = make_object<StmtSRefNode>();
+  n->stmt = stmt;
+  n->parent = parent;
+  n->seq_index = seq_index;
+  data_ = std::move(n);
+}
+
+StmtSRef StmtSRef::InlineMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}
+
+StmtSRef StmtSRef::RootMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}

Review comment:
       These two functions look exactly the same. Could you explain how they work?

##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {

Review comment:
       This function makes me think that we should make `root` as a preserved block name, and we should not allow duplicated block names in every tree of a PrimFunc.

##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  const StmtNode* src_stmt_;
+  const Stmt& tgt_stmt_;
+  int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+                                const Map<Block, Block>& _block_sref_reuse) {
+  if (this->debug_mode != 0) {
+    const StmtNode* src_stmt = _src_sref->stmt;
+    bool input_correct =
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+        (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+    if (!input_correct) {
+      LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+                 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+                 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+                 << tgt_stmt;
+    }
+  }
+  // Rule out the case that no replacement happens
+  if (_src_sref->stmt == tgt_stmt.get()) {
+    return;
+  }
+  // Reset sref as a new sref so that its content won't be affected by subsequent changes
+  StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+  Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+  // Step 1. Create all the nodes needed for the new sref tree.
+  // After this step
+  // 1) all `parent`s are correct
+  // 2) all `seq_index`s are correct, except for the root
+  // 3) all `stmt`s are correct, except for the root
+  {
+    // Step 0. Setup block_sref_reuse
+    std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+    block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+    for (const auto& kv : _block_sref_reuse) {
+      block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+    }
+    // Step 1.1. Collect info for different kinds of reuses
+    // 1) intact
+    // 2) loop/block reuse
+    ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+    reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+    // Step 1.2. Collect loop/block reuse to their corresponding srefs
+    // and remove those srefs in the `src_stmt` that are no longer used after replacement
+    std::unordered_map<const Object*, StmtSRef> reused_srefs =
+        SRefTreePruner::Prune(this, reuse_info, src_stmt);
+    // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+    // srefs in `tgt_stmt`
+    SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+  }
+  // Step 2. Set the ancestors' children properly
+  //   Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+  //   The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+  //   Along the way, because we create a new ancestor path,
+  //   we need to update those sref points from old ancestors to newly created ones
+  // Variables:
+  // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+  //   can be mutated inplace, it needs `num_copy_steps + 1` hops.
+  // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+  // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+  int num_copy_steps = -1;
+  bool need_module_copy = false;
+  const PrimFuncNode* g_func = nullptr;
+  GlobalVar g_var;
+  {
+    int i = 0;
+    const StmtSRefNode* p = src_sref.get();
+    while (true) {
+      if (!p->stmt->unique()) {
+        num_copy_steps = i;
+      }
+      if (p->parent == nullptr) {
+        break;
+      }
+      ++i;
+      p = p->parent;
+    }
+    // Find `g_func` and `g_var` where the `src_sref` is in
+    g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+    need_module_copy = num_copy_steps == i ||             //
+                       !this->mod.unique() ||             //
+                       !this->mod->functions.unique() ||  //
+                       !g_func->unique();
+  }
+  // Loop invariant:
+  //
+  // Before step `i`:
+  // 1) `child_sref` is `src_sref` going up by `i` steps
+  // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+  // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+  // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+  // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+  //
+  // During step `i`:
+  // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+  // 2) Point `child_sref` to `child_tgt_stmt`
+  // 3) `tgt_stmt` is of type Loop or Block
+  StmtSRefNode* child_sref = src_sref.get();
+  Stmt child_tgt_stmt = std::move(tgt_stmt);
+  for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+    bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+    // Replace `child_sref->stmt` to `child_tgt_stmt`.
+    const StmtNode* parent_stmt = child_sref->parent->stmt;
+    const StmtNode* child_src_stmt = child_sref->stmt;
+    // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+    if (i == 0) {
+      // As the invariance of SRefUpdater,
+      // the `seq_index` of the root of `tgt_stmt` is set as -1,
+      // which might be incorrect
+      SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+    } else {
+      // Point `child_sref` to `child_tgt_stmt`
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+    Stmt new_parent_stmt =
+        ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+                               /*seq_index=*/child_sref->seq_index,
+                               /*allow_copy_on_write=*/can_directly_mutate_parent);
+    // Step 2.3. Go to next parent
+    if (can_directly_mutate_parent) {
+      // If the node can be directly mutated inplace,
+      // then there is no need to update its parent and the function
+      break;
+    }
+    child_tgt_stmt = std::move(new_parent_stmt);
+    child_sref = child_sref->parent;
+  }
+  // Step 3. Handle the case that we mutate the root
+  if (need_module_copy) {
+    // From the loop invariant, upon exit, while its subtree is properly set,
+    // `child_sref` is not properly to `child_tgt_stmt` yet.
+    if (src_sref->parent != nullptr) {
+      // Not replacing a root
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+    IRModuleNode* new_mod = this->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    // Move out the PrimFunc where the sref belong while ensuring uniqueness
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    ICHECK(ref_new_func.get() == g_func);
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    // If `g_func` was not unique, after the 3 lines above:
+    //   `ref_new_func` points to a unique PrimFunc
+    //   `g_func` points to the previous PrimFunc if it is not unique
+    // If `g_func` was unique, after the 3 lines above:
+    //   `ref_new_func` points to the same unique function that `g_func` points to
+    // Update the body of the function the sref belongs to Assign
+    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    // Make `child_tgt_stmt` the root block
+    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+    new_realize->block = GetRef<Block>(child_block);
+    new_func->body = BlockRealize(std::move(new_realize));
+    // Finally, move the `ref_new_func` back and update `this->mod`
+    new_map->at(g_var) = std::move(ref_new_func);
+    this->mod = GetRef<IRModule>(new_mod);
+  }
+  if (this->debug_mode & 1) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+  constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+  constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+  ICHECK_GE(debug_mode, -1);

Review comment:
       We can just assign 8 to the `True` case, so that you don't need to deal with negative debug_mode.

##########
File path: python/tvm/tir/schedule/block_scope.py
##########
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope."""
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import List, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir import Block, For
+
+from . import _ffi_api_schedule
+
+
+@register_object("tir.StmtSRef")
+class StmtSRef(Object):
+    """An object that refers to schedulable elements in the TensorIR, aka "sref".
+
+    Glossary
+    - Block sref: An StmtSref that points to a TensorIR block.
+    - Loop sref: An StmtSRef that points to a TensorIR for loop.
+    - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+    schedulable statement of its ancestors on the TensorIR AST.
+    - Root sref: Sref to the root block. Every sref has exactly one parent sref
+    except for root sref.
+    - Sref tree: The parent-children-relationship of srefs that forms a tree,
+    uniquely determined by the TensorIR AST.
+    """
+
+    seq_index: int
+
+    @property
+    def stmt(self) -> Optional[Union[Block, For]]:
+        """The block/for stmt the object refers to"""
+        return _ffi_api_schedule.StmtSRefStmt(self)  # pylint: disable=no-member
+
+    @property
+    def parent(self) -> Optional[StmtSRef]:
+        """The parent sref"""
+        return _ffi_api_schedule.StmtSRefParent(self)  # pylint: disable=no-member
+
+    @staticmethod
+    def inline_mark() -> StmtSRef:
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do the work of compute-inline"""
+        return _ffi_api_schedule.StmtSRefInlineMark()  # pylint: disable=no-member
+
+    @staticmethod
+    def root_mark() -> StmtSRef:
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do nothing"""
+        return _ffi_api_schedule.StmtSRefRootMark()  # pylint: disable=no-member
+
+
+class DepKind(IntEnum):
+    """Type of dependency.
+
+    Attributes
+    ----------
+    RAW : int = 0
+        Read-after-write dependency
+    WAW : int = 1
+        Write-after-write dependency
+    WAR : int = 2
+        Write-after-read dependency. Not supported in TensorIR for now.
+    OPAQUE: int = 3
+        Opaque dependency
+    """
+
+    RAW = 0
+    WAW = 1
+    WAR = 2
+    OPAQUE = 3
+
+
+@register_object("tir.Dependency")
+class Dependency(Object):
+    """A tuple (src, dst, kind) representing certain types of dependency.
+    For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+    read-after-write, which means block B reads the result written by block A.
+
+    Parameters
+    ----------
+    src : StmtSRef
+        The source of the dependency relation
+    dst : StmtSRef
+        The destination of the dependency relation
+    kind : DepKind
+        The dependency kind
+    """
+
+    src: StmtSRef
+    dst: StmtSRef
+    kind: DepKind
+
+
+@register_object("tir.BlockScope")
+class BlockScope(Object):
+    """An object corresponds to each block sref in the sref tree,
+       which tracks the producer-consumer dependency between blocks.
+
+    Glossary:
+    - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+    whose components are:
+        - scope root: a block sref
+        - internal srefs: loop srefs
+        - scope leaves: block srefs
+    - Child block: The scope leaf blocks under the scope root or a specific internal sref
+    """
+
+    def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]:
+        """Get all dependencies whose `src` equals `block`
+
+        Parameters
+        ----------
+        block: StmtSRef
+            The queried block
+
+        Returns
+        -------
+        blocks: List[Dependency]
+            The dependencies
+        """
+        return _ffi_api_schedule.BlockScopeGetDepsBySrc(self, block)  # pylint: disable=no-member
+
+    def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]:
+        """Get all dependencies whose `dst` equals `block`

Review comment:
       ```suggestion
           """Get all dependencies whose `dst` is the target `block`.
   ```

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
+  /*!
+   * \brief Property of a block scope root at the block, indicaiting if the scope is an equivalence
+   * of a stage pipeline. Conditions:
+   * 1) The region cover property holds for every of its child blocks
+   * 2) No write-after-read dependency
+   */
+  bool stage_pipeline{false};
+
+  void VisitAttrs(AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "tir.BlockScope";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object);
+
+ public:

Review comment:
       Ditto

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.

Review comment:
       Not very clear to me. Is this the index of this statement from its parent's point of view? IIUC, the value when the parent has a single child seems to be 0 for consistency.

##########
File path: src/tir/schedule/block_scope.cc
##########
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Utility functions ********/
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/*!
+ * \brief Add a dependency relation.
+ * \param src The source of the dependency
+ * \param dst The destination of the dependecy
+ * \param kind Type of the dependency
+ * \note This method is effectively NOP on self-loops
+ */
+void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
+  if (!src.same_as(dst)) {
+    Dependency dep(src, dst, kind);
+    self->src2deps[src].push_back(dep);
+    self->dst2deps[dst].push_back(dep);
+  }
+}
+
+/******** Constructors ********/
+
+StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) {
+  ObjectPtr<StmtSRefNode> n = make_object<StmtSRefNode>();
+  n->stmt = stmt;
+  n->parent = parent;
+  n->seq_index = seq_index;
+  data_ = std::move(n);
+}
+
+StmtSRef StmtSRef::InlineMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}
+
+StmtSRef StmtSRef::RootMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}
+
+Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) {
+  ObjectPtr<DependencyNode> node = make_object<DependencyNode>();
+  node->src = std::move(src);
+  node->dst = std::move(dst);
+  node->kind = kind;
+  data_ = std::move(node);
+}
+
+BlockScope::BlockScope() { data_ = make_object<BlockScopeNode>(); }
+
+BlockScope::BlockScope(const Array<StmtSRef>& child_block_srefs) {
+  ObjectPtr<BlockScopeNode> n = make_object<BlockScopeNode>();
+  SMap<Buffer, Array<StmtSRef>> buffer_readers;
+  SMap<Buffer, Array<StmtSRef>>& buffer_writers = n->buffer_writers;
+  for (const StmtSRef& child_block_sref : child_block_srefs) {
+    const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref);
+    // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer
+    for (const BufferRegion& region : child_block->reads) {
+      buffer_readers[region->buffer].push_back(child_block_sref);
+    }
+    for (const BufferRegion& region : child_block->writes) {
+      buffer_writers[region->buffer].push_back(child_block_sref);
+    }
+    // Step 2. Update RAW dependency
+    for (const BufferRegion& region : child_block->reads) {
+      auto it = buffer_writers.find(region->buffer);
+      if (it != buffer_writers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kRAW);
+        }
+      }
+    }
+    // Step 3. Update WAW dependency
+    for (const BufferRegion& region : child_block->writes) {
+      auto it = buffer_writers.find(region->buffer);
+      if (it != buffer_writers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kWAW);
+        }
+      }
+    }
+    // Step 4. Update WAR dependency
+    for (const BufferRegion& region : child_block->writes) {
+      auto it = buffer_readers.find(region->buffer);
+      if (it != buffer_readers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kWAR);
+        }
+      }
+    }

Review comment:
       Should be able to merge these two loops?

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:

Review comment:
       Duplicated.

##########
File path: python/tvm/tir/schedule/state.py
##########
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This file defines ScheduleState, the core data structure of TensorIR scheduling."""
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import Dict, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.ir import IRModule
+from tvm.runtime import Object
+from tvm.tir import Block, BlockRealize, For, PrimFunc
+
+from . import _ffi_api_schedule
+from .block_scope import BlockScope, StmtSRef
+
+
+class ScheduleDebugMask(IntEnum):
+    """The bitmask of the debug flag in the ScheduleStateNode.
+
+    Attributes
+    ----------
+    VERIFY_SREF_TREE : int = 1
+        Verify the correctness of the sref tree
+    VERIFY_AFFINE_BINDING : int = 2
+        Verify the correctness of affine_binding
+    VERIFY_REGION_COVER : int = 4
+        Verify the correctness of region_cover
+    VERIFY_STAGE_PIPELINE: int = 8
+        Verify the correctness of stage_pipeline
+    """
+
+    VERIFY_SREF_TREE = 1
+    VERIFY_AFFINE_BINDING = 2
+    VERIFY_REGION_COVER = 4
+    VERIFY_STAGE_PIPELINE = 8
+
+
+@register_object("tir.ScheduleState")
+class ScheduleState(Object):
+    """The state of scheduling, which exposes a `Replace` method as
+    the primary resort for all the scheduling primitives to manipulate the TensorIR.
+
+    The data structure contains the following information
+    1) The AST being scheduled (mod)
+    2) The sref tree of schedulable statements (indicated by the srefs)
+    3) The dependency information of each block scope (block_info)
+    4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
+    5) A debug flag, if set, extra checking is enabled (debug_mode)
+
+    Parameters
+    ----------
+    mod : IRModule
+        The AST of the module being scheduled
+    debug_mode : int
+        Do extra correctness checking after the class creation

Review comment:
       ```suggestion
           Do extra correctness checking after the object construction
   ```

##########
File path: python/tvm/tir/schedule/state.py
##########
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This file defines ScheduleState, the core data structure of TensorIR scheduling."""
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import Dict, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.ir import IRModule
+from tvm.runtime import Object
+from tvm.tir import Block, BlockRealize, For, PrimFunc
+
+from . import _ffi_api_schedule
+from .block_scope import BlockScope, StmtSRef
+
+
+class ScheduleDebugMask(IntEnum):
+    """The bitmask of the debug flag in the ScheduleStateNode.

Review comment:
       IMHO, it would be better to put the comment of debug_mode in the C++ side here, because this is the place most users would trace to when they are interested in using the debug mode.

##########
File path: src/tir/schedule/utils.h
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ * \note The `E` in the macro means `error`, which means allowing to customize error message
+ */
+#define TVM_SREF_TO_E(Result, SRef, Type) \

Review comment:
       * The `E` in macro looks confusing, and I still cannot get the point even after reading the note...
   * This macro itself is also a bit confusion, as Result is being assigned in the LHS. I would suggest simply using inline functions for the cases in this file.

##########
File path: python/tvm/tir/schedule/state.py
##########
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This file defines ScheduleState, the core data structure of TensorIR scheduling."""
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import Dict, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.ir import IRModule
+from tvm.runtime import Object
+from tvm.tir import Block, BlockRealize, For, PrimFunc
+
+from . import _ffi_api_schedule
+from .block_scope import BlockScope, StmtSRef
+
+
+class ScheduleDebugMask(IntEnum):
+    """The bitmask of the debug flag in the ScheduleStateNode.
+
+    Attributes
+    ----------
+    VERIFY_SREF_TREE : int = 1
+        Verify the correctness of the sref tree
+    VERIFY_AFFINE_BINDING : int = 2
+        Verify the correctness of affine_binding
+    VERIFY_REGION_COVER : int = 4
+        Verify the correctness of region_cover
+    VERIFY_STAGE_PIPELINE: int = 8
+        Verify the correctness of stage_pipeline
+    """
+
+    VERIFY_SREF_TREE = 1
+    VERIFY_AFFINE_BINDING = 2
+    VERIFY_REGION_COVER = 4
+    VERIFY_STAGE_PIPELINE = 8
+
+
+@register_object("tir.ScheduleState")
+class ScheduleState(Object):
+    """The state of scheduling, which exposes a `Replace` method as
+    the primary resort for all the scheduling primitives to manipulate the TensorIR.
+
+    The data structure contains the following information
+    1) The AST being scheduled (mod)
+    2) The sref tree of schedulable statements (indicated by the srefs)
+    3) The dependency information of each block scope (block_info)
+    4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
+    5) A debug flag, if set, extra checking is enabled (debug_mode)
+
+    Parameters
+    ----------
+    mod : IRModule
+        The AST of the module being scheduled
+    debug_mode : int
+        Do extra correctness checking after the class creation
+        and each time after calling the Replace method.
+    """
+
+    mod: IRModule
+    debug_mode: int
+
+    def __init__(
+        self,
+        func_or_mod: Union[PrimFunc, IRModule],
+        debug_mode: Union[bool, int] = False,
+    ):
+        """Construct a schedule state from an IRModule or a PrimFunc
+
+        Parameters
+        ----------
+        func_or_mod : Union[PrimFunc, IRModule]
+            The IRModule or PrimFunc to be scheduled
+        debug_mode : Union[bool, int]
+            Do extra correctness checking after the class creation and each time
+            after calling the Replace method.
+            Possible choices of `debug_mode`:
+            1) True - Turn on all the checks
+            2) False - Turn off all the checks
+            3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
+        """
+        if isinstance(debug_mode, bool):
+            if debug_mode:
+                debug_mode = -1
+            else:
+                debug_mode = 0
+        assert isinstance(debug_mode, int)

Review comment:
       Add error message.

##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  const StmtNode* src_stmt_;
+  const Stmt& tgt_stmt_;
+  int seq_index_;

Review comment:
       docstring




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603698616



##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  const StmtNode* src_stmt_;
+  const Stmt& tgt_stmt_;
+  int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+                                const Map<Block, Block>& _block_sref_reuse) {
+  if (this->debug_mode != 0) {
+    const StmtNode* src_stmt = _src_sref->stmt;
+    bool input_correct =
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+        (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+    if (!input_correct) {
+      LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+                 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+                 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+                 << tgt_stmt;
+    }
+  }
+  // Rule out the case that no replacement happens
+  if (_src_sref->stmt == tgt_stmt.get()) {
+    return;
+  }
+  // Reset sref as a new sref so that its content won't be affected by subsequent changes
+  StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+  Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+  // Step 1. Create all the nodes needed for the new sref tree.
+  // After this step
+  // 1) all `parent`s are correct
+  // 2) all `seq_index`s are correct, except for the root
+  // 3) all `stmt`s are correct, except for the root
+  {
+    // Step 0. Setup block_sref_reuse
+    std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+    block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+    for (const auto& kv : _block_sref_reuse) {
+      block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+    }
+    // Step 1.1. Collect info for different kinds of reuses
+    // 1) intact
+    // 2) loop/block reuse
+    ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+    reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+    // Step 1.2. Collect loop/block reuse to their corresponding srefs
+    // and remove those srefs in the `src_stmt` that are no longer used after replacement
+    std::unordered_map<const Object*, StmtSRef> reused_srefs =
+        SRefTreePruner::Prune(this, reuse_info, src_stmt);
+    // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+    // srefs in `tgt_stmt`
+    SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+  }
+  // Step 2. Set the ancestors' children properly
+  //   Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+  //   The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+  //   Along the way, because we create a new ancestor path,
+  //   we need to update those sref points from old ancestors to newly created ones
+  // Variables:
+  // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+  //   can be mutated inplace, it needs `num_copy_steps + 1` hops.
+  // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+  // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+  int num_copy_steps = -1;
+  bool need_module_copy = false;
+  const PrimFuncNode* g_func = nullptr;
+  GlobalVar g_var;
+  {
+    int i = 0;
+    const StmtSRefNode* p = src_sref.get();
+    while (true) {
+      if (!p->stmt->unique()) {
+        num_copy_steps = i;
+      }
+      if (p->parent == nullptr) {
+        break;
+      }
+      ++i;
+      p = p->parent;
+    }
+    // Find `g_func` and `g_var` where the `src_sref` is in
+    g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+    need_module_copy = num_copy_steps == i ||             //
+                       !this->mod.unique() ||             //
+                       !this->mod->functions.unique() ||  //
+                       !g_func->unique();
+  }
+  // Loop invariant:
+  //
+  // Before step `i`:
+  // 1) `child_sref` is `src_sref` going up by `i` steps
+  // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+  // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+  // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+  // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+  //
+  // During step `i`:
+  // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+  // 2) Point `child_sref` to `child_tgt_stmt`
+  // 3) `tgt_stmt` is of type Loop or Block
+  StmtSRefNode* child_sref = src_sref.get();
+  Stmt child_tgt_stmt = std::move(tgt_stmt);
+  for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+    bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+    // Replace `child_sref->stmt` to `child_tgt_stmt`.
+    const StmtNode* parent_stmt = child_sref->parent->stmt;
+    const StmtNode* child_src_stmt = child_sref->stmt;
+    // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+    if (i == 0) {
+      // As the invariance of SRefUpdater,
+      // the `seq_index` of the root of `tgt_stmt` is set as -1,
+      // which might be incorrect
+      SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+    } else {
+      // Point `child_sref` to `child_tgt_stmt`
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+    Stmt new_parent_stmt =
+        ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+                               /*seq_index=*/child_sref->seq_index,
+                               /*allow_copy_on_write=*/can_directly_mutate_parent);
+    // Step 2.3. Go to next parent
+    if (can_directly_mutate_parent) {
+      // If the node can be directly mutated inplace,
+      // then there is no need to update its parent and the function
+      break;
+    }
+    child_tgt_stmt = std::move(new_parent_stmt);
+    child_sref = child_sref->parent;
+  }
+  // Step 3. Handle the case that we mutate the root
+  if (need_module_copy) {
+    // From the loop invariant, upon exit, while its subtree is properly set,
+    // `child_sref` is not properly to `child_tgt_stmt` yet.
+    if (src_sref->parent != nullptr) {
+      // Not replacing a root
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+    IRModuleNode* new_mod = this->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    // Move out the PrimFunc where the sref belong while ensuring uniqueness
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    ICHECK(ref_new_func.get() == g_func);
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    // If `g_func` was not unique, after the 3 lines above:
+    //   `ref_new_func` points to a unique PrimFunc
+    //   `g_func` points to the previous PrimFunc if it is not unique
+    // If `g_func` was unique, after the 3 lines above:
+    //   `ref_new_func` points to the same unique function that `g_func` points to
+    // Update the body of the function the sref belongs to Assign
+    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    // Make `child_tgt_stmt` the root block
+    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+    new_realize->block = GetRef<Block>(child_block);
+    new_func->body = BlockRealize(std::move(new_realize));
+    // Finally, move the `ref_new_func` back and update `this->mod`
+    new_map->at(g_var) = std::move(ref_new_func);
+    this->mod = GetRef<IRModule>(new_mod);
+  }
+  if (this->debug_mode & 1) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+  constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+  constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+  ICHECK_GE(debug_mode, -1);

Review comment:
       I prefer not to. The reason is that we don't want to introduce a really magic number, and even worse perhaps in the future we do want to extend the verification so the magic number may change over time




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603699556



##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  const StmtNode* src_stmt_;
+  const Stmt& tgt_stmt_;
+  int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+                                const Map<Block, Block>& _block_sref_reuse) {
+  if (this->debug_mode != 0) {
+    const StmtNode* src_stmt = _src_sref->stmt;
+    bool input_correct =
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+        (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+    if (!input_correct) {
+      LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+                 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+                 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+                 << tgt_stmt;
+    }
+  }
+  // Rule out the case that no replacement happens
+  if (_src_sref->stmt == tgt_stmt.get()) {
+    return;
+  }
+  // Reset sref as a new sref so that its content won't be affected by subsequent changes
+  StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+  Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+  // Step 1. Create all the nodes needed for the new sref tree.
+  // After this step
+  // 1) all `parent`s are correct
+  // 2) all `seq_index`s are correct, except for the root
+  // 3) all `stmt`s are correct, except for the root
+  {
+    // Step 0. Setup block_sref_reuse
+    std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+    block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+    for (const auto& kv : _block_sref_reuse) {
+      block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+    }
+    // Step 1.1. Collect info for different kinds of reuses
+    // 1) intact
+    // 2) loop/block reuse
+    ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+    reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+    // Step 1.2. Collect loop/block reuse to their corresponding srefs
+    // and remove those srefs in the `src_stmt` that are no longer used after replacement
+    std::unordered_map<const Object*, StmtSRef> reused_srefs =
+        SRefTreePruner::Prune(this, reuse_info, src_stmt);
+    // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+    // srefs in `tgt_stmt`
+    SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+  }
+  // Step 2. Set the ancestors' children properly
+  //   Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+  //   The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+  //   Along the way, because we create a new ancestor path,
+  //   we need to update those sref points from old ancestors to newly created ones
+  // Variables:
+  // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+  //   can be mutated inplace, it needs `num_copy_steps + 1` hops.
+  // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+  // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+  int num_copy_steps = -1;
+  bool need_module_copy = false;
+  const PrimFuncNode* g_func = nullptr;
+  GlobalVar g_var;
+  {
+    int i = 0;
+    const StmtSRefNode* p = src_sref.get();
+    while (true) {
+      if (!p->stmt->unique()) {
+        num_copy_steps = i;
+      }
+      if (p->parent == nullptr) {
+        break;
+      }
+      ++i;
+      p = p->parent;
+    }
+    // Find `g_func` and `g_var` where the `src_sref` is in
+    g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+    need_module_copy = num_copy_steps == i ||             //
+                       !this->mod.unique() ||             //
+                       !this->mod->functions.unique() ||  //
+                       !g_func->unique();
+  }
+  // Loop invariant:
+  //
+  // Before step `i`:
+  // 1) `child_sref` is `src_sref` going up by `i` steps
+  // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+  // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+  // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+  // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+  //
+  // During step `i`:
+  // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+  // 2) Point `child_sref` to `child_tgt_stmt`
+  // 3) `tgt_stmt` is of type Loop or Block
+  StmtSRefNode* child_sref = src_sref.get();
+  Stmt child_tgt_stmt = std::move(tgt_stmt);
+  for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+    bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+    // Replace `child_sref->stmt` to `child_tgt_stmt`.
+    const StmtNode* parent_stmt = child_sref->parent->stmt;
+    const StmtNode* child_src_stmt = child_sref->stmt;
+    // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+    if (i == 0) {
+      // As the invariance of SRefUpdater,
+      // the `seq_index` of the root of `tgt_stmt` is set as -1,
+      // which might be incorrect
+      SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+    } else {
+      // Point `child_sref` to `child_tgt_stmt`
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+    Stmt new_parent_stmt =
+        ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+                               /*seq_index=*/child_sref->seq_index,
+                               /*allow_copy_on_write=*/can_directly_mutate_parent);
+    // Step 2.3. Go to next parent
+    if (can_directly_mutate_parent) {
+      // If the node can be directly mutated inplace,
+      // then there is no need to update its parent and the function
+      break;
+    }
+    child_tgt_stmt = std::move(new_parent_stmt);
+    child_sref = child_sref->parent;
+  }
+  // Step 3. Handle the case that we mutate the root
+  if (need_module_copy) {
+    // From the loop invariant, upon exit, while its subtree is properly set,
+    // `child_sref` is not properly to `child_tgt_stmt` yet.
+    if (src_sref->parent != nullptr) {
+      // Not replacing a root
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+    IRModuleNode* new_mod = this->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    // Move out the PrimFunc where the sref belong while ensuring uniqueness
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    ICHECK(ref_new_func.get() == g_func);
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    // If `g_func` was not unique, after the 3 lines above:
+    //   `ref_new_func` points to a unique PrimFunc
+    //   `g_func` points to the previous PrimFunc if it is not unique
+    // If `g_func` was unique, after the 3 lines above:
+    //   `ref_new_func` points to the same unique function that `g_func` points to
+    // Update the body of the function the sref belongs to Assign
+    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    // Make `child_tgt_stmt` the root block
+    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+    new_realize->block = GetRef<Block>(child_block);
+    new_func->body = BlockRealize(std::move(new_realize));
+    // Finally, move the `ref_new_func` back and update `this->mod`
+    new_map->at(g_var) = std::move(ref_new_func);
+    this->mod = GetRef<IRModule>(new_mod);
+  }
+  if (this->debug_mode & 1) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+  constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+  constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+  ICHECK_GE(debug_mode, -1);

Review comment:
       Sorry I meant 15 (1111b, a full mask).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603697517



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:

Review comment:
       I use the "public" here as a visual separator between data fields and methods, so i suppose it is fine




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-813120191


   Hey would you guys take another look? Thanks a lot! @comaniac @jcf94 @MasterJH5574 @jroesch 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608138483



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
+  /*!
+   * \brief Property of a block scope root at the block, indicaiting if the scope is an equivalence
+   * of a stage pipeline. Conditions:
+   * 1) The region cover property holds for every of its child blocks
+   * 2) No write-after-read dependency
+   */
+  bool stage_pipeline{false};
+
+  void VisitAttrs(AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "tir.BlockScope";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object);
+
+ public:
+  /******** Dependency ********/
+  /*!
+   * \brief Get all dependencies whose `src` equals `src`
+   * \param src The queried block
+   * \return The dependencies
+   */
+  TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
+  /*!
+   * \brief Get all dependencies whose `dst` equals `dst`
+   * \param dst The queried block
+   * \return The dependencies
+   */
+  TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
+};
+
+/*!
+ * \brief Managed reference to BlockScopeNode
+ * \sa BlockScopeNode
+ */
+class BlockScope : public ObjectRef {
+ public:
+  /*! \brief The constructor creating an empty block scope with on dependency information */
+  TVM_DLL BlockScope();
+  /*!
+   * \brief Create the object with the specific leaf blocks, and complete the dependency information

Review comment:
       Right. `compute` is a better word




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608159960



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>

Review comment:
       Yeah we did use a lot of non-owned pointers (raw pointers) in the schedule's internal state, and it is intentional to avoid cyclic dependency. Introducing weak references is indeed an overkill to represent those objects, but the mechanism doesn't help in our particular case, because it cannot guarantee weak objects are not released (otherwise it is strong reference), so we still need to manually provide such guarantee in the `Replace` API.  (and that's why `Replace` is so complicated and we wrote a lot of tests and some proof to make sure it works)

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>

Review comment:
       Yeah we did use a lot of non-owned pointers (raw pointers) in the schedule's internal state, and it is intentional to avoid cyclic dependency. Introducing weak references is indeed an overkill to represent those objects, but the mechanism doesn't help in our particular case, because it cannot guarantee weak objects are not released (otherwise it is strong reference), so we still need to manually provide such guarantee in the `Replace` API.  (and that's why `Replace` is so complicated and we wrote a lot of tests and some proof to make sure it works as expected)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608138217



##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/state.h
+ * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling.
+ */
+#ifndef TVM_TIR_SCHEDULE_STATE_H_
+#define TVM_TIR_SCHEDULE_STATE_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/schedule/block_scope.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief The information about a TensorIR block, it contains two categories of information
+ * 1) Info on the block scope rooted at a specific block, including dependency tracking,
+ * flags indicating if the scope is a stage pipeline, etc.
+ * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it
+ * reads are completely covered by their producers, etc.
+ */
+struct BlockInfo {
+  /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */
+  BlockScope scope{nullptr};
+  // The properties below are information about the current block realization under its parent scope
+  /*! \brief Property of a block, indicating the block realization binding is quasi-affine */
+  bool affine_binding{false};
+  /*!
+   * \brief Property of a block, indicating each of the block's read regions is fully
+   * produced by its producers
+   */
+  bool region_cover{false};
+
+  BlockInfo() = default;
+
+  explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
+      : scope(std::move(scope)),         //
+        affine_binding(affine_binding),  //
+        region_cover(region_cover) {}
+};
+
+/*!
+ * \brief The bitmask of the debug flag in the ScheduleStateNode.
+ * \sa ScheduleStateNode
+ */
+enum class ScheduleDebugMask : int32_t {
+  /*! \brief Verify the correctness of the sref tree */
+  kVerifySRefTree = 1,
+  /*! \brief Verify the correctness of affine_binding */
+  kVerifyAffineBinding = 2,
+  /*! \brief Verify the correctness of region_cover */
+  kVerifyRegionCover = 4,
+  /*! \brief Verify the correctness of stage_pipeline */
+  kVerifyStagePipeline = 8,
+};
+
+/*!
+ * \brief The state of scheduling, which exposes a `Replace` method as
+ * the primary resort for all the scheduling primitives to manipulate the TensorIR.

Review comment:
       `interface` is definitely a better word




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-810547269


   These are great points, although i think they are somewhat parallel to the data structure itself and have things to do with primitive implementations. 
   
   So we could try to make discussions in parallel with respect to this PR.
   
   In terms of the "root" name, given that we are uniquely identifying function already via the global names, an easy way is to just use function name in the module to obtain the root, which removes on concept here.
   
   The main Q for the block name uniqueness is about how to enforce them. For manual operations they certainly makes sense. For general automated transformations it might create an extra burden to introduce name tables or allocation mechanism. Since automated transformations rules works on a sub-region and may not be aware of the names from other parts. Due to that reason, allowing pointer uniqueness might still be a better approach. This also aligns with our existing approach to handle loop vars, which saves a lot of trouble during automatic transformations.
   
   This being said, we should be able to introduce canonicalization pass to uniquely rename block names. We can also add a flag in the Schedule to enforce such uniqueness if it is turned on
   
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608061682



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;

Review comment:
       Could we maybe use `optional` or something else here? I feel like sentinel values unless really needed for some efficiency reason are not a great design as it requires users to remember special values. 

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>

Review comment:
       It feels like we should maybe introduce a WeakObject for these use cases? I can see having to duplicate a lot of functionality

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest

Review comment:
       ```suggestion
    * - Parent sref: The parent reference of an sref is the block or loop reference to the closest schedulable statement.
         We define closest to be the nearest schedulable statement of an ancestor in the AST. 
   ```
   This could use some clarification. 

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
+  /*!
+   * \brief Property of a block scope root at the block, indicaiting if the scope is an equivalence

Review comment:
       ```suggestion
      * \brief This property indicates that the block scope (rooted at its corresponding block) is equivalent to
   ```

##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/state.h
+ * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling.
+ */
+#ifndef TVM_TIR_SCHEDULE_STATE_H_
+#define TVM_TIR_SCHEDULE_STATE_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/schedule/block_scope.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief The information about a TensorIR block, it contains two categories of information
+ * 1) Info on the block scope rooted at a specific block, including dependency tracking,
+ * flags indicating if the scope is a stage pipeline, etc.
+ * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it
+ * reads are completely covered by their producers, etc.
+ */
+struct BlockInfo {
+  /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */
+  BlockScope scope{nullptr};
+  // The properties below are information about the current block realization under its parent scope
+  /*! \brief Property of a block, indicating the block realization binding is quasi-affine */
+  bool affine_binding{false};
+  /*!
+   * \brief Property of a block, indicating each of the block's read regions is fully
+   * produced by its producers
+   */
+  bool region_cover{false};
+
+  BlockInfo() = default;
+
+  explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
+      : scope(std::move(scope)),         //
+        affine_binding(affine_binding),  //
+        region_cover(region_cover) {}
+};
+
+/*!
+ * \brief The bitmask of the debug flag in the ScheduleStateNode.
+ * \sa ScheduleStateNode
+ */
+enum class ScheduleDebugMask : int32_t {
+  /*! \brief Verify the correctness of the sref tree */
+  kVerifySRefTree = 1,
+  /*! \brief Verify the correctness of affine_binding */
+  kVerifyAffineBinding = 2,
+  /*! \brief Verify the correctness of region_cover */
+  kVerifyRegionCover = 4,
+  /*! \brief Verify the correctness of stage_pipeline */
+  kVerifyStagePipeline = 8,
+};
+
+/*!
+ * \brief The state of scheduling, which exposes a `Replace` method as
+ * the primary resort for all the scheduling primitives to manipulate the TensorIR.
+ *
+ * The data structure contains the following information
+ * 1) The AST being scheduled (mod)
+ * 2) The sref tree of schedulable statements (indicated by the srefs)
+ * 3) The dependency information of each block scope (block_info)
+ * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
+ * 5) A debug flag, if set, extra checking is enabled (debug_mode)
+ */
+class ScheduleStateNode : public Object {
+ public:
+  /*! \brief The AST of the module being scheduled */
+  IRModule mod;
+  /*!
+   * \brief Mapping from a block sref to its correpsonding BlockInfo,
+   * tracking the dependency inside the block scope,
+   * and storing necessary information flags for scheduling
+   */
+  std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual> block_info;
+  /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */
+  std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;
+  /*!
+   * \brief Do extra correctness checking after the class creation
+   * and each time after calling the Replace method.
+   * \sa ScheduleDebugMask
+   */
+  int debug_mode;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("mod", &mod);
+    // `block_info` is not visited
+    // `stmt2ref` is not visited
+    v->Visit("debug_mode", &debug_mode);
+  }
+  /*!
+   * \brief Replace the part of the AST, as being pointed to by `src_sref`,

Review comment:
       This is good example of a comment, thanks for this one!

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited

Review comment:
       Is this because of the weak pointer optimization? It isn't clear why I can't read these fields 

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to

Review comment:
       ```suggestion
      * \brief The block or `for` stmt the object refers to
   ```

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.

Review comment:
       Same comment as above, might be worth factoring this pattern out, I've seen this done multiple times. 

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,

Review comment:
       ```suggestion
    * \brief An object with 1-to-1 correspondence with each block reference in the sref tree. 
    * This data structure is used to track the producer-consumer dependencies between blocks. 
    * For example even leaf nodes have a scope node, even though they have no dependencies .
   ```

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by

Review comment:
       It might be good to refer to the `sref` as `reference` in the english as repeating sref multiple times makes it much harder to read from my PoV. 

##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/state.h
+ * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling.
+ */
+#ifndef TVM_TIR_SCHEDULE_STATE_H_
+#define TVM_TIR_SCHEDULE_STATE_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/schedule/block_scope.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief The information about a TensorIR block, it contains two categories of information
+ * 1) Info on the block scope rooted at a specific block, including dependency tracking,
+ * flags indicating if the scope is a stage pipeline, etc.
+ * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it
+ * reads are completely covered by their producers, etc.
+ */
+struct BlockInfo {
+  /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */
+  BlockScope scope{nullptr};
+  // The properties below are information about the current block realization under its parent scope
+  /*! \brief Property of a block, indicating the block realization binding is quasi-affine */
+  bool affine_binding{false};
+  /*!
+   * \brief Property of a block, indicating each of the block's read regions is fully
+   * produced by its producers
+   */
+  bool region_cover{false};
+
+  BlockInfo() = default;
+
+  explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
+      : scope(std::move(scope)),         //
+        affine_binding(affine_binding),  //
+        region_cover(region_cover) {}
+};
+
+/*!
+ * \brief The bitmask of the debug flag in the ScheduleStateNode.
+ * \sa ScheduleStateNode
+ */
+enum class ScheduleDebugMask : int32_t {
+  /*! \brief Verify the correctness of the sref tree */
+  kVerifySRefTree = 1,
+  /*! \brief Verify the correctness of affine_binding */
+  kVerifyAffineBinding = 2,
+  /*! \brief Verify the correctness of region_cover */
+  kVerifyRegionCover = 4,
+  /*! \brief Verify the correctness of stage_pipeline */
+  kVerifyStagePipeline = 8,
+};
+
+/*!
+ * \brief The state of scheduling, which exposes a `Replace` method as
+ * the primary resort for all the scheduling primitives to manipulate the TensorIR.

Review comment:
       ```suggestion
    * the primary interface for all the scheduling primitives to manipulate the TensorIR.
   ```

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
+  /*!
+   * \brief Property of a block scope root at the block, indicaiting if the scope is an equivalence
+   * of a stage pipeline. Conditions:

Review comment:
       ```suggestion
      * a TVM/Halide stage pipeline. Under the following conditions:
   ```

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.

Review comment:
       ```suggestion
    * - Block sref: A StmtSRef that points to a TensorIR block.
   ```

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.

Review comment:
       ```suggestion
   ```

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.

Review comment:
       ```suggestion
    * - Loop sref: A StmtSRef that points to a TensorIR for loop.
   ```

##########
File path: python/tvm/tir/schedule/__init__.py
##########
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+"""Namespace for TensorIR schedule."""

Review comment:
       ```suggestion
   """Namespace for the TensorIR schedule API."""
   ```

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
+  /*!
+   * \brief Property of a block scope root at the block, indicaiting if the scope is an equivalence
+   * of a stage pipeline. Conditions:
+   * 1) The region cover property holds for every of its child blocks
+   * 2) No write-after-read dependency
+   */
+  bool stage_pipeline{false};
+
+  void VisitAttrs(AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "tir.BlockScope";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object);
+
+ public:
+  /******** Dependency ********/
+  /*!
+   * \brief Get all dependencies whose `src` equals `src`
+   * \param src The queried block
+   * \return The dependencies
+   */
+  TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
+  /*!
+   * \brief Get all dependencies whose `dst` equals `dst`
+   * \param dst The queried block
+   * \return The dependencies
+   */
+  TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
+};
+
+/*!
+ * \brief Managed reference to BlockScopeNode
+ * \sa BlockScopeNode
+ */
+class BlockScope : public ObjectRef {
+ public:
+  /*! \brief The constructor creating an empty block scope with on dependency information */
+  TVM_DLL BlockScope();
+  /*!
+   * \brief Create the object with the specific leaf blocks, and complete the dependency information

Review comment:
       By complete do you mean, compute the dependency info then store it?

##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/state.h
+ * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling.
+ */
+#ifndef TVM_TIR_SCHEDULE_STATE_H_
+#define TVM_TIR_SCHEDULE_STATE_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/schedule/block_scope.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief The information about a TensorIR block, it contains two categories of information
+ * 1) Info on the block scope rooted at a specific block, including dependency tracking,
+ * flags indicating if the scope is a stage pipeline, etc.
+ * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it
+ * reads are completely covered by their producers, etc.
+ */
+struct BlockInfo {
+  /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */
+  BlockScope scope{nullptr};
+  // The properties below are information about the current block realization under its parent scope
+  /*! \brief Property of a block, indicating the block realization binding is quasi-affine */
+  bool affine_binding{false};
+  /*!
+   * \brief Property of a block, indicating each of the block's read regions is fully
+   * produced by its producers
+   */
+  bool region_cover{false};
+
+  BlockInfo() = default;
+
+  explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
+      : scope(std::move(scope)),         //
+        affine_binding(affine_binding),  //
+        region_cover(region_cover) {}
+};
+
+/*!
+ * \brief The bitmask of the debug flag in the ScheduleStateNode.
+ * \sa ScheduleStateNode
+ */
+enum class ScheduleDebugMask : int32_t {
+  /*! \brief Verify the correctness of the sref tree */
+  kVerifySRefTree = 1,
+  /*! \brief Verify the correctness of affine_binding */
+  kVerifyAffineBinding = 2,
+  /*! \brief Verify the correctness of region_cover */
+  kVerifyRegionCover = 4,
+  /*! \brief Verify the correctness of stage_pipeline */
+  kVerifyStagePipeline = 8,
+};
+
+/*!
+ * \brief The state of scheduling, which exposes a `Replace` method as
+ * the primary resort for all the scheduling primitives to manipulate the TensorIR.

Review comment:
       Not sure about this one

##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;

Review comment:
       > We had been trying with `tvm::Map<Buffer, Array<StmtSRef>>` in the very beginning, but it turned out that we need the values (the `Array<StmtSRef>`) of the map to be mutable to make sure they are maintained properly during transformations, but with `tvm::Map` we are unable to do so in an easy way :-( Therefore, we have to provide workarounds like providing APIs `get_deps_by_src` on the python side.
   
   might be worth writing this down for future people in a NB




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603798568



##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  const StmtNode* src_stmt_;
+  const Stmt& tgt_stmt_;
+  int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+                                const Map<Block, Block>& _block_sref_reuse) {
+  if (this->debug_mode != 0) {
+    const StmtNode* src_stmt = _src_sref->stmt;
+    bool input_correct =
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+        (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+    if (!input_correct) {
+      LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+                 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+                 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+                 << tgt_stmt;
+    }
+  }
+  // Rule out the case that no replacement happens
+  if (_src_sref->stmt == tgt_stmt.get()) {
+    return;
+  }
+  // Reset sref as a new sref so that its content won't be affected by subsequent changes
+  StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+  Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+  // Step 1. Create all the nodes needed for the new sref tree.
+  // After this step
+  // 1) all `parent`s are correct
+  // 2) all `seq_index`s are correct, except for the root
+  // 3) all `stmt`s are correct, except for the root
+  {
+    // Step 0. Setup block_sref_reuse
+    std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+    block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+    for (const auto& kv : _block_sref_reuse) {
+      block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+    }
+    // Step 1.1. Collect info for different kinds of reuses
+    // 1) intact
+    // 2) loop/block reuse
+    ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+    reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+    // Step 1.2. Collect loop/block reuse to their corresponding srefs
+    // and remove those srefs in the `src_stmt` that are no longer used after replacement
+    std::unordered_map<const Object*, StmtSRef> reused_srefs =
+        SRefTreePruner::Prune(this, reuse_info, src_stmt);
+    // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+    // srefs in `tgt_stmt`
+    SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+  }
+  // Step 2. Set the ancestors' children properly
+  //   Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+  //   The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+  //   Along the way, because we create a new ancestor path,
+  //   we need to update those sref points from old ancestors to newly created ones
+  // Variables:
+  // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+  //   can be mutated inplace, it needs `num_copy_steps + 1` hops.
+  // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+  // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+  int num_copy_steps = -1;
+  bool need_module_copy = false;
+  const PrimFuncNode* g_func = nullptr;
+  GlobalVar g_var;
+  {
+    int i = 0;
+    const StmtSRefNode* p = src_sref.get();
+    while (true) {
+      if (!p->stmt->unique()) {
+        num_copy_steps = i;
+      }
+      if (p->parent == nullptr) {
+        break;
+      }
+      ++i;
+      p = p->parent;
+    }
+    // Find `g_func` and `g_var` where the `src_sref` is in
+    g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+    need_module_copy = num_copy_steps == i ||             //
+                       !this->mod.unique() ||             //
+                       !this->mod->functions.unique() ||  //
+                       !g_func->unique();
+  }
+  // Loop invariant:
+  //
+  // Before step `i`:
+  // 1) `child_sref` is `src_sref` going up by `i` steps
+  // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+  // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+  // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+  // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+  //
+  // During step `i`:
+  // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+  // 2) Point `child_sref` to `child_tgt_stmt`
+  // 3) `tgt_stmt` is of type Loop or Block
+  StmtSRefNode* child_sref = src_sref.get();
+  Stmt child_tgt_stmt = std::move(tgt_stmt);
+  for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+    bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+    // Replace `child_sref->stmt` to `child_tgt_stmt`.
+    const StmtNode* parent_stmt = child_sref->parent->stmt;
+    const StmtNode* child_src_stmt = child_sref->stmt;
+    // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+    if (i == 0) {
+      // As the invariance of SRefUpdater,
+      // the `seq_index` of the root of `tgt_stmt` is set as -1,
+      // which might be incorrect
+      SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+    } else {
+      // Point `child_sref` to `child_tgt_stmt`
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+    Stmt new_parent_stmt =
+        ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+                               /*seq_index=*/child_sref->seq_index,
+                               /*allow_copy_on_write=*/can_directly_mutate_parent);
+    // Step 2.3. Go to next parent
+    if (can_directly_mutate_parent) {
+      // If the node can be directly mutated inplace,
+      // then there is no need to update its parent and the function
+      break;
+    }
+    child_tgt_stmt = std::move(new_parent_stmt);
+    child_sref = child_sref->parent;
+  }
+  // Step 3. Handle the case that we mutate the root
+  if (need_module_copy) {
+    // From the loop invariant, upon exit, while its subtree is properly set,
+    // `child_sref` is not properly to `child_tgt_stmt` yet.
+    if (src_sref->parent != nullptr) {
+      // Not replacing a root
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+    IRModuleNode* new_mod = this->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    // Move out the PrimFunc where the sref belong while ensuring uniqueness
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    ICHECK(ref_new_func.get() == g_func);
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    // If `g_func` was not unique, after the 3 lines above:
+    //   `ref_new_func` points to a unique PrimFunc
+    //   `g_func` points to the previous PrimFunc if it is not unique
+    // If `g_func` was unique, after the 3 lines above:
+    //   `ref_new_func` points to the same unique function that `g_func` points to
+    // Update the body of the function the sref belongs to Assign
+    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    // Make `child_tgt_stmt` the root block
+    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+    new_realize->block = GetRef<Block>(child_block);
+    new_func->body = BlockRealize(std::move(new_realize));
+    // Finally, move the `ref_new_func` back and update `this->mod`
+    new_map->at(g_var) = std::move(ref_new_func);
+    this->mod = GetRef<IRModule>(new_mod);
+  }
+  if (this->debug_mode & 1) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+  constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+  constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+  ICHECK_GE(debug_mode, -1);

Review comment:
       yeah a full mask (like `INT_MAX`) definitely works, but I would prefer -1 here to make the logic clear and invariant




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r604595202



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;

Review comment:
       We had been trying with `tvm::Map<Buffer, Array<StmtSRef>>` in the very beginning, but it turned out that we need the values (the `Array<StmtSRef>`) of the map to be mutable to make sure they are maintained properly during transformations, but with `tvm::Map` we are unable to do so in an easy way :-(  Therefore, we have to provide workarounds like providing APIs `get_deps_by_src` on the python side.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603696615



##########
File path: python/tvm/tir/schedule/block_scope.py
##########
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope."""
+from __future__ import annotations
+
+from enum import IntEnum
+from typing import List, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir import Block, For
+
+from . import _ffi_api_schedule
+
+
+@register_object("tir.StmtSRef")
+class StmtSRef(Object):
+    """An object that refers to schedulable elements in the TensorIR, aka "sref".
+
+    Glossary
+    - Block sref: An StmtSref that points to a TensorIR block.
+    - Loop sref: An StmtSRef that points to a TensorIR for loop.
+    - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+    schedulable statement of its ancestors on the TensorIR AST.
+    - Root sref: Sref to the root block. Every sref has exactly one parent sref
+    except for root sref.
+    - Sref tree: The parent-children-relationship of srefs that forms a tree,
+    uniquely determined by the TensorIR AST.
+    """
+
+    seq_index: int
+
+    @property
+    def stmt(self) -> Optional[Union[Block, For]]:
+        """The block/for stmt the object refers to"""
+        return _ffi_api_schedule.StmtSRefStmt(self)  # pylint: disable=no-member
+
+    @property
+    def parent(self) -> Optional[StmtSRef]:
+        """The parent sref"""
+        return _ffi_api_schedule.StmtSRefParent(self)  # pylint: disable=no-member
+
+    @staticmethod
+    def inline_mark() -> StmtSRef:
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do the work of compute-inline"""
+        return _ffi_api_schedule.StmtSRefInlineMark()  # pylint: disable=no-member
+
+    @staticmethod
+    def root_mark() -> StmtSRef:
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do nothing"""
+        return _ffi_api_schedule.StmtSRefRootMark()  # pylint: disable=no-member
+
+
+class DepKind(IntEnum):
+    """Type of dependency.
+
+    Attributes
+    ----------
+    RAW : int = 0
+        Read-after-write dependency
+    WAW : int = 1
+        Write-after-write dependency
+    WAR : int = 2
+        Write-after-read dependency. Not supported in TensorIR for now.
+    OPAQUE: int = 3
+        Opaque dependency
+    """
+
+    RAW = 0
+    WAW = 1
+    WAR = 2
+    OPAQUE = 3

Review comment:
       1. This is a slot we left for future "unknown" dependencies. Right now the codebase doesn't use this kind of dependency yet.
   2. Strictly speaking read-after-read is not a dependency, so we didn't have a slot for it right now




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608050396



##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  const StmtNode* src_stmt_;
+  const Stmt& tgt_stmt_;
+  int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+                                const Map<Block, Block>& _block_sref_reuse) {
+  if (this->debug_mode != 0) {
+    const StmtNode* src_stmt = _src_sref->stmt;
+    bool input_correct =
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+        (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+    if (!input_correct) {
+      LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+                 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+                 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+                 << tgt_stmt;
+    }
+  }
+  // Rule out the case that no replacement happens
+  if (_src_sref->stmt == tgt_stmt.get()) {
+    return;
+  }
+  // Reset sref as a new sref so that its content won't be affected by subsequent changes
+  StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+  Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+  // Step 1. Create all the nodes needed for the new sref tree.
+  // After this step
+  // 1) all `parent`s are correct
+  // 2) all `seq_index`s are correct, except for the root
+  // 3) all `stmt`s are correct, except for the root
+  {
+    // Step 0. Setup block_sref_reuse
+    std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+    block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+    for (const auto& kv : _block_sref_reuse) {
+      block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+    }
+    // Step 1.1. Collect info for different kinds of reuses
+    // 1) intact
+    // 2) loop/block reuse
+    ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+    reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+    // Step 1.2. Collect loop/block reuse to their corresponding srefs
+    // and remove those srefs in the `src_stmt` that are no longer used after replacement
+    std::unordered_map<const Object*, StmtSRef> reused_srefs =
+        SRefTreePruner::Prune(this, reuse_info, src_stmt);
+    // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+    // srefs in `tgt_stmt`
+    SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+  }
+  // Step 2. Set the ancestors' children properly
+  //   Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+  //   The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+  //   Along the way, because we create a new ancestor path,
+  //   we need to update those sref points from old ancestors to newly created ones
+  // Variables:
+  // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+  //   can be mutated inplace, it needs `num_copy_steps + 1` hops.
+  // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+  // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+  int num_copy_steps = -1;
+  bool need_module_copy = false;
+  const PrimFuncNode* g_func = nullptr;
+  GlobalVar g_var;
+  {
+    int i = 0;
+    const StmtSRefNode* p = src_sref.get();
+    while (true) {
+      if (!p->stmt->unique()) {
+        num_copy_steps = i;
+      }
+      if (p->parent == nullptr) {
+        break;
+      }
+      ++i;
+      p = p->parent;
+    }
+    // Find `g_func` and `g_var` where the `src_sref` is in
+    g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+    need_module_copy = num_copy_steps == i ||             //
+                       !this->mod.unique() ||             //
+                       !this->mod->functions.unique() ||  //
+                       !g_func->unique();
+  }
+  // Loop invariant:
+  //
+  // Before step `i`:
+  // 1) `child_sref` is `src_sref` going up by `i` steps
+  // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+  // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+  // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+  // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+  //
+  // During step `i`:
+  // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+  // 2) Point `child_sref` to `child_tgt_stmt`
+  // 3) `tgt_stmt` is of type Loop or Block
+  StmtSRefNode* child_sref = src_sref.get();
+  Stmt child_tgt_stmt = std::move(tgt_stmt);
+  for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+    bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+    // Replace `child_sref->stmt` to `child_tgt_stmt`.
+    const StmtNode* parent_stmt = child_sref->parent->stmt;
+    const StmtNode* child_src_stmt = child_sref->stmt;
+    // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+    if (i == 0) {
+      // As the invariance of SRefUpdater,
+      // the `seq_index` of the root of `tgt_stmt` is set as -1,
+      // which might be incorrect
+      SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+    } else {
+      // Point `child_sref` to `child_tgt_stmt`
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+    Stmt new_parent_stmt =
+        ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+                               /*seq_index=*/child_sref->seq_index,
+                               /*allow_copy_on_write=*/can_directly_mutate_parent);
+    // Step 2.3. Go to next parent
+    if (can_directly_mutate_parent) {
+      // If the node can be directly mutated inplace,
+      // then there is no need to update its parent and the function
+      break;
+    }
+    child_tgt_stmt = std::move(new_parent_stmt);
+    child_sref = child_sref->parent;
+  }
+  // Step 3. Handle the case that we mutate the root
+  if (need_module_copy) {
+    // From the loop invariant, upon exit, while its subtree is properly set,
+    // `child_sref` is not properly to `child_tgt_stmt` yet.
+    if (src_sref->parent != nullptr) {
+      // Not replacing a root
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+    IRModuleNode* new_mod = this->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    // Move out the PrimFunc where the sref belong while ensuring uniqueness
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    ICHECK(ref_new_func.get() == g_func);
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    // If `g_func` was not unique, after the 3 lines above:
+    //   `ref_new_func` points to a unique PrimFunc
+    //   `g_func` points to the previous PrimFunc if it is not unique
+    // If `g_func` was unique, after the 3 lines above:
+    //   `ref_new_func` points to the same unique function that `g_func` points to
+    // Update the body of the function the sref belongs to Assign
+    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    // Make `child_tgt_stmt` the root block
+    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+    new_realize->block = GetRef<Block>(child_block);
+    new_func->body = BlockRealize(std::move(new_realize));
+    // Finally, move the `ref_new_func` back and update `this->mod`
+    new_map->at(g_var) = std::move(ref_new_func);
+    this->mod = GetRef<IRModule>(new_mod);
+  }
+  if (this->debug_mode & 1) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+  constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+  constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+  ICHECK_GE(debug_mode, -1);

Review comment:
       Since Python doesn't have unsigned int, changing to unsigned int will result in inconsistency between Python and C++. Plus this debug_mode won't be increased to a large number in the future AFAIK, I'll take the current solution.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen merged pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #7765:
URL: https://github.com/apache/tvm/pull/7765


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r604595202



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;

Review comment:
       We had been trying with `tvm::Map<Buffer, Array<StmtSRef>>` in the very beginning, but it turned out that we need the values (the `Array<StmtSRef>`) of the map to be mutable so that we can properly maintain them during transformations, but with `tvm::Map` we are unable to do so in an easy way :-(




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-809856737


   Would like to get a chance to read this throughly, put some time on calendar to do it tomorrow. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608115357



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
+  /*!
+   * \brief Property of a block scope root at the block, indicaiting if the scope is an equivalence
+   * of a stage pipeline. Conditions:

Review comment:
       given that the stage definition can evolve and tensorIR generalizes over the original TVM, maybe we can avoid directly referring to the existing definition, but instead use the conditions as source of truth




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 edited a comment on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 edited a comment on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-810515117


   Although out of the scope of this PR, I am really glad that we have the discussion about block names.
   
   @comaniac brought up the point https://github.com/apache/tvm/pull/7765#discussion_r603657740:
   
   > This function makes me think that we should make root as a preserved block name, and we should not allow duplicated block names in every tree of a PrimFunc.
   
   I kinda agree with Cody about his points, but would love to hear more discussion on the block name. Particularly, we have three points to discuss:
   - A1. Block names need to be unique. The reason is that the canonical way of retrieving a block is to use its name, i.e. `schedule.get_block(name)`. Without a unique name, we are unable to even retrieve a block, which makes scheduling almost impossible. (of course, it is possible to retrieve a block by the buffer it produces or via a statement, but it is not the canonical way)
   - A2. We need reserved names for the root block. I am kinda in favor of this idea too, because we do provide syntactic sugar to auto complete the root block with the name "root". This could help us eliminate possible name conflicts.
   - A3. Users could specify the names of newly created blocks/loops. Yes, it is doable when implementing schedule primitives.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-814307970


   Per offline discussion with @comaniac:
   
   More documentation on `InlineMark` and `RootMark` is desirable. Especially we should mention that they are only used in `ComputeAt`/`ReverseComputeAt` to change the compute-at to compute-inline/no-op. This will be done as we upstreaming the schedule class.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r604260468



##########
File path: src/tir/schedule/utils.h
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ * \note The `E` in the macro means `error`, which means allowing to customize error message
+ */
+#define TVM_SREF_TO_E(Result, SRef, Type) \

Review comment:
       The reasons you illustrated make sense to me, although I still don't like the coding style. It seems like extracting a common piese of _expression_ to be a macro. Anyways, as you mentioned, this macro is an internal utility which is more like just a helper, so I'm not strongly against it.
   
   For the naming, probably `TVM_SREF_AS_OR_ERR` would be more straightforward? "E" is not a proper and common abbrevation of "Error" so I don't think it is informative.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608137176



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by

Review comment:
       "sref" is a term we made up to represent our particular objects that reference the AST statements, and they have specific properties (like they could form a tree, a block scope, etc) and are used specifically in internal state manipulation in scheduling, so I think "sref" is appropriate in this case. What do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-814860499


   Thanks @junrushao1994 for keep improving the PR. 
   Thanks @jroesch @comaniac @jcf94 @Hzfengsy @MasterJH5574 for reviewing.  This PR is merged..We can also followup with more PRs to add additional clarifications when we see future needs


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen edited a comment on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-810547269


   Thanks @junrushao1994 @comaniac These are great points, although i think they are somewhat parallel to the data structure itself and have things to do with primitive implementations. 
   
   So we could try to make discussions in parallel with respect to this PR.
   
   In terms of the "root" name, given that we are uniquely identifying function already via the global names, an easy way is to just use function name in the module to obtain the root, which removes on concept here.
   
   The main Q for the block name uniqueness is about how to enforce them. For manual operations they certainly makes sense. For general automated transformations it might create an extra burden to introduce name tables or allocation mechanism. Since automated transformations rules works on a sub-region and may not be aware of the names from other parts. Due to that reason, allowing pointer uniqueness might still be a better approach. This also aligns with our existing approach to handle loop vars, which saves a lot of trouble during automatic transformations.
   
   This being said, we should be able to introduce canonicalization pass to uniquely rename block names. We can also add a flag in the Schedule to enforce such uniqueness if it is turned on
   
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-810559637


   > Thanks @junrushao1994 @comaniac These are great points, although i think they are somewhat parallel to the data structure itself and have things to do with primitive implementations.
   
   Make sense. I'm fine with a follow-up PR to implement the result of this discussion.
   
   > 
   > So we could try to make discussions in parallel with respect to this PR.
   > 
   > In terms of the "root" name, given that we are uniquely identifying function already via the global names, an easy way is to just use function name in the module to obtain the root, which removes on concept here.
   
   This is also a good point. IMHO, as long as the interface makes sense to schedule primitive developers, it should be fine.
   
   > 
   > The main Q for the block name uniqueness is about how to enforce them. For manual operations they certainly makes sense. For general automated transformations it might create an extra burden to introduce name tables or allocation mechanism. Since automated transformations rules works on a sub-region and may not be aware of the names from other parts. Due to that reason, allowing pointer uniqueness might still be a better approach. This also aligns with our existing approach to handle loop vars, which saves a lot of trouble during automatic transformations.
   > 
   
   It makes sense to use unique pointers in the automation framework. One thing I would like to highlight is that even we leverage unique pointer to access blocks and don't have to worry about their names during optimization, it might still be worthwhile to maintain block name uniqueness. The reason is, IIUC, we will have a mechanism to print out the schedule in Python format for debugging and investigation. In the printed schedule, block name will be the only referenced.
   
   > This being said, we should be able to introduce canonicalization pass to uniquely rename block names. We can also add a flag in the Schedule to enforce such uniqueness if it is turned on
   
   Exactly. Calling a canonicalization pass before printing out the schedule could also solve the issue I mentioned above.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608137987



##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/state.h
+ * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling.
+ */
+#ifndef TVM_TIR_SCHEDULE_STATE_H_
+#define TVM_TIR_SCHEDULE_STATE_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/schedule/block_scope.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief The information about a TensorIR block, it contains two categories of information
+ * 1) Info on the block scope rooted at a specific block, including dependency tracking,
+ * flags indicating if the scope is a stage pipeline, etc.
+ * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it
+ * reads are completely covered by their producers, etc.
+ */
+struct BlockInfo {
+  /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */
+  BlockScope scope{nullptr};
+  // The properties below are information about the current block realization under its parent scope
+  /*! \brief Property of a block, indicating the block realization binding is quasi-affine */
+  bool affine_binding{false};
+  /*!
+   * \brief Property of a block, indicating each of the block's read regions is fully
+   * produced by its producers
+   */
+  bool region_cover{false};
+
+  BlockInfo() = default;
+
+  explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
+      : scope(std::move(scope)),         //
+        affine_binding(affine_binding),  //
+        region_cover(region_cover) {}
+};
+
+/*!
+ * \brief The bitmask of the debug flag in the ScheduleStateNode.
+ * \sa ScheduleStateNode
+ */
+enum class ScheduleDebugMask : int32_t {
+  /*! \brief Verify the correctness of the sref tree */
+  kVerifySRefTree = 1,
+  /*! \brief Verify the correctness of affine_binding */
+  kVerifyAffineBinding = 2,
+  /*! \brief Verify the correctness of region_cover */
+  kVerifyRegionCover = 4,
+  /*! \brief Verify the correctness of stage_pipeline */
+  kVerifyStagePipeline = 8,
+};
+
+/*!
+ * \brief The state of scheduling, which exposes a `Replace` method as
+ * the primary resort for all the scheduling primitives to manipulate the TensorIR.
+ *
+ * The data structure contains the following information
+ * 1) The AST being scheduled (mod)
+ * 2) The sref tree of schedulable statements (indicated by the srefs)
+ * 3) The dependency information of each block scope (block_info)
+ * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
+ * 5) A debug flag, if set, extra checking is enabled (debug_mode)
+ */
+class ScheduleStateNode : public Object {
+ public:
+  /*! \brief The AST of the module being scheduled */
+  IRModule mod;
+  /*!
+   * \brief Mapping from a block sref to its correpsonding BlockInfo,
+   * tracking the dependency inside the block scope,
+   * and storing necessary information flags for scheduling
+   */
+  std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual> block_info;
+  /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */
+  std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;
+  /*!
+   * \brief Do extra correctness checking after the class creation
+   * and each time after calling the Replace method.
+   * \sa ScheduleDebugMask
+   */
+  int debug_mode;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("mod", &mod);
+    // `block_info` is not visited
+    // `stmt2ref` is not visited
+    v->Visit("debug_mode", &debug_mode);
+  }
+  /*!
+   * \brief Replace the part of the AST, as being pointed to by `src_sref`,

Review comment:
       Haha yeah, we spend a lot of time on this API because it is the core one that all schedule primitives will use




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608137649



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;

Review comment:
       Yeah I will add this to "\note"




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 edited a comment on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 edited a comment on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-813120191


   Hey would you guys take another look? Thanks a lot! @comaniac @jcf94 @MasterJH5574 @jroesch @tqchen 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603804112



##########
File path: src/tir/schedule/utils.h
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ * \note The `E` in the macro means `error`, which means allowing to customize error message
+ */
+#define TVM_SREF_TO_E(Result, SRef, Type) \

Review comment:
       The macro with "E" suffix means the error message is customizable, and we expose it only for flexibility. It is rarely used in the codebase, and given it is an internal util macro, and being well documented, I think it is fine to keep it here, and better names are definitely welcome :-)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r604594379



##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  const StmtNode* src_stmt_;
+  const Stmt& tgt_stmt_;
+  int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+                                const Map<Block, Block>& _block_sref_reuse) {
+  if (this->debug_mode != 0) {
+    const StmtNode* src_stmt = _src_sref->stmt;
+    bool input_correct =
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+        (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+    if (!input_correct) {
+      LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+                 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+                 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+                 << tgt_stmt;
+    }
+  }
+  // Rule out the case that no replacement happens
+  if (_src_sref->stmt == tgt_stmt.get()) {
+    return;
+  }
+  // Reset sref as a new sref so that its content won't be affected by subsequent changes
+  StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+  Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+  // Step 1. Create all the nodes needed for the new sref tree.
+  // After this step
+  // 1) all `parent`s are correct
+  // 2) all `seq_index`s are correct, except for the root
+  // 3) all `stmt`s are correct, except for the root
+  {
+    // Step 0. Setup block_sref_reuse
+    std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+    block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+    for (const auto& kv : _block_sref_reuse) {
+      block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+    }
+    // Step 1.1. Collect info for different kinds of reuses
+    // 1) intact
+    // 2) loop/block reuse
+    ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+    reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+    // Step 1.2. Collect loop/block reuse to their corresponding srefs
+    // and remove those srefs in the `src_stmt` that are no longer used after replacement
+    std::unordered_map<const Object*, StmtSRef> reused_srefs =
+        SRefTreePruner::Prune(this, reuse_info, src_stmt);
+    // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+    // srefs in `tgt_stmt`
+    SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+  }
+  // Step 2. Set the ancestors' children properly
+  //   Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+  //   The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+  //   Along the way, because we create a new ancestor path,
+  //   we need to update those sref points from old ancestors to newly created ones
+  // Variables:
+  // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+  //   can be mutated inplace, it needs `num_copy_steps + 1` hops.
+  // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+  // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+  int num_copy_steps = -1;
+  bool need_module_copy = false;
+  const PrimFuncNode* g_func = nullptr;
+  GlobalVar g_var;
+  {
+    int i = 0;
+    const StmtSRefNode* p = src_sref.get();
+    while (true) {
+      if (!p->stmt->unique()) {
+        num_copy_steps = i;
+      }
+      if (p->parent == nullptr) {
+        break;
+      }
+      ++i;
+      p = p->parent;
+    }
+    // Find `g_func` and `g_var` where the `src_sref` is in
+    g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+    need_module_copy = num_copy_steps == i ||             //
+                       !this->mod.unique() ||             //
+                       !this->mod->functions.unique() ||  //
+                       !g_func->unique();
+  }
+  // Loop invariant:
+  //
+  // Before step `i`:
+  // 1) `child_sref` is `src_sref` going up by `i` steps
+  // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+  // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+  // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+  // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+  //
+  // During step `i`:
+  // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+  // 2) Point `child_sref` to `child_tgt_stmt`
+  // 3) `tgt_stmt` is of type Loop or Block
+  StmtSRefNode* child_sref = src_sref.get();
+  Stmt child_tgt_stmt = std::move(tgt_stmt);
+  for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+    bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+    // Replace `child_sref->stmt` to `child_tgt_stmt`.
+    const StmtNode* parent_stmt = child_sref->parent->stmt;
+    const StmtNode* child_src_stmt = child_sref->stmt;
+    // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+    if (i == 0) {
+      // As the invariance of SRefUpdater,
+      // the `seq_index` of the root of `tgt_stmt` is set as -1,
+      // which might be incorrect
+      SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+    } else {
+      // Point `child_sref` to `child_tgt_stmt`
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+    Stmt new_parent_stmt =
+        ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+                               /*seq_index=*/child_sref->seq_index,
+                               /*allow_copy_on_write=*/can_directly_mutate_parent);
+    // Step 2.3. Go to next parent
+    if (can_directly_mutate_parent) {
+      // If the node can be directly mutated inplace,
+      // then there is no need to update its parent and the function
+      break;
+    }
+    child_tgt_stmt = std::move(new_parent_stmt);
+    child_sref = child_sref->parent;
+  }
+  // Step 3. Handle the case that we mutate the root
+  if (need_module_copy) {
+    // From the loop invariant, upon exit, while its subtree is properly set,
+    // `child_sref` is not properly to `child_tgt_stmt` yet.
+    if (src_sref->parent != nullptr) {
+      // Not replacing a root
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+    IRModuleNode* new_mod = this->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    // Move out the PrimFunc where the sref belong while ensuring uniqueness
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    ICHECK(ref_new_func.get() == g_func);
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    // If `g_func` was not unique, after the 3 lines above:
+    //   `ref_new_func` points to a unique PrimFunc
+    //   `g_func` points to the previous PrimFunc if it is not unique
+    // If `g_func` was unique, after the 3 lines above:
+    //   `ref_new_func` points to the same unique function that `g_func` points to
+    // Update the body of the function the sref belongs to Assign
+    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    // Make `child_tgt_stmt` the root block
+    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+    new_realize->block = GetRef<Block>(child_block);
+    new_func->body = BlockRealize(std::move(new_realize));
+    // Finally, move the `ref_new_func` back and update `this->mod`
+    new_map->at(g_var) = std::move(ref_new_func);
+    this->mod = GetRef<IRModule>(new_mod);
+  }
+  if (this->debug_mode & 1) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+  constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+  constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+  ICHECK_GE(debug_mode, -1);

Review comment:
       Changing to unsigned integers makes sense to me, and we could use like (11...11)_2 for the full mask. What I was worrying about is passing unsigned integers around the TVM FFI would cause potential issues




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603697340



##########
File path: src/tir/schedule/block_scope.cc
##########
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Utility functions ********/
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/*!
+ * \brief Add a dependency relation.
+ * \param src The source of the dependency
+ * \param dst The destination of the dependecy
+ * \param kind Type of the dependency
+ * \note This method is effectively NOP on self-loops
+ */
+void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
+  if (!src.same_as(dst)) {
+    Dependency dep(src, dst, kind);
+    self->src2deps[src].push_back(dep);
+    self->dst2deps[dst].push_back(dep);
+  }
+}
+
+/******** Constructors ********/
+
+StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) {
+  ObjectPtr<StmtSRefNode> n = make_object<StmtSRefNode>();
+  n->stmt = stmt;
+  n->parent = parent;
+  n->seq_index = seq_index;
+  data_ = std::move(n);
+}
+
+StmtSRef StmtSRef::InlineMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}
+
+StmtSRef StmtSRef::RootMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}

Review comment:
       It is a "mark" that doesn't refer to any loops. It is a trick that if `compute_at` sees the inline-mark, then it turns itself into `compute_inline`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r604555466



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object corresponds to each block sref in the sref tree,
+ * which tracks the producer-consumer dependency between blocks.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*! \brief Lookup table for the `src` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;

Review comment:
       Why not to use `tvm::Map` here instead of `std::unordered_map`? Then these members can be visited.

##########
File path: src/tir/schedule/state.cc
##########
@@ -0,0 +1,863 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  const StmtNode* src_stmt_;
+  const Stmt& tgt_stmt_;
+  int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+                                const Map<Block, Block>& _block_sref_reuse) {
+  if (this->debug_mode != 0) {
+    const StmtNode* src_stmt = _src_sref->stmt;
+    bool input_correct =
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+        (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+    if (!input_correct) {
+      LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+                 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+                 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+                 << tgt_stmt;
+    }
+  }
+  // Rule out the case that no replacement happens
+  if (_src_sref->stmt == tgt_stmt.get()) {
+    return;
+  }
+  // Reset sref as a new sref so that its content won't be affected by subsequent changes
+  StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+  Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+  // Step 1. Create all the nodes needed for the new sref tree.
+  // After this step
+  // 1) all `parent`s are correct
+  // 2) all `seq_index`s are correct, except for the root
+  // 3) all `stmt`s are correct, except for the root
+  {
+    // Step 0. Setup block_sref_reuse
+    std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+    block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+    for (const auto& kv : _block_sref_reuse) {
+      block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+    }
+    // Step 1.1. Collect info for different kinds of reuses
+    // 1) intact
+    // 2) loop/block reuse
+    ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+    reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+    // Step 1.2. Collect loop/block reuse to their corresponding srefs
+    // and remove those srefs in the `src_stmt` that are no longer used after replacement
+    std::unordered_map<const Object*, StmtSRef> reused_srefs =
+        SRefTreePruner::Prune(this, reuse_info, src_stmt);
+    // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+    // srefs in `tgt_stmt`
+    SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+  }
+  // Step 2. Set the ancestors' children properly
+  //   Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+  //   The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+  //   Along the way, because we create a new ancestor path,
+  //   we need to update those sref points from old ancestors to newly created ones
+  // Variables:
+  // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+  //   can be mutated inplace, it needs `num_copy_steps + 1` hops.
+  // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+  // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+  int num_copy_steps = -1;
+  bool need_module_copy = false;
+  const PrimFuncNode* g_func = nullptr;
+  GlobalVar g_var;
+  {
+    int i = 0;
+    const StmtSRefNode* p = src_sref.get();
+    while (true) {
+      if (!p->stmt->unique()) {
+        num_copy_steps = i;
+      }
+      if (p->parent == nullptr) {
+        break;
+      }
+      ++i;
+      p = p->parent;
+    }
+    // Find `g_func` and `g_var` where the `src_sref` is in
+    g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+    need_module_copy = num_copy_steps == i ||             //
+                       !this->mod.unique() ||             //
+                       !this->mod->functions.unique() ||  //
+                       !g_func->unique();
+  }
+  // Loop invariant:
+  //
+  // Before step `i`:
+  // 1) `child_sref` is `src_sref` going up by `i` steps
+  // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+  // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+  // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+  // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+  //
+  // During step `i`:
+  // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+  // 2) Point `child_sref` to `child_tgt_stmt`
+  // 3) `tgt_stmt` is of type Loop or Block
+  StmtSRefNode* child_sref = src_sref.get();
+  Stmt child_tgt_stmt = std::move(tgt_stmt);
+  for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+    bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+    // Replace `child_sref->stmt` to `child_tgt_stmt`.
+    const StmtNode* parent_stmt = child_sref->parent->stmt;
+    const StmtNode* child_src_stmt = child_sref->stmt;
+    // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+    if (i == 0) {
+      // As the invariance of SRefUpdater,
+      // the `seq_index` of the root of `tgt_stmt` is set as -1,
+      // which might be incorrect
+      SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+    } else {
+      // Point `child_sref` to `child_tgt_stmt`
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+    Stmt new_parent_stmt =
+        ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+                               /*seq_index=*/child_sref->seq_index,
+                               /*allow_copy_on_write=*/can_directly_mutate_parent);
+    // Step 2.3. Go to next parent
+    if (can_directly_mutate_parent) {
+      // If the node can be directly mutated inplace,
+      // then there is no need to update its parent and the function
+      break;
+    }
+    child_tgt_stmt = std::move(new_parent_stmt);
+    child_sref = child_sref->parent;
+  }
+  // Step 3. Handle the case that we mutate the root
+  if (need_module_copy) {
+    // From the loop invariant, upon exit, while its subtree is properly set,
+    // `child_sref` is not properly to `child_tgt_stmt` yet.
+    if (src_sref->parent != nullptr) {
+      // Not replacing a root
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+    IRModuleNode* new_mod = this->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    // Move out the PrimFunc where the sref belong while ensuring uniqueness
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    ICHECK(ref_new_func.get() == g_func);
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    // If `g_func` was not unique, after the 3 lines above:
+    //   `ref_new_func` points to a unique PrimFunc
+    //   `g_func` points to the previous PrimFunc if it is not unique
+    // If `g_func` was unique, after the 3 lines above:
+    //   `ref_new_func` points to the same unique function that `g_func` points to
+    // Update the body of the function the sref belongs to Assign
+    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    // Make `child_tgt_stmt` the root block
+    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+    new_realize->block = GetRef<Block>(child_block);
+    new_func->body = BlockRealize(std::move(new_realize));
+    // Finally, move the `ref_new_func` back and update `this->mod`
+    new_map->at(g_var) = std::move(ref_new_func);
+    this->mod = GetRef<IRModule>(new_mod);
+  }
+  if (this->debug_mode & 1) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+  constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+  constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+  ICHECK_GE(debug_mode, -1);

Review comment:
       ... In my experience, I'm thinking that it will be better to use unsigned int as bit masks. 😄 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603803031



##########
File path: src/tir/schedule/utils.h
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ * \note The `E` in the macro means `error`, which means allowing to customize error message
+ */
+#define TVM_SREF_TO_E(Result, SRef, Type) \

Review comment:
       Inlined functions are generally good, but I am not in favor of it in this particular case -  I considered this before but didn't go with that idea, and here is the reason:
   
   - When an error occurs, we want to print the exact line/function/file that throws that error: if we use an inline function, then instead of rendering the caller, it throws in the inline function in utils.h, which is much less informative.
   - The caller should be responsible for writing the declaration of the variable. Comparing the following two, I would go for the first one, because it writes the type clearly, allows re-assignment of a variable, and makes it really clear what we are doing.
   - The only disadvantage is that we need to repeat the name "block" twice, which is not quite inconvenient IMO.
   
   ```C++
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   // compared with 
   TVM_SREF_TO_BLOCK(block, block_sref);
   ```
   
   A good alternative I considered before is to use inlined lambda function which expands like:
   
   ```C++
   const auto* block = [&]() -> const BlockNode* {
     const BlockNode* stmt = sref->StmtAs<BlockNode>();
     ICHECK(stmt != nullptr) << "Error Message";
     return stmt;
   }();
   ```
   
   The disadvantage of this approach is that the error message is not customizable, and it really depends on the compiler to optimize the lambda out.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#issuecomment-810576037


   Per discussion with @tqchen.
   
   A1. I agree that enforcing name uniqueness in scheduling is important and it is something we should do. We also recognize the name uniqueness is a bit misleading and not super useful in subsequent IR passes. Therefore, we want to divide the problems in two steps:
   - Scheduling: Require name uniqueness - we can keep a table in the schedule class.
   - Passes after scheduling: Don't require name uniqueness.
   
   A2. Reserve names for the root block: Yes, we should do that. We have two proposals:
   - A2.1. Use "root" as the reserved name for the block
   - A2.2. Use the PrimFunc's name in the IRModule, e..g "main", as the reserved name for the block
   
   A3. Yes, we want to enable users who call scheduling primitives to specify the names of the blocks. Particularly, we want to hear some further discussions on the user experience should look like. Here is our proposal:
   - A3.1. Error out when user provide a duplicate name.
   - A3.2. If the name string the user provided is suffixed with "*", e.g. "unique_name*", then our system will find a unique name whose prefix is "unique_name" and doesn't conflict with other names.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r603803031



##########
File path: src/tir/schedule/utils.h
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ * \note The `E` in the macro means `error`, which means allowing to customize error message
+ */
+#define TVM_SREF_TO_E(Result, SRef, Type) \

Review comment:
       Inlined functions are generally good, but I am not in favor of this particularly in this case. I considered this before but didn't go with that idea. Here is the reason:
   
   - When an error occurs, we want to print the exact line/function/file that throws that error: if we use an inline function, then instead of rendering the caller, it throws in the inline function in utils.h, which is much less informative.
   - The caller should be responsible for writing the declaration of the variable. Comparing the following two, I would go for the first one, because it writes the type clearly, allows re-assignment of a variable, and makes it really clear what we are doing.
   - The only disadvantage is that we need to repeat the name "block" twice, which is not quite inconvenient IMO.
   
   ```C++
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   // compared with 
   TVM_SREF_TO_BLOCK(block, block_sref);
   ```
   
   A good alternative I considered before is to use inlined lambda function which expands like:
   
   ```C++
   const auto* block = [&]() -> const BlockNode* {
     const BlockNode* stmt = sref->StmtAs<BlockNode>();
     ICHECK(stmt != nullptr) << "Error Message";
     return stmt;
   }();
   ```
   
   The disadvantage of this approach is that the error message is not customizable, and it really depends on the compiler to optimize the lambda out.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7765: [M1b] Scaffolding ScheduleState data structure

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7765:
URL: https://github.com/apache/tvm/pull/7765#discussion_r608110217



##########
File path: include/tvm/tir/schedule/block_scope.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: An StmtSRef that points to a TensorIR block.
+ * - Loop sref: An StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block/for stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited

Review comment:
       Oh we don't want to visit the weak references in the visitors, because those void pointers are less meaningful on the python side. Instead, we provide FFI functions that return strong references: see [block_scope.cc:144-151](https://github.com/apache/tvm/pull/7765/files#diff-32dfb07672aaa02e5e57bae323fb938ddb88db0f4cd6bda0f84a20299d4cf5c0R144-R151)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org