You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/07 20:13:46 UTC

[GitHub] [tvm] vinx13 commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

vinx13 commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r965151002


##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -601,9 +601,11 @@ class ScheduleNode : public runtime::Object {
    * \param buffer_index The index of the buffer in block's read or write region.
    * \param buffer_index_type The type of the buffer index, kRead or kWrite.
    * \param index_map The transformation to apply.
+   * \param pad_value The value to write into padding introduced by the transformation.

Review Comment:
   When `pad_value` is incorrect this can affect the correctness of the program. Would be great to explicitly mention this.



##########
python/tvm/tir/schedule/schedule.py:
##########
@@ -2479,6 +2480,20 @@ def transform_layout(
             primitive will be called in addition to the
             TransformLayout primitive.
 
+        pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]
+
+            The value to be used for any padding introduced by the
+            transformation.
+
+            If None, the transformation may not introduce padding.
+
+            If an int, float or PrimExpr, the transformation is the
+            specific value to be present in the padding.
+
+            If an IndexMap or Callable, the transformation is the
+            value to be present in the padding in terms of the
+            transformed index.

Review Comment:
   cpp side only accepts `Optional[PrimExpr]`, seems this is not supported?



##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -16,12 +16,580 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+#include <optional>
+#include <variant>
+
 #include "../../../arith/ir_mutator_with_analyzer.h"
 #include "../utils.h"
 
 namespace tvm {
 namespace tir {
 
+class LayoutTransformPlanner : private StmtExprVisitor {
+ public:
+  // Statement to be inserted prior to the analyzed block
+  struct ProloguePlan {
+    Stmt prologue;
+  };
+
+  // Loops within the analyzed block that should be replaced
+  struct ReplacementPlan {
+    Map<For, Stmt> replacements;
+    Map<Block, Block> block_sref_reuse;
+  };
+
+  // The block to be inserted, along with the location at which it
+  // should be inserted.  The location will be either a For or a
+  // Block, and will be after all writes the transformed buffer.
+  struct EpiloguePlan {
+    Stmt insert_after;
+    Stmt new_block;
+  };
+
+  struct NoPaddingRequired {};
+
+  using TransformPlan =
+      std::variant<ProloguePlan, ReplacementPlan, EpiloguePlan, NoPaddingRequired>;
+
+  static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map,
+                            IndexMap inverse, PrimExpr padding_predicate,
+                            Optional<PrimExpr> pad_value) {
+    LayoutTransformPlanner visitor(old_buffer);
+    visitor(block);
+    return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value);
+  }
+
+ private:
+  explicit LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {}
+
+  void VisitStmt_(const ForNode* op) override {
+    BindLoopVar context(this, GetRef<For>(op));
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const LetStmtNode* op) override {
+    BindLetVar context(this, op->var, op->value);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const BlockRealizeNode* op) override {
+    BindBlockRealize context(this, GetRef<BlockRealize>(op));
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) override {
+    if (!op->buffer.same_as(old_buffer_)) {
+      return;
+    }
+
+    std::optional<std::pair<size_t, size_t>> loop_dependency_range = std::nullopt;
+    for (const auto& index : op->indices) {
+      if (auto index_depth = LoopDependencyRange(index); index_depth.has_value()) {
+        if (loop_dependency_range) {
+          loop_dependency_range = {
+              std::min(loop_dependency_range.value().first, index_depth.value().first),
+              std::max(loop_dependency_range.value().second, index_depth.value().second)};
+        } else {
+          loop_dependency_range = index_depth;
+        }
+      }
+    }
+
+    WriteInfo write_info;
+    write_info.store = GetRef<BufferStore>(op);
+    if (loop_dependency_range) {
+      size_t i = loop_dependency_range.value().first;
+      size_t j = loop_dependency_range.value().second;
+      ICHECK_LT(i, active_loops_.size());
+      ICHECK_LT(j, active_loops_.size());
+
+      write_info.dependent_loopnest = {active_loops_.begin() + i, active_loops_.begin() + j + 1};
+    }
+    write_info.innermost_block_realize = innermost_block_realize_;
+
+    write_info.contains_row_major_traversal = [&]() -> bool {
+      const auto& loopnest = write_info.dependent_loopnest;
+      if (loopnest.empty()) {
+        return false;
+      }
+
+      if (loopnest.size() != old_buffer_->shape.size() || loopnest.size() != op->indices.size()) {
+        return false;
+      }
+
+      for (size_t i = 0; i < loopnest.size(); i++) {
+        const For& loop = loopnest[i];
+        const PrimExpr& buffer_dim = old_buffer_->shape[i];
+        PrimExpr index = Substitute(op->indices[i], active_let_bindings_);
+        bool is_loop_over_axis = index.same_as(loop->loop_var) && is_const_int(loop->min, 0) &&
+                                 ExprDeepEqual()(loop->extent, buffer_dim) &&
+                                 loop->kind == ForKind::kSerial;
+        if (!is_loop_over_axis) {
+          return false;
+        }
+      }
+
+      return true;
+    }();
+
+    write_info_.push_back(write_info);
+
+    // Don't need to continue recursing, as the entire goal was to
+    // find the BufferStore.
+  }
+
+  std::optional<std::pair<size_t, size_t>> LoopDependencyRange(const PrimExpr& expr) const {
+    std::optional<std::pair<size_t, size_t>> prev = std::nullopt;
+    for (const auto& var : UndefinedVars(expr)) {
+      auto it = loop_depth_lookup_.find(var.get());
+      if (it != loop_depth_lookup_.end()) {
+        if (prev.has_value()) {
+          prev = {std::min(prev.value().first, it->second.first),
+                  std::max(prev.value().second, it->second.second)};
+        } else {
+          prev = it->second;
+        }
+      }
+    }
+
+    return prev;
+  }
+
+  class BufferStoreReplacer : public StmtExprMutator {
+   public:
+    BufferStoreReplacer(std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store,
+                        std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
+                            replace_block_realize)
+        : replace_store_(replace_store), replace_block_realize_(replace_block_realize) {}
+
+    Stmt VisitStmt_(const BufferStoreNode* op) final {
+      if (auto replacement = replace_store_(op)) {
+        auto store = Downcast<BufferStore>(replacement.value());
+        return StmtExprMutator::VisitStmt_(store.get());
+      } else {
+        return StmtExprMutator::VisitStmt_(op);
+      }
+    }
+
+    Stmt VisitStmt_(const BlockRealizeNode* op) final {
+      auto realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
+      if (auto replacement = replace_block_realize_(op, realize)) {
+        return replacement.value();
+      } else {
+        return std::move(realize);
+      }
+    }
+
+   private:
+    std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store_;
+    std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
+        replace_block_realize_;
+  };
+
+  TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse,
+                         PrimExpr padding_predicate, Optional<PrimExpr> pad_value) const {
+    if (auto prologue_plan =
+            FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value);
+        prologue_plan.has_value()) {
+      return prologue_plan.value();
+    } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, index_map, inverse,
+                                                               padding_predicate, pad_value);
+               replacement_plan.has_value()) {
+      return replacement_plan.value();
+    } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse,
+                                                         padding_predicate, pad_value);
+               epilogue_plan.has_value()) {
+      return epilogue_plan.value();
+    } else {
+      return NoPaddingRequired();
+    }
+  }
+
+  std::optional<ProloguePlan> FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map,
+                                                   IndexMap inverse, PrimExpr padding_predicate,
+                                                   Optional<PrimExpr> pad_value) const {
+    if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> iter_values;
+    Array<PrimExpr> indices;
+    Map<Var, PrimExpr> loop_indices_to_block_indices;
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+      const auto& loop_var = inverse->initial_indices[i];
+      const auto& dim = new_buffer->shape[i];
+      Var block_var("v_" + loop_var->name_hint, loop_var->dtype);
+      IterVar iter_var(Range(0, dim), block_var, kDataPar);
+      loop_indices_to_block_indices.Set(loop_var, block_var);
+      indices.push_back(iter_var->var);
+      iter_vars.push_back(iter_var);
+      iter_values.push_back(loop_var);
+    }
+    padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices);
+
+    PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value.value());
+    Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr}));
+
+    std::stringstream block_name;
+    block_name << "buffer_" << new_buffer->name << "_assumptions";
+    auto read_region = BufferRegion::FromPoint(new_buffer, indices);
+    stmt = BlockRealize(iter_values, Bool(true),
+                        Block(iter_vars, {read_region}, {}, block_name.str(), stmt));
+
+    for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+      size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+      Var loop_var = inverse->initial_indices[i];
+      PrimExpr extent = new_buffer->shape[i];
+      stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+    }
+    return ProloguePlan{stmt};
+  }
+
+  std::optional<ReplacementPlan> FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map,
+                                                         IndexMap inverse,
+                                                         PrimExpr padding_predicate,
+                                                         Optional<PrimExpr> pad_value) const {
+    if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional<Stmt> {
+      if (!info.contains_row_major_traversal || !pad_value.defined() ||
+          is_zero(padding_predicate)) {
+        return NullOpt;
+      }
+
+      Array<PrimExpr> old_indices = info.store->indices;
+      PrimExpr if_then_else_condition = padding_predicate;
+      Array<PrimExpr> new_indices;
+      for (const auto& var : inverse->initial_indices) {
+        new_indices.push_back(var);
+      }
+
+      auto replace_block_realize =
+          [&]() -> std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)> {
+        auto no_change = [](const BlockRealizeNode*, const BlockRealize&) -> Optional<Stmt> {
+          return NullOpt;
+        };
+        if (!info.innermost_block_realize) {
+          return no_change;
+        }
+        if (old_indices.empty()) {
+          return no_change;
+        }
+
+        BlockRealize block_realize = info.innermost_block_realize.value();
+        const auto& block = block_realize->block;
+
+        // Find the block iterators that are used to access the buffer.  Must be in the same order
+        // as they appear in the indices.
+        if (block->iter_vars.size() < old_indices.size()) {
+          return no_change;
+        }
+        const auto& iter_vars = block->iter_vars;
+        size_t block_index_start = 0;
+        for (; block_index_start < iter_vars.size() - old_indices.size(); block_index_start++) {
+          if (old_indices[0].same_as(iter_vars[block_index_start]->var)) {
+            break;
+          }
+        }
+        if (block_index_start > iter_vars.size() - old_indices.size()) {
+          return no_change;
+        }
+
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          if (!old_indices[i].same_as(iter_vars[block_index_start + i]->var) ||
+              iter_vars[block_index_start + i]->iter_type != kDataPar) {
+            return no_change;
+          }
+        }
+
+        // If we got to this point, all indices used to access the
+        // buffer are virtual indices defined in the innermost block.
+        // Therefore, generate new virtual indices for iterating over
+        // the post-transform buffer.
+        Array<PrimExpr> new_iter_values;             // For BlockRealize
+        Array<IterVar> new_iter_vars;                // For Block
+        Array<PrimExpr> new_access_indices;          // For BufferStore
+        Map<Var, PrimExpr> loop_var_to_virtual_var;  // For updating if_then_else_condition
+
+        for (size_t i = 0; i < block_index_start; i++) {
+          new_iter_vars.push_back(iter_vars[i]);
+          new_iter_values.push_back(block_realize->iter_values[i]);
+        }
+
+        ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+        for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+          Var var = inverse->initial_indices[i];
+          PrimExpr dim = new_buffer->shape[i];
+          std::stringstream ss;
+          ss << "v_" << var->name_hint;
+          Var virtual_var(ss.str(), var.dtype());
+          new_iter_values.push_back(var);
+          new_iter_vars.push_back(IterVar(Range::FromMinExtent(0, dim), virtual_var, kDataPar));
+          new_access_indices.push_back(virtual_var);
+          loop_var_to_virtual_var.Set(var, virtual_var);
+        }
+
+        for (size_t i = block_index_start + old_indices.size(); i < iter_vars.size(); i++) {
+          new_iter_vars.push_back(iter_vars[i]);
+          new_iter_values.push_back(block_realize->iter_values[i]);
+        }
+
+        Map<Var, PrimExpr> old_virtual_var_to_new_virtual_var;
+        ICHECK_EQ(inverse->final_indices.size(), old_indices.size());
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          Var var = Downcast<Var>(old_indices[i]);
+          PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var);
+          old_virtual_var_to_new_virtual_var.Set(var, expr);
+        }
+
+        if_then_else_condition = Substitute(if_then_else_condition, loop_var_to_virtual_var);
+        new_indices = new_access_indices;
+
+        return [target_realize = info.innermost_block_realize, new_iter_vars, new_iter_values,
+                old_virtual_var_to_new_virtual_var](const BlockRealizeNode* op,
+                                                    const BlockRealize& visited) -> Optional<Stmt> {
+          if (op == target_realize.get()) {
+            Block block = visited->block;
+            block =
+                Downcast<Block>(Substitute(std::move(block), old_virtual_var_to_new_virtual_var));
+            block.CopyOnWrite()->iter_vars = new_iter_vars;
+
+            BlockRealize realize = visited;
+            {
+              auto write_ptr = realize.CopyOnWrite();
+              write_ptr->block = block;
+              write_ptr->iter_values = new_iter_values;
+            }
+            return realize;
+          } else {
+            return NullOpt;
+          }
+        };
+      }();
+
+      bool all_stores_replaced = true;
+      auto replace_store = [&](const BufferStoreNode* op) -> Optional<Stmt> {
+        if (!op->buffer.same_as(info.store->buffer)) {
+          all_stores_replaced = false;
+          return NullOpt;
+        }
+        ICHECK_EQ(old_indices.size(), op->indices.size());
+        ExprDeepEqual expr_equal;
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          if (!expr_equal(old_indices[i], op->indices[i])) {
+            all_stores_replaced = false;
+            return NullOpt;
+          }
+        }
+
+        return BufferStore(new_buffer,
+                           if_then_else(if_then_else_condition, pad_value.value(), op->value),
+                           new_indices);
+      };
+
+      BufferStoreReplacer replacer(replace_store, replace_block_realize);
+      Stmt stmt = replacer(info.dependent_loopnest.back()->body);
+      if (!all_stores_replaced) {
+        return NullOpt;
+      }
+
+      std::unordered_map<const VarNode*, PrimExpr> var_remap;
+      ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
+      for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
+        Var var = info.dependent_loopnest[i]->loop_var;
+        PrimExpr expr = inverse->final_indices[i];
+        var_remap[var.get()] = expr;
+      }
+      stmt = Substitute(std::move(stmt), var_remap);
+
+      ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+      for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+        size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+        Var loop_var = inverse->initial_indices[i];
+        PrimExpr extent = new_buffer->shape[i];
+        stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+      }
+
+      return stmt;
+    };
+
+    Map<For, Stmt> loop_replacements;
+
+    for (const auto& info : write_info_) {
+      if (info.dependent_loopnest.size()) {
+        if (auto opt_stmt = generate_if_then_else_block(info)) {
+          loop_replacements.Set(info.dependent_loopnest[0], opt_stmt.value());
+        }
+      }
+    }
+
+    if (loop_replacements.size()) {
+      return ReplacementPlan{std::move(loop_replacements)};
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  std::optional<EpiloguePlan> FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map,
+                                                   IndexMap inverse, PrimExpr padding_predicate,
+                                                   Optional<PrimExpr> pad_value) const {
+    if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> iter_values;
+    Array<PrimExpr> indices;
+    Map<Var, PrimExpr> loop_indices_to_block_indices;
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+      const auto& loop_var = inverse->initial_indices[i];
+      const auto& dim = new_buffer->shape[i];
+      Var block_var("v_" + loop_var->name_hint, loop_var->dtype);
+      IterVar iter_var(Range(0, dim), block_var, kDataPar);
+      loop_indices_to_block_indices.Set(loop_var, block_var);
+      indices.push_back(iter_var->var);
+      iter_vars.push_back(iter_var);
+      iter_values.push_back(loop_var);
+    }
+    padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices);
+
+    Stmt stmt = BufferStore(new_buffer, pad_value.value(), indices);
+
+    std::stringstream block_name;
+    block_name << "buffer_" << new_buffer->name << "_padding";
+    auto write_region = BufferRegion::FromPoint(new_buffer, indices);
+    stmt = BlockRealize(iter_values, padding_predicate,
+                        Block(iter_vars, {}, {write_region}, block_name.str(), stmt));
+
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+      size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+      Var loop_var = inverse->initial_indices[i];
+      PrimExpr extent = new_buffer->shape[i];
+      stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+    }
+
+    const auto& info = write_info_.back();
+    Stmt insert_after = [&]() -> Stmt {
+      if (info.dependent_loopnest.size()) {
+        return info.dependent_loopnest.front();
+      } else if (info.innermost_block_realize) {
+        return info.innermost_block_realize.value();
+      } else {
+        LOG(FATAL) << "Write occured outside of any block/loop";
+        return Stmt();
+      }
+    }();
+    return EpiloguePlan{insert_after, stmt};
+  }
+
+  struct BindLoopVar {
+    BindLoopVar(LayoutTransformPlanner* self, For for_node)
+        : self_(self), var_(for_node->loop_var) {
+      size_t loop_depth = self_->active_loops_.size();
+      self_->loop_depth_lookup_[var_.get()] = {loop_depth, loop_depth};
+      self_->active_loops_.push_back(std::move(for_node));
+    }
+    ~BindLoopVar() {
+      self_->active_loops_.pop_back();
+      self_->loop_depth_lookup_.erase(var_.get());
+    }
+    BindLoopVar(const BindLoopVar&) = delete;
+    BindLoopVar& operator=(const BindLoopVar&) = delete;
+    BindLoopVar(BindLoopVar&&) = delete;
+    BindLoopVar& operator=(BindLoopVar&&) = delete;
+
+    LayoutTransformPlanner* self_{nullptr};
+    Var var_;
+  };
+
+  struct BindLetVar {
+    BindLetVar() {}
+    BindLetVar(LayoutTransformPlanner* self, Var var, PrimExpr value) : self_(self), var_(var) {
+      if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) {
+        self_->loop_depth_lookup_[var_.get()] = loop_depth.value();
+        self_->active_let_bindings_[var_.get()] = Substitute(value, self_->active_let_bindings_);
+      }
+    }
+    ~BindLetVar() {
+      if (self_) {
+        self_->loop_depth_lookup_.erase(var_.get());
+        self_->active_let_bindings_.erase(var_.get());
+      }
+    }
+    BindLetVar(const BindLetVar&) = delete;
+    BindLetVar& operator=(const BindLetVar&) = delete;
+    BindLetVar(BindLetVar&& other) : BindLetVar() { swap(other); }
+    BindLetVar& operator=(BindLetVar&& other) {
+      swap(other);
+      return *this;
+    }
+    void swap(BindLetVar& other) {
+      std::swap(self_, other.self_);
+      std::swap(var_, other.var_);
+    }
+
+    LayoutTransformPlanner* self_{nullptr};
+    Var var_;
+  };
+
+  struct BindBlockRealize {
+    BindBlockRealize(LayoutTransformPlanner* self, BlockRealize block_realize) : self_(self) {
+      ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size());
+      for (size_t i = 0; i < block_realize->iter_values.size(); i++) {
+        bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var,
+                                 block_realize->iter_values[i]);
+      }
+      cache_ = std::move(block_realize);
+      std::swap(self_->innermost_block_realize_, cache_);
+    }
+    ~BindBlockRealize() { std::swap(self_->innermost_block_realize_, cache_); }
+    BindBlockRealize(const BindBlockRealize&) = delete;
+    BindBlockRealize& operator=(const BindBlockRealize&) = delete;
+    BindBlockRealize(BindBlockRealize&&) = delete;
+    BindBlockRealize& operator=(BindBlockRealize&&) = delete;
+
+    LayoutTransformPlanner* self_{nullptr};
+    Optional<BlockRealize> cache_;
+    std::vector<BindLetVar> bound_vars_;
+  };
+
+  struct WriteInfo {
+    // The BufferStore object
+    BufferStore store;
+
+    // The block realize that contains the store, if any.
+    Optional<BlockRealize> innermost_block_realize;
+
+    // The nested loops whose values contribute to the indices used in
+    // the store.  Not all loop variables in the loopnest need to
+    // contribute, but the first and last must.
+    std::vector<For> dependent_loopnest;
+
+    // Whether the padding could be represented as a tir::if_then_else
+    // node.  This requires that the surrounding loop iterators
+    // iterate over all pre-transformation buffer axes, that there are
+    // no data dependencies between loop iterations, and that
+    bool contains_row_major_traversal{false};
+  };
+
+  struct LoopEntry {};
+
+  std::vector<WriteInfo> write_info_;
+  std::vector<For> active_loops_;
+  std::unordered_map<const VarNode*, std::pair<size_t, size_t>> loop_depth_lookup_;
+  std::unordered_map<const VarNode*, PrimExpr> active_let_bindings_;
+  Optional<BlockRealize> innermost_block_realize_{NullOpt};

Review Comment:
   document these fields



##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -16,12 +16,580 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+#include <optional>
+#include <variant>
+
 #include "../../../arith/ir_mutator_with_analyzer.h"
 #include "../utils.h"
 
 namespace tvm {
 namespace tir {
 
+class LayoutTransformPlanner : private StmtExprVisitor {

Review Comment:
   document the high level algorithm



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org