You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/11/14 22:51:45 UTC

[tvm] branch main updated: [TIR][Bugfix] Fix AXIS_SEPARATORS in tir.Schedule.transform_layout (#13326)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b6fae9b35e [TIR][Bugfix] Fix AXIS_SEPARATORS in tir.Schedule.transform_layout (#13326)
b6fae9b35e is described below

commit b6fae9b35eff4ad1f7cc2e83d8d7da5d701d8e44
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Nov 14 16:51:38 2022 -0600

    [TIR][Bugfix] Fix AXIS_SEPARATORS in tir.Schedule.transform_layout (#13326)
    
    Preivously, the block SREF reuse only included a single step of
    changes, and would have an incorrect mapping if multiple sequential
    changes to the TIR block occurred.  This could happen if a
    `BufferStore` was updated, followed by replacement of `Block` iter
    vars/values.  This commit tracks the Block replacements across each
    usage, to ensure the SREF instances remain valid.
---
 .../schedule/primitive/layout_transformation.cc    | 462 +++++++++++++--------
 src/tir/schedule/state.cc                          |   7 +-
 .../unittest/test_tir_schedule_transform_layout.py |  48 ++-
 3 files changed, 326 insertions(+), 191 deletions(-)

diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index e4c91dac58..c0b4ddfb4a 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -73,7 +73,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
   // Loops within the analyzed block that should be replaced
   struct ReplacementPlan {
     Map<For, Stmt> replacements;
-    Map<Block, Block> block_sref_reuse;
+    Map<Block, Block> new_block_to_old;
   };
 
   // The block to be inserted, along with the location at which it
@@ -100,6 +100,25 @@ class TransformLayoutPlanner : private StmtExprVisitor {
   }
 
  private:
+  struct WriteInfo {
+    // The BufferStore object
+    BufferStore store;
+
+    // The block realize that contains the store, if any.
+    Optional<BlockRealize> innermost_block_realize;
+
+    // The nested loops whose values contribute to the indices used in
+    // the store.  Not all loop variables in the loopnest need to
+    // contribute, but the first and last must.
+    std::vector<For> dependent_loopnest;
+
+    // Whether the padding could be represented as a tir::if_then_else
+    // node.  This requires that the surrounding loop iterators
+    // iterate over all pre-transformation buffer axes, that there are
+    // no data dependencies between loop iterations, and that
+    bool contains_row_major_traversal{false};
+  };
+
   explicit TransformLayoutPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {}
 
   void VisitStmt_(const ForNode* op) override {
@@ -197,33 +216,217 @@ class TransformLayoutPlanner : private StmtExprVisitor {
 
   class BufferStoreReplacer : public StmtExprMutator {
    public:
-    BufferStoreReplacer(std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store,
-                        std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
-                            replace_block_realize)
-        : replace_store_(replace_store), replace_block_realize_(replace_block_realize) {}
+    BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate,
+                        const IndexMap& inverse, const Optional<IndexMap>& pad_value,
+                        Map<Block, Block>* new_block_to_old)
+        : info(info),
+          new_buffer(new_buffer),
+          new_indices(inverse->initial_indices.Map([](const Var& var) -> PrimExpr { return var; })),
+          padding_predicate(padding_predicate),
+          inverse(inverse),
+          pad_value(pad_value),
+          new_block_to_old(*new_block_to_old) {
+      ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
+      for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
+        Var var = info.dependent_loopnest[i]->loop_var;
+        PrimExpr expr = inverse->final_indices[i];
+        var_remap.Set(var, expr);
+      }
+
+      DefineBlockUpdates();
+    }
+
+    bool is_all_stores_replaced() const { return all_stores_replaced; }
+
+   private:
+    void DefineBlockUpdates() {
+      if (!info.innermost_block_realize) {
+        return;
+      }
+
+      BlockRealize block_realize = info.innermost_block_realize.value();
+      const auto& block = block_realize->block;
+      const Array<PrimExpr>& old_indices = info.store->indices;
+      const auto& old_iter_vars = block->iter_vars;
+
+      this->new_iter_vars = old_iter_vars;
+      this->new_iter_values = block_realize->iter_values;
+
+      if (old_indices.empty()) {
+        return;
+      }
+
+      // Find the block iterators that are used to access the buffer.  Must be in the same
+      // order as they appear in the indices.
+      if (block->iter_vars.size() < old_indices.size()) {
+        return;
+      }
+
+      size_t block_index_start = 0;
+      for (; block_index_start < old_iter_vars.size() - old_indices.size(); block_index_start++) {
+        if (old_indices[0].same_as(old_iter_vars[block_index_start]->var)) {
+          break;
+        }
+      }
+      if (block_index_start > old_iter_vars.size() - old_indices.size()) {
+        return;
+      }
+
+      for (size_t i = 0; i < old_indices.size(); i++) {
+        if (!old_indices[i].same_as(old_iter_vars[block_index_start + i]->var) ||
+            old_iter_vars[block_index_start + i]->iter_type != kDataPar) {
+          return;
+        }
+      }
+
+      // If we got to this point, all indices used to access the
+      // buffer are virtual indices defined in the innermost block.
+      // Therefore, generate new virtual indices for iterating over
+      // the post-transform buffer.
+
+      new_indices = inverse->initial_indices.Map([](Var var) -> PrimExpr {
+        std::stringstream ss;
+        ss << "v_" << var->name_hint;
+        return Var(ss.str(), var.dtype());
+      });
+
+      Map<Var, PrimExpr>
+          loop_var_to_virtual_var;  // For updating padding_predicate in terms of the new indices
+      Array<PrimExpr> new_iter_values;  // For BlockRealize
+      Array<IterVar> new_iter_vars;     // For Block
+
+      for (size_t i = 0; i < block_index_start; i++) {
+        new_iter_vars.push_back(old_iter_vars[i]);
+        new_iter_values.push_back(block_realize->iter_values[i]);
+      }
+
+      ICHECK_EQ(new_indices.size(), new_buffer->shape.size());
+      for (size_t i = 0; i < new_indices.size(); i++) {
+        Var var = inverse->initial_indices[i];
+        Var virtual_var = Downcast<Var>(new_indices[i]);
+        PrimExpr dim = new_buffer->shape[i];
+        new_iter_values.push_back(var);
+        new_iter_vars.push_back(
+            IterVar(Range::FromMinExtent(make_zero(dim.dtype()), dim), virtual_var, kDataPar));
+        loop_var_to_virtual_var.Set(var, virtual_var);
+      }
+
+      for (size_t i = block_index_start + old_indices.size(); i < old_iter_vars.size(); i++) {
+        new_iter_vars.push_back(old_iter_vars[i]);
+        new_iter_values.push_back(block_realize->iter_values[i]);
+      }
+
+      ICHECK_EQ(inverse->final_indices.size(), old_indices.size());
+      for (size_t i = 0; i < old_indices.size(); i++) {
+        Var var = Downcast<Var>(old_indices[i]);
+        PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var);
+        var_remap.Set(var, expr);
+      }
+
+      padding_predicate = Substitute(padding_predicate, loop_var_to_virtual_var);
+
+      this->new_iter_vars = new_iter_vars;
+      this->new_iter_values = new_iter_values;
+    }
 
     Stmt VisitStmt_(const BufferStoreNode* op) final {
-      if (auto replacement = replace_store_(op)) {
-        auto store = Downcast<BufferStore>(replacement.value());
-        return StmtExprMutator::VisitStmt_(store.get());
+      bool can_replace = [&]() -> bool {
+        if (!op->buffer.same_as(info.store->buffer)) {
+          return false;
+        }
+
+        const Array<PrimExpr>& old_indices = info.store->indices;
+
+        ICHECK_EQ(old_indices.size(), op->indices.size());
+        ExprDeepEqual expr_equal;
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          if (!expr_equal(old_indices[i], op->indices[i])) {
+            return false;
+          }
+        }
+        return true;
+      }();
+
+      BufferStore store = GetRef<BufferStore>(op);
+      if (can_replace) {
+        PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_indices)[0];
+        store =
+            BufferStore(new_buffer, if_then_else(padding_predicate, pad_value_at_index, op->value),
+                        new_indices);
       } else {
-        return StmtExprMutator::VisitStmt_(op);
+        all_stores_replaced = false;
       }
+      return StmtExprMutator::VisitStmt_(store.get());
     }
 
     Stmt VisitStmt_(const BlockRealizeNode* op) final {
-      auto realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
-      if (auto replacement = replace_block_realize_(op, realize)) {
-        return replacement.value();
+      BlockRealize realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
+
+      if (op == info.innermost_block_realize.get()) {
+        Block block = realize->block;
+        if (!block->iter_vars.same_as(this->new_iter_vars)) {
+          block.CopyOnWrite()->iter_vars = this->new_iter_vars;
+          RecordReplacement(op->block, block);
+        }
+
+        if (!block.same_as(realize->block) ||
+            !realize->iter_values.same_as(this->new_iter_values)) {
+          auto write_ptr = realize.CopyOnWrite();
+          write_ptr->block = block;
+          write_ptr->iter_values = this->new_iter_values;
+        }
+      }
+
+      return std::move(realize);
+    }
+
+    Stmt VisitStmt_(const BlockNode* op) final {
+      Block orig = GetRef<Block>(op);
+      Block mutated = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+
+      RecordReplacement(orig, mutated);
+      return std::move(mutated);
+    }
+
+    PrimExpr VisitExpr_(const VarNode* op) final {
+      Var var = GetRef<Var>(op);
+      if (auto opt = var_remap.Get(var)) {
+        return opt.value();
       } else {
-        return std::move(realize);
+        return std::move(var);
       }
     }
 
-   private:
-    std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store_;
-    std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
-        replace_block_realize_;
+    void RecordReplacement(Block before, Block after) {
+      if (before.same_as(after)) {
+        return;
+      }
+
+      ICHECK(!new_block_to_old.count(after));
+
+      while (true) {
+        if (auto opt = new_block_to_old.Get(before)) {
+          before = opt.value();
+        } else {
+          break;
+        }
+      }
+
+      new_block_to_old.Set(after, before);
+    }
+
+    const WriteInfo& info;
+    const Buffer& new_buffer;
+    Array<PrimExpr> new_indices;
+    Array<IterVar> new_iter_vars;
+    Array<PrimExpr> new_iter_values;
+    PrimExpr padding_predicate;
+    const IndexMap& inverse;
+    const Optional<IndexMap>& pad_value;
+    Map<Block, Block>& new_block_to_old;
+    bool all_stores_replaced{true};
+
+    Map<Var, PrimExpr> var_remap;
   };
 
   TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse,
@@ -296,159 +499,20 @@ class TransformLayoutPlanner : private StmtExprVisitor {
       return std::nullopt;
     }
 
+    Map<Block, Block> new_block_to_old;
     auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional<Stmt> {
       if (!info.contains_row_major_traversal || !pad_value.defined() ||
           is_zero(padding_predicate)) {
         return NullOpt;
       }
 
-      Array<PrimExpr> old_indices = info.store->indices;
-      PrimExpr if_then_else_condition = padding_predicate;
-      Array<PrimExpr> new_indices;
-      for (const auto& var : inverse->initial_indices) {
-        new_indices.push_back(var);
-      }
-
-      auto replace_block_realize =
-          [&]() -> std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)> {
-        auto no_change = [](const BlockRealizeNode*, const BlockRealize&) -> Optional<Stmt> {
-          return NullOpt;
-        };
-        if (!info.innermost_block_realize) {
-          return no_change;
-        }
-        if (old_indices.empty()) {
-          return no_change;
-        }
-
-        BlockRealize block_realize = info.innermost_block_realize.value();
-        const auto& block = block_realize->block;
-
-        // Find the block iterators that are used to access the buffer.  Must be in the same order
-        // as they appear in the indices.
-        if (block->iter_vars.size() < old_indices.size()) {
-          return no_change;
-        }
-        const auto& iter_vars = block->iter_vars;
-        size_t block_index_start = 0;
-        for (; block_index_start < iter_vars.size() - old_indices.size(); block_index_start++) {
-          if (old_indices[0].same_as(iter_vars[block_index_start]->var)) {
-            break;
-          }
-        }
-        if (block_index_start > iter_vars.size() - old_indices.size()) {
-          return no_change;
-        }
-
-        for (size_t i = 0; i < old_indices.size(); i++) {
-          if (!old_indices[i].same_as(iter_vars[block_index_start + i]->var) ||
-              iter_vars[block_index_start + i]->iter_type != kDataPar) {
-            return no_change;
-          }
-        }
-
-        // If we got to this point, all indices used to access the
-        // buffer are virtual indices defined in the innermost block.
-        // Therefore, generate new virtual indices for iterating over
-        // the post-transform buffer.
-        Array<PrimExpr> new_iter_values;             // For BlockRealize
-        Array<IterVar> new_iter_vars;                // For Block
-        Array<PrimExpr> new_access_indices;          // For BufferStore
-        Map<Var, PrimExpr> loop_var_to_virtual_var;  // For updating if_then_else_condition
-
-        for (size_t i = 0; i < block_index_start; i++) {
-          new_iter_vars.push_back(iter_vars[i]);
-          new_iter_values.push_back(block_realize->iter_values[i]);
-        }
-
-        ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
-        for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
-          Var var = inverse->initial_indices[i];
-          PrimExpr dim = new_buffer->shape[i];
-          std::stringstream ss;
-          ss << "v_" << var->name_hint;
-          Var virtual_var(ss.str(), var.dtype());
-          new_iter_values.push_back(var);
-          new_iter_vars.push_back(
-              IterVar(Range::FromMinExtent(make_zero(dim.dtype()), dim), virtual_var, kDataPar));
-          new_access_indices.push_back(virtual_var);
-          loop_var_to_virtual_var.Set(var, virtual_var);
-        }
-
-        for (size_t i = block_index_start + old_indices.size(); i < iter_vars.size(); i++) {
-          new_iter_vars.push_back(iter_vars[i]);
-          new_iter_values.push_back(block_realize->iter_values[i]);
-        }
-
-        Map<Var, PrimExpr> old_virtual_var_to_new_virtual_var;
-        ICHECK_EQ(inverse->final_indices.size(), old_indices.size());
-        for (size_t i = 0; i < old_indices.size(); i++) {
-          Var var = Downcast<Var>(old_indices[i]);
-          PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var);
-          old_virtual_var_to_new_virtual_var.Set(var, expr);
-        }
-
-        if_then_else_condition = Substitute(if_then_else_condition, loop_var_to_virtual_var);
-        new_indices = new_access_indices;
-
-        return [target_realize = info.innermost_block_realize, new_iter_vars, new_iter_values,
-                old_virtual_var_to_new_virtual_var](const BlockRealizeNode* op,
-                                                    const BlockRealize& visited) -> Optional<Stmt> {
-          if (op == target_realize.get()) {
-            Block block = visited->block;
-            block =
-                Downcast<Block>(Substitute(std::move(block), old_virtual_var_to_new_virtual_var));
-            block.CopyOnWrite()->iter_vars = new_iter_vars;
-
-            BlockRealize realize = visited;
-            {
-              auto write_ptr = realize.CopyOnWrite();
-              write_ptr->block = block;
-              write_ptr->iter_values = new_iter_values;
-            }
-            return realize;
-          } else {
-            return NullOpt;
-          }
-        };
-      }();
-
-      bool all_stores_replaced = true;
-      auto replace_store = [&](const BufferStoreNode* op) -> Optional<Stmt> {
-        if (!op->buffer.same_as(info.store->buffer)) {
-          all_stores_replaced = false;
-          return NullOpt;
-        }
-        ICHECK_EQ(old_indices.size(), op->indices.size());
-        ExprDeepEqual expr_equal;
-        for (size_t i = 0; i < old_indices.size(); i++) {
-          if (!expr_equal(old_indices[i], op->indices[i])) {
-            all_stores_replaced = false;
-            return NullOpt;
-          }
-        }
-
-        PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_indices)[0];
-        return BufferStore(new_buffer,
-                           if_then_else(if_then_else_condition, pad_value_at_index, op->value),
-                           new_indices);
-      };
-
-      BufferStoreReplacer replacer(replace_store, replace_block_realize);
+      BufferStoreReplacer replacer(info, new_buffer, padding_predicate, inverse, pad_value,
+                                   &new_block_to_old);
       Stmt stmt = replacer(info.dependent_loopnest.back()->body);
-      if (!all_stores_replaced) {
+      if (!replacer.is_all_stores_replaced()) {
         return NullOpt;
       }
 
-      std::unordered_map<const VarNode*, PrimExpr> var_remap;
-      ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
-      for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
-        Var var = info.dependent_loopnest[i]->loop_var;
-        PrimExpr expr = inverse->final_indices[i];
-        var_remap[var.get()] = expr;
-      }
-      stmt = Substitute(std::move(stmt), var_remap);
-
       ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
       for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
         size_t i = (inverse->initial_indices.size() - 1) - rev_i;
@@ -471,7 +535,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
     }
 
     if (loop_replacements.size()) {
-      return ReplacementPlan{std::move(loop_replacements)};
+      return ReplacementPlan{std::move(loop_replacements), std::move(new_block_to_old)};
     } else {
       return std::nullopt;
     }
@@ -603,25 +667,6 @@ class TransformLayoutPlanner : private StmtExprVisitor {
     std::vector<BindVariableDefinition> bound_vars_;
   };
 
-  struct WriteInfo {
-    // The BufferStore object
-    BufferStore store;
-
-    // The block realize that contains the store, if any.
-    Optional<BlockRealize> innermost_block_realize;
-
-    // The nested loops whose values contribute to the indices used in
-    // the store.  Not all loop variables in the loopnest need to
-    // contribute, but the first and last must.
-    std::vector<For> dependent_loopnest;
-
-    // Whether the padding could be represented as a tir::if_then_else
-    // node.  This requires that the surrounding loop iterators
-    // iterate over all pre-transformation buffer axes, that there are
-    // no data dependencies between loop iterations, and that
-    bool contains_row_major_traversal{false};
-  };
-
   /*! \brief Collected information about each BufferStore */
   std::vector<WriteInfo> write_info_;
 
@@ -683,7 +728,20 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
       auto write_ptr = result.CopyOnWrite();
       write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
     }
-    return {result, rewriter.block_sref_reuse_};
+
+    Map<Block, Block> block_sref_reuse;
+    for (auto [after, before] : rewriter.new_block_to_old_) {
+      while (auto opt = rewriter.new_block_to_old_.Get(before)) {
+        before = opt.value();
+      }
+      while (auto opt = block_sref_reuse.Get(after)) {
+        after = opt.value();
+      }
+
+      block_sref_reuse.Set(before, after);
+    }
+
+    return {result, block_sref_reuse};
   }
 
  private:
@@ -696,7 +754,11 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
         new_buffer_(new_buffer),
         index_map_(index_map),
         plan_(plan),
-        buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {}
+        buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {
+    if (auto plan_ptr = std::get_if<TransformLayoutPlanner::ReplacementPlan>(&plan_)) {
+      new_block_to_old_ = plan_ptr->new_block_to_old;
+    }
+  }
 
   void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
     *buffer = new_buffer_;
@@ -765,7 +827,20 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
   }
 
   Stmt VisitStmt_(const BlockNode* op) final {
+    Block orig = [&]() {
+      Block block = GetRef<Block>(op);
+      while (true) {
+        if (auto it = new_block_to_old_.find(block); it != new_block_to_old_.end()) {
+          block = (*it).second;
+        } else {
+          break;
+        }
+      }
+      return block;
+    }();
+
     Block block = Downcast<Block>(Parent::VisitStmt_(op));
+
     auto infered_access_regions = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
     auto* n = block.CopyOnWrite();
     RewriteAccessRegion(&n->reads, infered_access_regions[0]);
@@ -777,16 +852,35 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
         return buffer;
       }
     });
-    block_sref_reuse_.Set(GetRef<Block>(op), block);
+
+    RecordReplacement(orig, block);
     return std::move(block);
   }
 
+  void RecordReplacement(Block before, Block after) {
+    if (before.same_as(after)) {
+      return;
+    }
+
+    ICHECK(!new_block_to_old_.count(after));
+
+    while (true) {
+      if (auto opt = new_block_to_old_.Get(before)) {
+        before = opt.value();
+      } else {
+        break;
+      }
+    }
+
+    new_block_to_old_.Set(after, before);
+  }
+
   const Buffer& old_buffer_;
   const Buffer& new_buffer_;
   const IndexMap& index_map_;
   const TransformLayoutPlanner::TransformPlan& plan_;
   Map<Var, Buffer> buffer_data_to_buffer_;
-  Map<Block, Block> block_sref_reuse_;
+  Map<Block, Block> new_block_to_old_;
 };
 
 class BufferIsSubregionError : public ScheduleError {
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 27056124d9..a901eff6f2 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -662,10 +662,11 @@ class SRefTreePruner : public StmtVisitor {
         << GetRef<Block>(op);
     StmtSRef& sref = it->second;
     // Detect reuse
-    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
-    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+    const auto& sref_reuse = reuse_info_.block_sref_reuse;
+    if (auto reuse_it = sref_reuse.find(op); reuse_it != sref_reuse.end()) {
+      const BlockNode* to_reuse = reuse_it->second;
       // sref can be reused
-      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+      reused_srefs_.emplace(to_reuse, std::move(sref));
     } else {
       sref->Reset();
       self_->block_info.erase(sref);
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py
index ca5ac12a97..e904789223 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -415,13 +415,13 @@ class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
 
     transformed_buffer = tvm.testing.parameter("A")
 
+    index_map = tvm.testing.parameter(lambda i: [i // 4, i % 4])
+
     @pytest.fixture
-    def transform(self, pad_value, transformed_buffer):
+    def transform(self, pad_value, transformed_buffer, index_map):
         def transform(mod):
             sch = tir.Schedule(mod)
-            sch.transform_layout(
-                "block", transformed_buffer, lambda i: [i // 4, i % 4], pad_value=pad_value
-            )
+            sch.transform_layout("block", transformed_buffer, index_map, pad_value=pad_value)
             return sch.mod
 
         return transform
@@ -885,5 +885,45 @@ class TestTransformLayoutWithVar(tvm.testing.CompareBeforeAfter):
                 )
 
 
+class TestTransformWithAxisSeparators(BasePaddingCompare):
+    """Axis separators may be specified in a transform"""
+
+    index_map = tvm.testing.parameter(lambda i: [i // 4, tvm.tir.IndexMap.AXIS_SEPARATOR, i % 4])
+    pad_value = tvm.testing.parameter(0)
+
+    def before(a: T.handle):
+        A = T.match_buffer(a, [14], "int32")
+        for i in T.serial(14):
+            with T.block("block"):
+                vi = T.axis.remap("S", [i])
+                A[vi] = 42
+
+    def expected(a: T.handle):
+        A = T.match_buffer(a, [4, 4], "int32", axis_separators=[1])
+        for i, j in T.grid(4, 4):
+            with T.block("block"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                A[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 42, dtype="int32")
+
+
+class TestTransformWithAxisSeparatorsOpaqueBlock(BasePaddingCompare):
+    """Axis separators may be specified in a transform of opaque block"""
+
+    index_map = tvm.testing.parameter(lambda i: [i // 4, tvm.tir.IndexMap.AXIS_SEPARATOR, i % 4])
+    pad_value = tvm.testing.parameter(0)
+
+    def before(a: T.handle):
+        A = T.match_buffer(a, [14], "int32")
+        for i in T.serial(14):
+            with T.block("block"):
+                A[i] = 42
+
+    def expected(a: T.handle):
+        A = T.match_buffer(a, [4, 4], "int32", axis_separators=[1])
+        for i, j in T.grid(4, 4):
+            with T.block("block"):
+                A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42, dtype="int32")
+
+
 if __name__ == "__main__":
     tvm.testing.main()