You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/23 13:40:34 UTC

[GitHub] [tvm] Lunderberg opened a new pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Lunderberg opened a new pull request #9091:
URL: https://github.com/apache/tvm/pull/9091


   This started after noticing that StorageFlatten incorrectly handled BufferLoad/BufferStore nodes that pointed to a buffer defined in an `attr::buffer_bind_scope` annotation.  Rather than adding more logic into the existing `StorageFlattener` mutator, I split up the existing behavior into multiple independent mutators.
   
   This PR includes a series of commits, each of which refactors one of the behaviors out of the `StorageFlattener` class and into a separate class.  While all of the transforms are called sequentially in the `tir.transform.StorageFlatten` to maintain the same overall behavior, each transform results in a valid TIR tree.
   
   * BufferShapeLegalize, which rewrites Buffer nodes to have sizes that match the BufferRealize node in which they are defined.
   * BufferStrideLegalize, which rewrites the strides of Buffer nodes that are annotated with `attr::dim_align`.
   * ThreadScopePropagate, which defines the allocation scope of Buffer nodes based on the thread iter in which they are declared, if no allocation scope was already defined.
   * BufferBindUnwrapper, which rewrites access into Buffer objects that are defined by `attr::buffer_bind_scope`.  Refactoring this behavior into a separate mutator was my original goal, in order to resolve the issue of BufferLoad/BufferStore nodes that point to bound buffers, but doing so required the previous three behaviors to also be refactored into separate mutators.
   * StorageFlattener, which contains all remaining behavior from the original StorageFlattener, and outputs the final Allocate/Store/Load nodes.
   
   This refactor will also help in the future, when introducing layout transformations.
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
csullivan commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r715763492



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";

Review comment:
       Consider making this a CHECK as external buffers can be provided by the user and failing this check could indicate a use error. 

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    Array<PrimExpr> shape = buf->shape;
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.

Review comment:
       nit: An example of one such case that makes this assumption could be useful for context.

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {

Review comment:
       Add a comment describing that the buffer bind scope buffer attributes are updated according to the legalized buffer. Similarly add comments in the other passes which have BufferBindScope handler methods. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r715831607



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";

Review comment:
       Makes sense, and changed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tmoreau89 commented on pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#issuecomment-926140430


   CC @kparzysz-quic 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
csullivan commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r716909762



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -499,6 +1298,40 @@ class StorageFlattener : public StmtExprMutator {
   bool create_bound_attributes_{false};
 };
 
+// The specific tensor data layout is not determined before
+// StorageFlatten pass. We use buffer_bind_scope
+// to specify before hand we want to bind a subregion
+// of tensor to a symbolic buffer, which get used in extern.
+//
+// Example:
+//
+// realize A in range [i*4, extent=10) {
+//   bind Ab to A in [i*4+1, extent=4) {
+//     call_func(Ab.ptr, Ab.shape[0])
+//   }
+// }
+//
+// After StorageFlatten
+//
+// alloc A[10]
+//   call(A + 1,  4)
+//
+// Buffer is a protocol to declare specific
+// data layout and shape we expect.
+// So this function need to check:
+// - If the bind range is within the realize range
+// - If we can match the requirement of buffer
+// - Remap variables such as Ab.ptr to the actual value.
+//
+// Here are a few possible failure cases:
+// - Buffer is declared to have constant shape,
+//   but we try to bind it to a different one.
+// - Buffer is declared to be compact(no strides)
+//   but this binded region is a subregion of
+//   a matrix(tensor), which means it requires strides.
+//
+// We do support a few relaxed case, such as bindingx
+// region with shape [1, 1, n, m] to buffer with shape [n, m]

Review comment:
       These docs are duplicated from those on the BufferBindUnwrapper. Consider removing or updating.

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -507,8 +1340,20 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at
 
     IRVisitorWithAnalyzer bound_analyzer;
     bound_analyzer(fptr->body);
+
+    fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));
+
+    auto stride_legalize = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer);
+    fptr->body = stride_legalize(std::move(fptr->body));
+    fptr->buffer_map = stride_legalize.UpdatedExternBufferMap();
+
+    fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body));
+
+    fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));

Review comment:
       In reading through the refactor, it occurs to me that the passes prior to BufferBindUnwrapper could be simpler if the `buffer_bind_scope` was unwrapped earlier, e.g. prior to ThreadScopePropagate or perhaps first of all. Then each pass would not need to special case the handling done in the variants of `HandleBufferBindScope`. 
   
   Is there something that makes it difficult to apply `BufferBindUnwrapper` earlier?

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,933 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  // Any buffers that give views into a resized buffer should be
+  // updated, both to refer to the resized buffer and to have the view
+  // window updated.  For example, suppose B1 is a 1-D buffer of size
+  // 100 which is only realized on the range (10,50), and buffer V1 is
+  // a view into B1[25:35].  When B1 is replaced with B2, a buffer of
+  // size 40 realized on the range (0,40), V1 must be replaced to be a
+  // view into B2[15:25].
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.  For example, ArgBinder::BindBuffer
+    // and tir.Specialize.
+    if (dim_align_.count(buf) == 0) {
+      return buf;
+    }
+
+    // Can't define the strides for a buffer without a known shape.
+    Array<PrimExpr> shape = buf->shape;
+    if (shape.size() == 0) {
+      return buf;
+    }
+
+    std::vector<PrimExpr> rstrides;
+    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+    int first_dim = 0;
+    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+    for (size_t i = shape.size(); i != 0; --i) {
+      size_t dim = i - 1;
+      if (dim < avec.size() && avec[dim].align_factor != 0) {
+        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+        stride = bound_analyzer_->Simplify(stride);
+      }
+      rstrides.push_back(stride);
+      stride = stride * shape[dim];
+    }
+
+    auto ptr = buf.CopyOnWrite();
+    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+    return buf;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::buffer_dim_align) {
+      auto buffer = Downcast<tir::Buffer>(op->node);
+      const CallNode* tuple = op->value.as<CallNode>();
+      ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+      auto& vinfo = dim_align_[buffer];
+      int dim = tuple->args[0].as<IntImmNode>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2U);
+      Buffer source = Downcast<Buffer>(arr[0]);
+      Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+      Buffer source_with_strides = WithStrides(source);
+
+      {
+        BufferEntry entry;
+        entry.remap_to = source_with_strides;
+        entry.in_scope = true;
+        entry.is_external = false;
+        buf_map_[source] = entry;
+      }
+
+      Stmt body = this->VisitStmt(op->body);
+
+      return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
+                      op->value, body, op->span);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Buffer key = op->buffer;
+    Buffer with_strides = WithStrides(op->buffer);
+    {
+      BufferEntry entry;
+      entry.remap_to = with_strides;
+      entry.in_scope = true;
+      entry.is_external = false;
+      buf_map_[key] = entry;
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    buf_map_[key].in_scope = false;
+    op = stmt.as<BufferRealizeNode>();
+    ICHECK(op);
+
+    return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    return BufferLoad(e.remap_to, op->indices, op->span);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    return BufferStore(e.remap_to, op->value, op->indices, op->span);
+  }
+
+ private:
+  Map<Var, Buffer> updated_extern_buffer_map_;
+
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  // Dimension alignment
+  std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
+
+  struct BufferEntry {
+    Buffer remap_to;
+    bool in_scope;
+    bool is_external;
+  };
+
+  std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation.  All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+  explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+    // External buffers shouldn't be overwritten, even if they have a
+    // BufferRealizeNode.
+    for (auto kv : extern_buffer_map) {
+      external_buffers_.insert(kv.second);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = buf_remap_.find(GetRef<Var>(op));
+    if (it != buf_remap_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "StorageFlattener assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+      curr_thread_scope_.push_back(ts);
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      curr_thread_scope_.pop_back();
+      return stmt;
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Var old_var = op->buffer->data;
+
+    // Don't remap buffers that already have an explicit scope,
+    // or external buffers.
+    std::string str_scope = GetPtrStorageScope(old_var);
+    if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    ICHECK_EQ(buf_remap_.count(old_var), 0)
+        << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
+
+    StorageScope skey;
+    if (curr_thread_scope_.size() == 0) {
+      skey.rank = StorageRank::kGlobal;
+    } else {
+      skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+    }
+
+    auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type);
+    Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
+                old_var->span);
+
+    Buffer buf = op->buffer;
+    buf.CopyOnWrite()->data = new_var;
+
+    buf_remap_[old_var] = buf;
+
+    Stmt body = this->VisitStmt(op->body);
+    return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferLoad(it->second, op->indices, op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices, op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // If the rewritten buffers are part of a buffer_bind_scope, either
+  // as the buffer view or as the the buffer being viewed, then the
+  // buffer_bind_scope must be rewritten to refer to the updated
+  // buffers.
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    bool needs_rewrite = false;
+
+    {
+      auto it = buf_remap_.find(buffer->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        buffer = it->second;
+      }
+    }
+
+    {
+      auto it = buf_remap_.find(target->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        target = it->second;
+      }
+    }
+
+    if (needs_rewrite) {
+      Stmt body = this->VisitStmt(op->body);
+      return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+  // The current thread scope.
+  std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer.  This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+  explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[kv.second.get()] = std::move(e);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<StoreNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (StoreNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Store(new_buf_var, op->value, op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<LoadNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (LoadNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Load(op->dtype, new_buf_var, op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
+                                Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return indices;
+    }
+
+    ICHECK_EQ(begins.size(), indices.size());
+
+    Array<PrimExpr> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(begins[i] + indices[i]);
+    }
+    return out;
+  }
+
+  Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return bounds;
+    }
+
+    ICHECK_EQ(begins.size(), bounds.size());
+
+    Array<Range> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
+    }
+    return out;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferLoad(e.remap->target,
+                        remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferStore(e.remap->target, op->value,
+                         remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    const auto& key = op->buffer.get();
+
+    bool is_external = false;
+
+    if (buf_map_.count(key)) {
+      ICHECK(buf_map_.at(key).external)
+          << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
+
+      is_external = true;
+    } else {
+      BufferEntry e;
+      e.bounds = op->bounds;
+      e.buffer = op->buffer;
+      buf_map_[key] = std::move(e);
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    if (is_external) {
+      buf_map_[key].in_scope = false;
+    }
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const PrefetchNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<PrefetchNode>();
+    ICHECK(op != nullptr);
+
+    const auto& key = op->buffer.get();
+    auto it = buf_map_.find(key);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+    const BufferEntry& e = it->second;
+
+    ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+    ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+        << "Prefetch dim should be the same as buffer dim";
+
+    if (e.remap) {
+      return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
+                      op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // The specific tensor data layout is not determined before
+  // StorageFlatten pass. We use buffer_bind_scope
+  // to specify before hand we want to bind a subregion
+  // of tensor to a symbolic buffer, which get used in extern.
+  //
+  // Example:
+  //
+  // realize A in range [i*4, extent=10) {
+  //   bind Ab to A in [i*4+1, extent=4) {
+  //     call_func(Ab.ptr, Ab.shape[0])
+  //   }
+  // }
+  //
+  // After StorageFlatten
+  //
+  // alloc A[10]
+  //   call(A + 1,  4)
+  //
+  // Buffer is a protocol to declare specific
+  // data layout and shape we expect.
+  // So this function need to check:
+  // - If the bind range is within the realize range
+  // - If we can match the requirement of buffer
+  // - Remap variables such as Ab.ptr to the actual value.
+  //
+  // Here are a few possible failure cases:
+  // - Buffer is declared to have constant shape,
+  //   but we try to bind it to a different one.
+  // - Buffer is declared to be compact(no strides)
+  //   but this binded region is a subregion of
+  //   a matrix(tensor), which means it requires strides.
+  //
+  // We do support a few relaxed case, such as bindingx
+  // region with shape [1, 1, n, m] to buffer with shape [n, m]
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    // Unpack information from Attribute node
+    RemapInfo remap;
+
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    const Buffer source = Downcast<Buffer>(arr[0]);
+    ICHECK(source.defined());
+    remap.target = Downcast<Buffer>(arr[1]);
+    ICHECK(remap.target.defined());
+    const CallNode* tuple = op->value.as<CallNode>();
+    ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+
+    for (size_t i = 0; i < tuple->args.size(); i += 2) {
+      remap.begins.push_back(tuple->args[i]);
+      remap.extents.push_back(tuple->args[i + 1]);
+    }
+
+    // Determine bounds in the target buffer
+    auto it = buf_map_.find(remap.target.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find buffer " << remap.target << " @ "
+                                 << remap.target.get();
+    const BufferEntry& target_info = it->second;
+    ICHECK(target_info.in_scope) << "Cannot bind to a buffer that is out of scope";
+    ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size())
+        << "Incorrect number of arguments in buffer_bind_scope attribute.  "
+        << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N).";
+
+    if (target_info.bounds.size() > 0) {
+      Array<PrimExpr> mapped_begins;
+      for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) {
+        mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min);
+      }
+      remap.begins = std::move(mapped_begins);
+    }
+
+    ICHECK(target_info.remap == nullptr) << "Indirect remapping not handled";
+
+    for (size_t i = 0; i < remap.begins.size(); i++) {
+      remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i]));
+      remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i]));
+    }
+
+    // Add a buffer remap entry
+    {
+      BufferEntry source_info;
+      source_info.buffer = source;
+      source_info.remap = std::make_unique<RemapInfo>(remap);
+
+      buf_map_[source.get()] = std::move(source_info);
+    }
+
+    // Define remappings of any remaining Variables (e.g. Store/Load nodes).
+    ArgBinder binder(&var_remap_);
+
+    // Define a view that represents the source's view into the target
+    // buffer.  This Buffer object is only used to define the mapping
+    // to the target buffer, and never actually appears in the TIR
+    // graph.
+    Buffer view = remap.target.MakeSlice(remap.begins, remap.extents);
+    if (source->strides.size() == 0) {
+      ICHECK_EQ(view->strides.size(), 0U)
+          << "Cannot bind a compact buffer to a strided buffer" << view->strides;
+    } else {
+      // Add explicit strides to the view, in order to bind to source.strides[i].
+      view = view.MakeStrideView();
+    }
+    binder.BindBuffer(source, view, source->name, true);
+
+    // Apply the remaps
+    Stmt body = op->body;
+    body = MergeNest(binder.asserts(), body);
+    body = MergeNest(binder.init_nest(), body);
+    body = this->VisitStmt(body);
+    // remove the binds
+    for (const Var& v : binder.defs()) {
+      var_remap_.erase(v.get());
+    }
+    return body;
+  }
+
+  struct RemapInfo {
+    Buffer target;
+    Array<PrimExpr> begins;
+    Array<PrimExpr> extents;
+  };
+
+  // The buffer entry in the flatten map
+  struct BufferEntry {
+    // The storage buffer
+    Buffer buffer;
+    // the bounds of realization, can be null, means everything
+    Region bounds;
+    // Whether the buffer is external
+    bool external{false};
+    // Whether we are within the allocation scope of the buffer.
+    bool in_scope{true};
+
+    // The buffer to which the storage buffer should be remapped.
+    std::unique_ptr<RemapInfo> remap{nullptr};
+
+    PrimExpr ElemOffset() const {
+      ICHECK(remap);
+
+      Buffer copy = remap->target;
+      {
+        Array<PrimExpr> shape;
+        for (auto r : bounds) {
+          shape.push_back(r->extent);
+        }
+        copy.CopyOnWrite()->shape = std::move(shape);
+      }
+
+      Buffer target_slice = copy.MakeSlice(remap->begins, remap->extents);
+      if (buffer->strides.size() == 0) {
+        ICHECK_EQ(target_slice->strides.size(), 0U)
+            << "Trying to bind compact buffer to strided one strides=" << target_slice->strides;
+      } else {
+        target_slice = target_slice.MakeStrideView();
+      }
+
+      return copy->ElemOffset(remap->begins);
+    }
+  };

Review comment:
       Not seeing `struct BufferEntry::ElemOffset` used anywhere. Consider removing or refactoring to use the method?

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,933 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  // Any buffers that give views into a resized buffer should be
+  // updated, both to refer to the resized buffer and to have the view
+  // window updated.  For example, suppose B1 is a 1-D buffer of size
+  // 100 which is only realized on the range (10,50), and buffer V1 is
+  // a view into B1[25:35].  When B1 is replaced with B2, a buffer of
+  // size 40 realized on the range (0,40), V1 must be replaced to be a
+  // view into B2[15:25].
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.  For example, ArgBinder::BindBuffer
+    // and tir.Specialize.
+    if (dim_align_.count(buf) == 0) {
+      return buf;
+    }
+
+    // Can't define the strides for a buffer without a known shape.
+    Array<PrimExpr> shape = buf->shape;
+    if (shape.size() == 0) {
+      return buf;
+    }
+
+    std::vector<PrimExpr> rstrides;
+    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+    int first_dim = 0;
+    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+    for (size_t i = shape.size(); i != 0; --i) {
+      size_t dim = i - 1;
+      if (dim < avec.size() && avec[dim].align_factor != 0) {
+        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+        stride = bound_analyzer_->Simplify(stride);
+      }
+      rstrides.push_back(stride);
+      stride = stride * shape[dim];
+    }
+
+    auto ptr = buf.CopyOnWrite();
+    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+    return buf;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::buffer_dim_align) {
+      auto buffer = Downcast<tir::Buffer>(op->node);
+      const CallNode* tuple = op->value.as<CallNode>();
+      ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+      auto& vinfo = dim_align_[buffer];
+      int dim = tuple->args[0].as<IntImmNode>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2U);
+      Buffer source = Downcast<Buffer>(arr[0]);
+      Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+      Buffer source_with_strides = WithStrides(source);
+
+      {
+        BufferEntry entry;
+        entry.remap_to = source_with_strides;
+        entry.in_scope = true;
+        entry.is_external = false;
+        buf_map_[source] = entry;
+      }
+
+      Stmt body = this->VisitStmt(op->body);
+
+      return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
+                      op->value, body, op->span);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Buffer key = op->buffer;
+    Buffer with_strides = WithStrides(op->buffer);
+    {
+      BufferEntry entry;
+      entry.remap_to = with_strides;
+      entry.in_scope = true;
+      entry.is_external = false;
+      buf_map_[key] = entry;
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    buf_map_[key].in_scope = false;
+    op = stmt.as<BufferRealizeNode>();
+    ICHECK(op);
+
+    return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    return BufferLoad(e.remap_to, op->indices, op->span);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    return BufferStore(e.remap_to, op->value, op->indices, op->span);
+  }
+
+ private:
+  Map<Var, Buffer> updated_extern_buffer_map_;
+
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  // Dimension alignment
+  std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
+
+  struct BufferEntry {
+    Buffer remap_to;
+    bool in_scope;
+    bool is_external;
+  };
+
+  std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation.  All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+  explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+    // External buffers shouldn't be overwritten, even if they have a
+    // BufferRealizeNode.
+    for (auto kv : extern_buffer_map) {
+      external_buffers_.insert(kv.second);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = buf_remap_.find(GetRef<Var>(op));
+    if (it != buf_remap_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "StorageFlattener assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+      curr_thread_scope_.push_back(ts);
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      curr_thread_scope_.pop_back();
+      return stmt;
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Var old_var = op->buffer->data;
+
+    // Don't remap buffers that already have an explicit scope,
+    // or external buffers.
+    std::string str_scope = GetPtrStorageScope(old_var);
+    if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    ICHECK_EQ(buf_remap_.count(old_var), 0)
+        << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
+
+    StorageScope skey;
+    if (curr_thread_scope_.size() == 0) {
+      skey.rank = StorageRank::kGlobal;
+    } else {
+      skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+    }
+
+    auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type);
+    Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
+                old_var->span);
+
+    Buffer buf = op->buffer;
+    buf.CopyOnWrite()->data = new_var;
+
+    buf_remap_[old_var] = buf;
+
+    Stmt body = this->VisitStmt(op->body);
+    return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferLoad(it->second, op->indices, op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices, op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // If the rewritten buffers are part of a buffer_bind_scope, either
+  // as the buffer view or as the the buffer being viewed, then the
+  // buffer_bind_scope must be rewritten to refer to the updated
+  // buffers.
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    bool needs_rewrite = false;
+
+    {
+      auto it = buf_remap_.find(buffer->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        buffer = it->second;
+      }
+    }
+
+    {
+      auto it = buf_remap_.find(target->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        target = it->second;
+      }
+    }
+
+    if (needs_rewrite) {
+      Stmt body = this->VisitStmt(op->body);
+      return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+  // The current thread scope.
+  std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer.  This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+  explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[kv.second.get()] = std::move(e);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<StoreNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (StoreNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Store(new_buf_var, op->value, op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<LoadNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (LoadNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Load(op->dtype, new_buf_var, op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
+                                Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return indices;
+    }
+
+    ICHECK_EQ(begins.size(), indices.size());
+
+    Array<PrimExpr> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(begins[i] + indices[i]);
+    }
+    return out;
+  }
+
+  Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return bounds;
+    }
+
+    ICHECK_EQ(begins.size(), bounds.size());
+
+    Array<Range> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
+    }
+    return out;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferLoad(e.remap->target,
+                        remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferStore(e.remap->target, op->value,
+                         remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    const auto& key = op->buffer.get();
+
+    bool is_external = false;
+
+    if (buf_map_.count(key)) {
+      ICHECK(buf_map_.at(key).external)
+          << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
+
+      is_external = true;
+    } else {
+      BufferEntry e;
+      e.bounds = op->bounds;
+      e.buffer = op->buffer;
+      buf_map_[key] = std::move(e);
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    if (is_external) {
+      buf_map_[key].in_scope = false;
+    }
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const PrefetchNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<PrefetchNode>();
+    ICHECK(op != nullptr);
+
+    const auto& key = op->buffer.get();
+    auto it = buf_map_.find(key);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+    const BufferEntry& e = it->second;
+
+    ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+    ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+        << "Prefetch dim should be the same as buffer dim";
+
+    if (e.remap) {
+      return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
+                      op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // The specific tensor data layout is not determined before
+  // StorageFlatten pass. We use buffer_bind_scope
+  // to specify before hand we want to bind a subregion
+  // of tensor to a symbolic buffer, which get used in extern.
+  //
+  // Example:
+  //
+  // realize A in range [i*4, extent=10) {
+  //   bind Ab to A in [i*4+1, extent=4) {
+  //     call_func(Ab.ptr, Ab.shape[0])
+  //   }
+  // }
+  //
+  // After StorageFlatten
+  //
+  // alloc A[10]
+  //   call(A + 1,  4)
+  //
+  // Buffer is a protocol to declare specific
+  // data layout and shape we expect.
+  // So this function need to check:
+  // - If the bind range is within the realize range
+  // - If we can match the requirement of buffer
+  // - Remap variables such as Ab.ptr to the actual value.
+  //
+  // Here are a few possible failure cases:
+  // - Buffer is declared to have constant shape,
+  //   but we try to bind it to a different one.
+  // - Buffer is declared to be compact(no strides)
+  //   but this binded region is a subregion of
+  //   a matrix(tensor), which means it requires strides.
+  //
+  // We do support a few relaxed case, such as bindingx

Review comment:
       Typo,
   ```suggestion
     // We do support a few relaxed case, such as binding a
   ```

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -137,8 +152,15 @@ class BufferShapeLegalize : public StmtExprMutator {
           << "Inconsistent dimensions for buffer " << op->buffer->name;
 
       Array<PrimExpr> new_indices;
-      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
-        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+
+      // Pad leading indices with zero, matching the "fuzzy_match"
+      // behavior from ArgBinder::BindBuffer.
+      size_t diff = entry.realized_begins.size() - op->indices.size();
+      for (size_t i = 0; i < diff; i++) {
+        new_indices.push_back(0);

Review comment:
       Assuming this is for matching cases like [1, 1, n, m], do we need to check that the leading axes are indeed extent=1?

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -137,8 +152,15 @@ class BufferShapeLegalize : public StmtExprMutator {
           << "Inconsistent dimensions for buffer " << op->buffer->name;
 
       Array<PrimExpr> new_indices;
-      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
-        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+
+      // Pad leading indices with zero, matching the "fuzzy_match"
+      // behavior from ArgBinder::BindBuffer.
+      size_t diff = entry.realized_begins.size() - op->indices.size();
+      for (size_t i = 0; i < diff; i++) {
+        new_indices.push_back(0);

Review comment:
       Let me ask a more general question, is it possible to expand or squeeze all shapes and then do exact matching as before? Noticing a fair amount of special casing in this commit for handling the extra unit dimensions.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
csullivan commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r716909762



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -499,6 +1298,40 @@ class StorageFlattener : public StmtExprMutator {
   bool create_bound_attributes_{false};
 };
 
+// The specific tensor data layout is not determined before
+// StorageFlatten pass. We use buffer_bind_scope
+// to specify before hand we want to bind a subregion
+// of tensor to a symbolic buffer, which get used in extern.
+//
+// Example:
+//
+// realize A in range [i*4, extent=10) {
+//   bind Ab to A in [i*4+1, extent=4) {
+//     call_func(Ab.ptr, Ab.shape[0])
+//   }
+// }
+//
+// After StorageFlatten
+//
+// alloc A[10]
+//   call(A + 1,  4)
+//
+// Buffer is a protocol to declare specific
+// data layout and shape we expect.
+// So this function need to check:
+// - If the bind range is within the realize range
+// - If we can match the requirement of buffer
+// - Remap variables such as Ab.ptr to the actual value.
+//
+// Here are a few possible failure cases:
+// - Buffer is declared to have constant shape,
+//   but we try to bind it to a different one.
+// - Buffer is declared to be compact(no strides)
+//   but this binded region is a subregion of
+//   a matrix(tensor), which means it requires strides.
+//
+// We do support a few relaxed case, such as bindingx
+// region with shape [1, 1, n, m] to buffer with shape [n, m]

Review comment:
       These docs are duplicated from those on the BufferBindUnwrapper. Consider removing or updating.

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -507,8 +1340,20 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at
 
     IRVisitorWithAnalyzer bound_analyzer;
     bound_analyzer(fptr->body);
+
+    fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));
+
+    auto stride_legalize = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer);
+    fptr->body = stride_legalize(std::move(fptr->body));
+    fptr->buffer_map = stride_legalize.UpdatedExternBufferMap();
+
+    fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body));
+
+    fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));

Review comment:
       In reading through the refactor, it occurs to me that the passes prior to BufferBindUnwrapper could be simpler if the `buffer_bind_scope` was unwrapped earlier, e.g. prior to ThreadScopePropagate or perhaps first of all. Then each pass would not need to special case the handling done in the variants of `HandleBufferBindScope`. 
   
   Is there something that makes it difficult to apply `BufferBindUnwrapper` earlier?

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,933 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  // Any buffers that give views into a resized buffer should be
+  // updated, both to refer to the resized buffer and to have the view
+  // window updated.  For example, suppose B1 is a 1-D buffer of size
+  // 100 which is only realized on the range (10,50), and buffer V1 is
+  // a view into B1[25:35].  When B1 is replaced with B2, a buffer of
+  // size 40 realized on the range (0,40), V1 must be replaced to be a
+  // view into B2[15:25].
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.  For example, ArgBinder::BindBuffer
+    // and tir.Specialize.
+    if (dim_align_.count(buf) == 0) {
+      return buf;
+    }
+
+    // Can't define the strides for a buffer without a known shape.
+    Array<PrimExpr> shape = buf->shape;
+    if (shape.size() == 0) {
+      return buf;
+    }
+
+    std::vector<PrimExpr> rstrides;
+    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+    int first_dim = 0;
+    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+    for (size_t i = shape.size(); i != 0; --i) {
+      size_t dim = i - 1;
+      if (dim < avec.size() && avec[dim].align_factor != 0) {
+        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+        stride = bound_analyzer_->Simplify(stride);
+      }
+      rstrides.push_back(stride);
+      stride = stride * shape[dim];
+    }
+
+    auto ptr = buf.CopyOnWrite();
+    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+    return buf;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::buffer_dim_align) {
+      auto buffer = Downcast<tir::Buffer>(op->node);
+      const CallNode* tuple = op->value.as<CallNode>();
+      ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+      auto& vinfo = dim_align_[buffer];
+      int dim = tuple->args[0].as<IntImmNode>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2U);
+      Buffer source = Downcast<Buffer>(arr[0]);
+      Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+      Buffer source_with_strides = WithStrides(source);
+
+      {
+        BufferEntry entry;
+        entry.remap_to = source_with_strides;
+        entry.in_scope = true;
+        entry.is_external = false;
+        buf_map_[source] = entry;
+      }
+
+      Stmt body = this->VisitStmt(op->body);
+
+      return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
+                      op->value, body, op->span);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Buffer key = op->buffer;
+    Buffer with_strides = WithStrides(op->buffer);
+    {
+      BufferEntry entry;
+      entry.remap_to = with_strides;
+      entry.in_scope = true;
+      entry.is_external = false;
+      buf_map_[key] = entry;
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    buf_map_[key].in_scope = false;
+    op = stmt.as<BufferRealizeNode>();
+    ICHECK(op);
+
+    return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    return BufferLoad(e.remap_to, op->indices, op->span);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    return BufferStore(e.remap_to, op->value, op->indices, op->span);
+  }
+
+ private:
+  Map<Var, Buffer> updated_extern_buffer_map_;
+
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  // Dimension alignment
+  std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
+
+  struct BufferEntry {
+    Buffer remap_to;
+    bool in_scope;
+    bool is_external;
+  };
+
+  std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation.  All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+  explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+    // External buffers shouldn't be overwritten, even if they have a
+    // BufferRealizeNode.
+    for (auto kv : extern_buffer_map) {
+      external_buffers_.insert(kv.second);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = buf_remap_.find(GetRef<Var>(op));
+    if (it != buf_remap_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "StorageFlattener assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+      curr_thread_scope_.push_back(ts);
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      curr_thread_scope_.pop_back();
+      return stmt;
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Var old_var = op->buffer->data;
+
+    // Don't remap buffers that already have an explicit scope,
+    // or external buffers.
+    std::string str_scope = GetPtrStorageScope(old_var);
+    if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    ICHECK_EQ(buf_remap_.count(old_var), 0)
+        << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
+
+    StorageScope skey;
+    if (curr_thread_scope_.size() == 0) {
+      skey.rank = StorageRank::kGlobal;
+    } else {
+      skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+    }
+
+    auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type);
+    Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
+                old_var->span);
+
+    Buffer buf = op->buffer;
+    buf.CopyOnWrite()->data = new_var;
+
+    buf_remap_[old_var] = buf;
+
+    Stmt body = this->VisitStmt(op->body);
+    return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferLoad(it->second, op->indices, op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices, op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // If the rewritten buffers are part of a buffer_bind_scope, either
+  // as the buffer view or as the the buffer being viewed, then the
+  // buffer_bind_scope must be rewritten to refer to the updated
+  // buffers.
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    bool needs_rewrite = false;
+
+    {
+      auto it = buf_remap_.find(buffer->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        buffer = it->second;
+      }
+    }
+
+    {
+      auto it = buf_remap_.find(target->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        target = it->second;
+      }
+    }
+
+    if (needs_rewrite) {
+      Stmt body = this->VisitStmt(op->body);
+      return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+  // The current thread scope.
+  std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer.  This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+  explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[kv.second.get()] = std::move(e);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<StoreNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (StoreNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Store(new_buf_var, op->value, op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<LoadNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (LoadNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Load(op->dtype, new_buf_var, op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
+                                Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return indices;
+    }
+
+    ICHECK_EQ(begins.size(), indices.size());
+
+    Array<PrimExpr> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(begins[i] + indices[i]);
+    }
+    return out;
+  }
+
+  Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return bounds;
+    }
+
+    ICHECK_EQ(begins.size(), bounds.size());
+
+    Array<Range> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
+    }
+    return out;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferLoad(e.remap->target,
+                        remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferStore(e.remap->target, op->value,
+                         remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    const auto& key = op->buffer.get();
+
+    bool is_external = false;
+
+    if (buf_map_.count(key)) {
+      ICHECK(buf_map_.at(key).external)
+          << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
+
+      is_external = true;
+    } else {
+      BufferEntry e;
+      e.bounds = op->bounds;
+      e.buffer = op->buffer;
+      buf_map_[key] = std::move(e);
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    if (is_external) {
+      buf_map_[key].in_scope = false;
+    }
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const PrefetchNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<PrefetchNode>();
+    ICHECK(op != nullptr);
+
+    const auto& key = op->buffer.get();
+    auto it = buf_map_.find(key);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+    const BufferEntry& e = it->second;
+
+    ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+    ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+        << "Prefetch dim should be the same as buffer dim";
+
+    if (e.remap) {
+      return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
+                      op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // The specific tensor data layout is not determined before
+  // StorageFlatten pass. We use buffer_bind_scope
+  // to specify before hand we want to bind a subregion
+  // of tensor to a symbolic buffer, which get used in extern.
+  //
+  // Example:
+  //
+  // realize A in range [i*4, extent=10) {
+  //   bind Ab to A in [i*4+1, extent=4) {
+  //     call_func(Ab.ptr, Ab.shape[0])
+  //   }
+  // }
+  //
+  // After StorageFlatten
+  //
+  // alloc A[10]
+  //   call(A + 1,  4)
+  //
+  // Buffer is a protocol to declare specific
+  // data layout and shape we expect.
+  // So this function need to check:
+  // - If the bind range is within the realize range
+  // - If we can match the requirement of buffer
+  // - Remap variables such as Ab.ptr to the actual value.
+  //
+  // Here are a few possible failure cases:
+  // - Buffer is declared to have constant shape,
+  //   but we try to bind it to a different one.
+  // - Buffer is declared to be compact(no strides)
+  //   but this binded region is a subregion of
+  //   a matrix(tensor), which means it requires strides.
+  //
+  // We do support a few relaxed case, such as bindingx
+  // region with shape [1, 1, n, m] to buffer with shape [n, m]
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    // Unpack information from Attribute node
+    RemapInfo remap;
+
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    const Buffer source = Downcast<Buffer>(arr[0]);
+    ICHECK(source.defined());
+    remap.target = Downcast<Buffer>(arr[1]);
+    ICHECK(remap.target.defined());
+    const CallNode* tuple = op->value.as<CallNode>();
+    ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+
+    for (size_t i = 0; i < tuple->args.size(); i += 2) {
+      remap.begins.push_back(tuple->args[i]);
+      remap.extents.push_back(tuple->args[i + 1]);
+    }
+
+    // Determine bounds in the target buffer
+    auto it = buf_map_.find(remap.target.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find buffer " << remap.target << " @ "
+                                 << remap.target.get();
+    const BufferEntry& target_info = it->second;
+    ICHECK(target_info.in_scope) << "Cannot bind to a buffer that is out of scope";
+    ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size())
+        << "Incorrect number of arguments in buffer_bind_scope attribute.  "
+        << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N).";
+
+    if (target_info.bounds.size() > 0) {
+      Array<PrimExpr> mapped_begins;
+      for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) {
+        mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min);
+      }
+      remap.begins = std::move(mapped_begins);
+    }
+
+    ICHECK(target_info.remap == nullptr) << "Indirect remapping not handled";
+
+    for (size_t i = 0; i < remap.begins.size(); i++) {
+      remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i]));
+      remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i]));
+    }
+
+    // Add a buffer remap entry
+    {
+      BufferEntry source_info;
+      source_info.buffer = source;
+      source_info.remap = std::make_unique<RemapInfo>(remap);
+
+      buf_map_[source.get()] = std::move(source_info);
+    }
+
+    // Define remappings of any remaining Variables (e.g. Store/Load nodes).
+    ArgBinder binder(&var_remap_);
+
+    // Define a view that represents the source's view into the target
+    // buffer.  This Buffer object is only used to define the mapping
+    // to the target buffer, and never actually appears in the TIR
+    // graph.
+    Buffer view = remap.target.MakeSlice(remap.begins, remap.extents);
+    if (source->strides.size() == 0) {
+      ICHECK_EQ(view->strides.size(), 0U)
+          << "Cannot bind a compact buffer to a strided buffer" << view->strides;
+    } else {
+      // Add explicit strides to the view, in order to bind to source.strides[i].
+      view = view.MakeStrideView();
+    }
+    binder.BindBuffer(source, view, source->name, true);
+
+    // Apply the remaps
+    Stmt body = op->body;
+    body = MergeNest(binder.asserts(), body);
+    body = MergeNest(binder.init_nest(), body);
+    body = this->VisitStmt(body);
+    // remove the binds
+    for (const Var& v : binder.defs()) {
+      var_remap_.erase(v.get());
+    }
+    return body;
+  }
+
+  struct RemapInfo {
+    Buffer target;
+    Array<PrimExpr> begins;
+    Array<PrimExpr> extents;
+  };
+
+  // The buffer entry in the flatten map
+  struct BufferEntry {
+    // The storage buffer
+    Buffer buffer;
+    // the bounds of realization, can be null, means everything
+    Region bounds;
+    // Whether the buffer is external
+    bool external{false};
+    // Whether we are within the allocation scope of the buffer.
+    bool in_scope{true};
+
+    // The buffer to which the storage buffer should be remapped.
+    std::unique_ptr<RemapInfo> remap{nullptr};
+
+    PrimExpr ElemOffset() const {
+      ICHECK(remap);
+
+      Buffer copy = remap->target;
+      {
+        Array<PrimExpr> shape;
+        for (auto r : bounds) {
+          shape.push_back(r->extent);
+        }
+        copy.CopyOnWrite()->shape = std::move(shape);
+      }
+
+      Buffer target_slice = copy.MakeSlice(remap->begins, remap->extents);
+      if (buffer->strides.size() == 0) {
+        ICHECK_EQ(target_slice->strides.size(), 0U)
+            << "Trying to bind compact buffer to strided one strides=" << target_slice->strides;
+      } else {
+        target_slice = target_slice.MakeStrideView();
+      }
+
+      return copy->ElemOffset(remap->begins);
+    }
+  };

Review comment:
       Not seeing `struct BufferEntry::ElemOffset` used anywhere. Consider removing or refactoring to use the method?

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,933 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  // Any buffers that give views into a resized buffer should be
+  // updated, both to refer to the resized buffer and to have the view
+  // window updated.  For example, suppose B1 is a 1-D buffer of size
+  // 100 which is only realized on the range (10,50), and buffer V1 is
+  // a view into B1[25:35].  When B1 is replaced with B2, a buffer of
+  // size 40 realized on the range (0,40), V1 must be replaced to be a
+  // view into B2[15:25].
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.  For example, ArgBinder::BindBuffer
+    // and tir.Specialize.
+    if (dim_align_.count(buf) == 0) {
+      return buf;
+    }
+
+    // Can't define the strides for a buffer without a known shape.
+    Array<PrimExpr> shape = buf->shape;
+    if (shape.size() == 0) {
+      return buf;
+    }
+
+    std::vector<PrimExpr> rstrides;
+    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+    int first_dim = 0;
+    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+    for (size_t i = shape.size(); i != 0; --i) {
+      size_t dim = i - 1;
+      if (dim < avec.size() && avec[dim].align_factor != 0) {
+        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+        stride = bound_analyzer_->Simplify(stride);
+      }
+      rstrides.push_back(stride);
+      stride = stride * shape[dim];
+    }
+
+    auto ptr = buf.CopyOnWrite();
+    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+    return buf;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::buffer_dim_align) {
+      auto buffer = Downcast<tir::Buffer>(op->node);
+      const CallNode* tuple = op->value.as<CallNode>();
+      ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+      auto& vinfo = dim_align_[buffer];
+      int dim = tuple->args[0].as<IntImmNode>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2U);
+      Buffer source = Downcast<Buffer>(arr[0]);
+      Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+      Buffer source_with_strides = WithStrides(source);
+
+      {
+        BufferEntry entry;
+        entry.remap_to = source_with_strides;
+        entry.in_scope = true;
+        entry.is_external = false;
+        buf_map_[source] = entry;
+      }
+
+      Stmt body = this->VisitStmt(op->body);
+
+      return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
+                      op->value, body, op->span);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Buffer key = op->buffer;
+    Buffer with_strides = WithStrides(op->buffer);
+    {
+      BufferEntry entry;
+      entry.remap_to = with_strides;
+      entry.in_scope = true;
+      entry.is_external = false;
+      buf_map_[key] = entry;
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    buf_map_[key].in_scope = false;
+    op = stmt.as<BufferRealizeNode>();
+    ICHECK(op);
+
+    return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    return BufferLoad(e.remap_to, op->indices, op->span);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    return BufferStore(e.remap_to, op->value, op->indices, op->span);
+  }
+
+ private:
+  Map<Var, Buffer> updated_extern_buffer_map_;
+
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  // Dimension alignment
+  std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
+
+  struct BufferEntry {
+    Buffer remap_to;
+    bool in_scope;
+    bool is_external;
+  };
+
+  std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation.  All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+  explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+    // External buffers shouldn't be overwritten, even if they have a
+    // BufferRealizeNode.
+    for (auto kv : extern_buffer_map) {
+      external_buffers_.insert(kv.second);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = buf_remap_.find(GetRef<Var>(op));
+    if (it != buf_remap_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "StorageFlattener assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+      curr_thread_scope_.push_back(ts);
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      curr_thread_scope_.pop_back();
+      return stmt;
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Var old_var = op->buffer->data;
+
+    // Don't remap buffers that already have an explicit scope,
+    // or external buffers.
+    std::string str_scope = GetPtrStorageScope(old_var);
+    if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    ICHECK_EQ(buf_remap_.count(old_var), 0)
+        << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
+
+    StorageScope skey;
+    if (curr_thread_scope_.size() == 0) {
+      skey.rank = StorageRank::kGlobal;
+    } else {
+      skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+    }
+
+    auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type);
+    Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
+                old_var->span);
+
+    Buffer buf = op->buffer;
+    buf.CopyOnWrite()->data = new_var;
+
+    buf_remap_[old_var] = buf;
+
+    Stmt body = this->VisitStmt(op->body);
+    return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferLoad(it->second, op->indices, op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices, op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // If the rewritten buffers are part of a buffer_bind_scope, either
+  // as the buffer view or as the the buffer being viewed, then the
+  // buffer_bind_scope must be rewritten to refer to the updated
+  // buffers.
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    bool needs_rewrite = false;
+
+    {
+      auto it = buf_remap_.find(buffer->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        buffer = it->second;
+      }
+    }
+
+    {
+      auto it = buf_remap_.find(target->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        target = it->second;
+      }
+    }
+
+    if (needs_rewrite) {
+      Stmt body = this->VisitStmt(op->body);
+      return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+  // The current thread scope.
+  std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer.  This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+  explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[kv.second.get()] = std::move(e);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<StoreNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (StoreNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Store(new_buf_var, op->value, op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<LoadNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (LoadNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Load(op->dtype, new_buf_var, op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
+                                Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return indices;
+    }
+
+    ICHECK_EQ(begins.size(), indices.size());
+
+    Array<PrimExpr> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(begins[i] + indices[i]);
+    }
+    return out;
+  }
+
+  Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return bounds;
+    }
+
+    ICHECK_EQ(begins.size(), bounds.size());
+
+    Array<Range> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
+    }
+    return out;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferLoad(e.remap->target,
+                        remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferStore(e.remap->target, op->value,
+                         remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    const auto& key = op->buffer.get();
+
+    bool is_external = false;
+
+    if (buf_map_.count(key)) {
+      ICHECK(buf_map_.at(key).external)
+          << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
+
+      is_external = true;
+    } else {
+      BufferEntry e;
+      e.bounds = op->bounds;
+      e.buffer = op->buffer;
+      buf_map_[key] = std::move(e);
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    if (is_external) {
+      buf_map_[key].in_scope = false;
+    }
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const PrefetchNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<PrefetchNode>();
+    ICHECK(op != nullptr);
+
+    const auto& key = op->buffer.get();
+    auto it = buf_map_.find(key);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+    const BufferEntry& e = it->second;
+
+    ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+    ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+        << "Prefetch dim should be the same as buffer dim";
+
+    if (e.remap) {
+      return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
+                      op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // The specific tensor data layout is not determined before
+  // StorageFlatten pass. We use buffer_bind_scope
+  // to specify before hand we want to bind a subregion
+  // of tensor to a symbolic buffer, which get used in extern.
+  //
+  // Example:
+  //
+  // realize A in range [i*4, extent=10) {
+  //   bind Ab to A in [i*4+1, extent=4) {
+  //     call_func(Ab.ptr, Ab.shape[0])
+  //   }
+  // }
+  //
+  // After StorageFlatten
+  //
+  // alloc A[10]
+  //   call(A + 1,  4)
+  //
+  // Buffer is a protocol to declare specific
+  // data layout and shape we expect.
+  // So this function need to check:
+  // - If the bind range is within the realize range
+  // - If we can match the requirement of buffer
+  // - Remap variables such as Ab.ptr to the actual value.
+  //
+  // Here are a few possible failure cases:
+  // - Buffer is declared to have constant shape,
+  //   but we try to bind it to a different one.
+  // - Buffer is declared to be compact(no strides)
+  //   but this binded region is a subregion of
+  //   a matrix(tensor), which means it requires strides.
+  //
+  // We do support a few relaxed case, such as bindingx

Review comment:
       Typo,
   ```suggestion
     // We do support a few relaxed case, such as binding a
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
csullivan commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r717743045



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -137,8 +152,15 @@ class BufferShapeLegalize : public StmtExprMutator {
           << "Inconsistent dimensions for buffer " << op->buffer->name;
 
       Array<PrimExpr> new_indices;
-      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
-        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+
+      // Pad leading indices with zero, matching the "fuzzy_match"
+      // behavior from ArgBinder::BindBuffer.
+      size_t diff = entry.realized_begins.size() - op->indices.size();
+      for (size_t i = 0; i < diff; i++) {
+        new_indices.push_back(0);

Review comment:
       Assuming this is for matching cases like [1, 1, n, m], do we need to check that the leading axes are indeed extent=1?

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -137,8 +152,15 @@ class BufferShapeLegalize : public StmtExprMutator {
           << "Inconsistent dimensions for buffer " << op->buffer->name;
 
       Array<PrimExpr> new_indices;
-      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
-        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+
+      // Pad leading indices with zero, matching the "fuzzy_match"
+      // behavior from ArgBinder::BindBuffer.
+      size_t diff = entry.realized_begins.size() - op->indices.size();
+      for (size_t i = 0; i < diff; i++) {
+        new_indices.push_back(0);

Review comment:
       Let me ask a more general question, is it possible to expand or squeeze all shapes and then do exact matching as before? Noticing a fair amount of special casing in this commit for handling the extra unit dimensions.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r716977020



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -499,6 +1298,40 @@ class StorageFlattener : public StmtExprMutator {
   bool create_bound_attributes_{false};
 };
 
+// The specific tensor data layout is not determined before
+// StorageFlatten pass. We use buffer_bind_scope
+// to specify before hand we want to bind a subregion
+// of tensor to a symbolic buffer, which get used in extern.
+//
+// Example:
+//
+// realize A in range [i*4, extent=10) {
+//   bind Ab to A in [i*4+1, extent=4) {
+//     call_func(Ab.ptr, Ab.shape[0])
+//   }
+// }
+//
+// After StorageFlatten
+//
+// alloc A[10]
+//   call(A + 1,  4)
+//
+// Buffer is a protocol to declare specific
+// data layout and shape we expect.
+// So this function need to check:
+// - If the bind range is within the realize range
+// - If we can match the requirement of buffer
+// - Remap variables such as Ab.ptr to the actual value.
+//
+// Here are a few possible failure cases:
+// - Buffer is declared to have constant shape,
+//   but we try to bind it to a different one.
+// - Buffer is declared to be compact(no strides)
+//   but this binded region is a subregion of
+//   a matrix(tensor), which means it requires strides.
+//
+// We do support a few relaxed case, such as bindingx
+// region with shape [1, 1, n, m] to buffer with shape [n, m]

Review comment:
       Thank you, and updated to have shorter documentation in BufferBindUnwrapper, while StorageFlatten maintains the full documentation.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tmoreau89 commented on pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#issuecomment-926758244


   Interestingly there is one CI unit test failure: 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r715844177



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    Array<PrimExpr> shape = buf->shape;
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.

Review comment:
       Makes sense, and added.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r716931686



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -507,8 +1340,20 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at
 
     IRVisitorWithAnalyzer bound_analyzer;
     bound_analyzer(fptr->body);
+
+    fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));
+
+    auto stride_legalize = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer);
+    fptr->body = stride_legalize(std::move(fptr->body));
+    fptr->buffer_map = stride_legalize.UpdatedExternBufferMap();
+
+    fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body));
+
+    fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));

Review comment:
       I agree that I'd want to unwrap the binds earlier, if I could, to prevent the amount of rewriting needed to pass the updated buffers along.  The main issue I ran into was for IR definitions that directly reference `buf.elem_offset` ([example](https://github.com/apache/tvm/blob/main/tests/python/unittest/test_te_schedule_tensor_core.py#L54)).  In order to determine the offset of the bufffer view relative to the `data` pointer of the parent buffer, the shape and strides of the parent buffer need to be determined first.
   
   I have two ideas for making the implementation be cleaner and more readable.  One is to change how data are packed in an `AttrStmtNode` for `buffer_bind_scope`, to use [the `MatchBufferRegion` class](https://github.com/apache/tvm/blob/main/include/tvm/tir/stmt.h#L1059).  The other is to extend `StmtExprMutator` to act on `BufferNode`, so that the buffer replacements only need to be done in a single location for each pass through.  As it is, rewriting the `BufferStoreNode`, `BufferLoadNode`, `AttrStmtNode`, and `CallNode` must be done each time the buffer gets modified, even if it's just to use the modified Buffer object.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r716978069



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,933 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  // Any buffers that give views into a resized buffer should be
+  // updated, both to refer to the resized buffer and to have the view
+  // window updated.  For example, suppose B1 is a 1-D buffer of size
+  // 100 which is only realized on the range (10,50), and buffer V1 is
+  // a view into B1[25:35].  When B1 is replaced with B2, a buffer of
+  // size 40 realized on the range (0,40), V1 must be replaced to be a
+  // view into B2[15:25].
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.  For example, ArgBinder::BindBuffer
+    // and tir.Specialize.
+    if (dim_align_.count(buf) == 0) {
+      return buf;
+    }
+
+    // Can't define the strides for a buffer without a known shape.
+    Array<PrimExpr> shape = buf->shape;
+    if (shape.size() == 0) {
+      return buf;
+    }
+
+    std::vector<PrimExpr> rstrides;
+    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+    int first_dim = 0;
+    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+    for (size_t i = shape.size(); i != 0; --i) {
+      size_t dim = i - 1;
+      if (dim < avec.size() && avec[dim].align_factor != 0) {
+        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+        stride = bound_analyzer_->Simplify(stride);
+      }
+      rstrides.push_back(stride);
+      stride = stride * shape[dim];
+    }
+
+    auto ptr = buf.CopyOnWrite();
+    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+    return buf;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::buffer_dim_align) {
+      auto buffer = Downcast<tir::Buffer>(op->node);
+      const CallNode* tuple = op->value.as<CallNode>();
+      ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+      auto& vinfo = dim_align_[buffer];
+      int dim = tuple->args[0].as<IntImmNode>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2U);
+      Buffer source = Downcast<Buffer>(arr[0]);
+      Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+      Buffer source_with_strides = WithStrides(source);
+
+      {
+        BufferEntry entry;
+        entry.remap_to = source_with_strides;
+        entry.in_scope = true;
+        entry.is_external = false;
+        buf_map_[source] = entry;
+      }
+
+      Stmt body = this->VisitStmt(op->body);
+
+      return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
+                      op->value, body, op->span);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Buffer key = op->buffer;
+    Buffer with_strides = WithStrides(op->buffer);
+    {
+      BufferEntry entry;
+      entry.remap_to = with_strides;
+      entry.in_scope = true;
+      entry.is_external = false;
+      buf_map_[key] = entry;
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    buf_map_[key].in_scope = false;
+    op = stmt.as<BufferRealizeNode>();
+    ICHECK(op);
+
+    return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    return BufferLoad(e.remap_to, op->indices, op->span);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    return BufferStore(e.remap_to, op->value, op->indices, op->span);
+  }
+
+ private:
+  Map<Var, Buffer> updated_extern_buffer_map_;
+
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  // Dimension alignment
+  std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
+
+  struct BufferEntry {
+    Buffer remap_to;
+    bool in_scope;
+    bool is_external;
+  };
+
+  std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation.  All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+  explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+    // External buffers shouldn't be overwritten, even if they have a
+    // BufferRealizeNode.
+    for (auto kv : extern_buffer_map) {
+      external_buffers_.insert(kv.second);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = buf_remap_.find(GetRef<Var>(op));
+    if (it != buf_remap_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "StorageFlattener assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+      curr_thread_scope_.push_back(ts);
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      curr_thread_scope_.pop_back();
+      return stmt;
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Var old_var = op->buffer->data;
+
+    // Don't remap buffers that already have an explicit scope,
+    // or external buffers.
+    std::string str_scope = GetPtrStorageScope(old_var);
+    if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    ICHECK_EQ(buf_remap_.count(old_var), 0)
+        << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
+
+    StorageScope skey;
+    if (curr_thread_scope_.size() == 0) {
+      skey.rank = StorageRank::kGlobal;
+    } else {
+      skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+    }
+
+    auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type);
+    Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
+                old_var->span);
+
+    Buffer buf = op->buffer;
+    buf.CopyOnWrite()->data = new_var;
+
+    buf_remap_[old_var] = buf;
+
+    Stmt body = this->VisitStmt(op->body);
+    return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferLoad(it->second, op->indices, op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices, op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // If the rewritten buffers are part of a buffer_bind_scope, either
+  // as the buffer view or as the the buffer being viewed, then the
+  // buffer_bind_scope must be rewritten to refer to the updated
+  // buffers.
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    bool needs_rewrite = false;
+
+    {
+      auto it = buf_remap_.find(buffer->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        buffer = it->second;
+      }
+    }
+
+    {
+      auto it = buf_remap_.find(target->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        target = it->second;
+      }
+    }
+
+    if (needs_rewrite) {
+      Stmt body = this->VisitStmt(op->body);
+      return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+  // The current thread scope.
+  std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer.  This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+  explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[kv.second.get()] = std::move(e);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<StoreNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (StoreNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Store(new_buf_var, op->value, op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<LoadNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (LoadNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Load(op->dtype, new_buf_var, op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
+                                Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return indices;
+    }
+
+    ICHECK_EQ(begins.size(), indices.size());
+
+    Array<PrimExpr> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(begins[i] + indices[i]);
+    }
+    return out;
+  }
+
+  Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return bounds;
+    }
+
+    ICHECK_EQ(begins.size(), bounds.size());
+
+    Array<Range> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
+    }
+    return out;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferLoad(e.remap->target,
+                        remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferStore(e.remap->target, op->value,
+                         remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    const auto& key = op->buffer.get();
+
+    bool is_external = false;
+
+    if (buf_map_.count(key)) {
+      ICHECK(buf_map_.at(key).external)
+          << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
+
+      is_external = true;
+    } else {
+      BufferEntry e;
+      e.bounds = op->bounds;
+      e.buffer = op->buffer;
+      buf_map_[key] = std::move(e);
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    if (is_external) {
+      buf_map_[key].in_scope = false;
+    }
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const PrefetchNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<PrefetchNode>();
+    ICHECK(op != nullptr);
+
+    const auto& key = op->buffer.get();
+    auto it = buf_map_.find(key);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+    const BufferEntry& e = it->second;
+
+    ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+    ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+        << "Prefetch dim should be the same as buffer dim";
+
+    if (e.remap) {
+      return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
+                      op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // The specific tensor data layout is not determined before
+  // StorageFlatten pass. We use buffer_bind_scope
+  // to specify before hand we want to bind a subregion
+  // of tensor to a symbolic buffer, which get used in extern.
+  //
+  // Example:
+  //
+  // realize A in range [i*4, extent=10) {
+  //   bind Ab to A in [i*4+1, extent=4) {
+  //     call_func(Ab.ptr, Ab.shape[0])
+  //   }
+  // }
+  //
+  // After StorageFlatten
+  //
+  // alloc A[10]
+  //   call(A + 1,  4)
+  //
+  // Buffer is a protocol to declare specific
+  // data layout and shape we expect.
+  // So this function need to check:
+  // - If the bind range is within the realize range
+  // - If we can match the requirement of buffer
+  // - Remap variables such as Ab.ptr to the actual value.
+  //
+  // Here are a few possible failure cases:
+  // - Buffer is declared to have constant shape,
+  //   but we try to bind it to a different one.
+  // - Buffer is declared to be compact(no strides)
+  //   but this binded region is a subregion of
+  //   a matrix(tensor), which means it requires strides.
+  //
+  // We do support a few relaxed case, such as bindingx

Review comment:
       Typo fixed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#issuecomment-930294838


   Latest round of changes, added another pass to remove assert statements that can be statically validated.  These are placed by ArgBinder::Bind if it can't verify a constraint at the time when it binds a variable.  If later variable substitutions allow the constraint to be statically verified, they can still remain in the final generated code.  These didn't appear prior to the refactor, because `StorageFlatten` made a single substitution, whereas the refactor does so in multiple passes.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#issuecomment-930294838


   Latest round of changes, added another pass to remove assert statements that can be statically validated.  These are placed by ArgBinder::Bind if it can't verify a constraint at the time when it binds a variable.  If later variable substitutions allow the constraint to be statically verified, they can still remain in the final generated code.  These didn't appear prior to the refactor, because `StorageFlatten` made a single substitution, whereas the refactor does so in multiple passes.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r716931686



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -507,8 +1340,20 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at
 
     IRVisitorWithAnalyzer bound_analyzer;
     bound_analyzer(fptr->body);
+
+    fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));
+
+    auto stride_legalize = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer);
+    fptr->body = stride_legalize(std::move(fptr->body));
+    fptr->buffer_map = stride_legalize.UpdatedExternBufferMap();
+
+    fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body));
+
+    fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));

Review comment:
       I agree that I'd want to unwrap the binds earlier, if I could, to prevent the amount of rewriting needed to pass the updated buffers along.  The main issue I ran into was for IR definitions that directly reference `buf.elem_offset` ([example](https://github.com/apache/tvm/blob/main/tests/python/unittest/test_te_schedule_tensor_core.py#L54)).  In order to determine the offset of the bufffer view relative to the `data` pointer of the parent buffer, the shape and strides of the parent buffer need to be determined first.
   
   I have two ideas for making the implementation be cleaner and more readable.  One is to change how data are packed in an `AttrStmtNode` for `buffer_bind_scope`, to use [the `MatchBufferRegion` class](https://github.com/apache/tvm/blob/main/include/tvm/tir/stmt.h#L1059).  The other is to extend `StmtExprMutator` to act on `BufferNode`, so that the buffer replacements only need to be done in a single location for each pass through.  As it is, rewriting the `BufferStoreNode`, `BufferLoadNode`, `AttrStmtNode`, and `CallNode` must be done each time the buffer gets modified, even if it's just to use the modified Buffer object.

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -499,6 +1298,40 @@ class StorageFlattener : public StmtExprMutator {
   bool create_bound_attributes_{false};
 };
 
+// The specific tensor data layout is not determined before
+// StorageFlatten pass. We use buffer_bind_scope
+// to specify before hand we want to bind a subregion
+// of tensor to a symbolic buffer, which get used in extern.
+//
+// Example:
+//
+// realize A in range [i*4, extent=10) {
+//   bind Ab to A in [i*4+1, extent=4) {
+//     call_func(Ab.ptr, Ab.shape[0])
+//   }
+// }
+//
+// After StorageFlatten
+//
+// alloc A[10]
+//   call(A + 1,  4)
+//
+// Buffer is a protocol to declare specific
+// data layout and shape we expect.
+// So this function need to check:
+// - If the bind range is within the realize range
+// - If we can match the requirement of buffer
+// - Remap variables such as Ab.ptr to the actual value.
+//
+// Here are a few possible failure cases:
+// - Buffer is declared to have constant shape,
+//   but we try to bind it to a different one.
+// - Buffer is declared to be compact(no strides)
+//   but this binded region is a subregion of
+//   a matrix(tensor), which means it requires strides.
+//
+// We do support a few relaxed case, such as bindingx
+// region with shape [1, 1, n, m] to buffer with shape [n, m]

Review comment:
       Thank you, and updated to have shorter documentation in BufferBindUnwrapper, while StorageFlatten maintains the full documentation.

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,933 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  // Any buffers that give views into a resized buffer should be
+  // updated, both to refer to the resized buffer and to have the view
+  // window updated.  For example, suppose B1 is a 1-D buffer of size
+  // 100 which is only realized on the range (10,50), and buffer V1 is
+  // a view into B1[25:35].  When B1 is replaced with B2, a buffer of
+  // size 40 realized on the range (0,40), V1 must be replaced to be a
+  // view into B2[15:25].
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.  For example, ArgBinder::BindBuffer
+    // and tir.Specialize.
+    if (dim_align_.count(buf) == 0) {
+      return buf;
+    }
+
+    // Can't define the strides for a buffer without a known shape.
+    Array<PrimExpr> shape = buf->shape;
+    if (shape.size() == 0) {
+      return buf;
+    }
+
+    std::vector<PrimExpr> rstrides;
+    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+    int first_dim = 0;
+    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+    for (size_t i = shape.size(); i != 0; --i) {
+      size_t dim = i - 1;
+      if (dim < avec.size() && avec[dim].align_factor != 0) {
+        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+        stride = bound_analyzer_->Simplify(stride);
+      }
+      rstrides.push_back(stride);
+      stride = stride * shape[dim];
+    }
+
+    auto ptr = buf.CopyOnWrite();
+    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+    return buf;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::buffer_dim_align) {
+      auto buffer = Downcast<tir::Buffer>(op->node);
+      const CallNode* tuple = op->value.as<CallNode>();
+      ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+      auto& vinfo = dim_align_[buffer];
+      int dim = tuple->args[0].as<IntImmNode>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2U);
+      Buffer source = Downcast<Buffer>(arr[0]);
+      Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+      Buffer source_with_strides = WithStrides(source);
+
+      {
+        BufferEntry entry;
+        entry.remap_to = source_with_strides;
+        entry.in_scope = true;
+        entry.is_external = false;
+        buf_map_[source] = entry;
+      }
+
+      Stmt body = this->VisitStmt(op->body);
+
+      return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
+                      op->value, body, op->span);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Buffer key = op->buffer;
+    Buffer with_strides = WithStrides(op->buffer);
+    {
+      BufferEntry entry;
+      entry.remap_to = with_strides;
+      entry.in_scope = true;
+      entry.is_external = false;
+      buf_map_[key] = entry;
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    buf_map_[key].in_scope = false;
+    op = stmt.as<BufferRealizeNode>();
+    ICHECK(op);
+
+    return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    return BufferLoad(e.remap_to, op->indices, op->span);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    return BufferStore(e.remap_to, op->value, op->indices, op->span);
+  }
+
+ private:
+  Map<Var, Buffer> updated_extern_buffer_map_;
+
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  // Dimension alignment
+  std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
+
+  struct BufferEntry {
+    Buffer remap_to;
+    bool in_scope;
+    bool is_external;
+  };
+
+  std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation.  All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+  explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+    // External buffers shouldn't be overwritten, even if they have a
+    // BufferRealizeNode.
+    for (auto kv : extern_buffer_map) {
+      external_buffers_.insert(kv.second);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = buf_remap_.find(GetRef<Var>(op));
+    if (it != buf_remap_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "StorageFlattener assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+      curr_thread_scope_.push_back(ts);
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      curr_thread_scope_.pop_back();
+      return stmt;
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Var old_var = op->buffer->data;
+
+    // Don't remap buffers that already have an explicit scope,
+    // or external buffers.
+    std::string str_scope = GetPtrStorageScope(old_var);
+    if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    ICHECK_EQ(buf_remap_.count(old_var), 0)
+        << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
+
+    StorageScope skey;
+    if (curr_thread_scope_.size() == 0) {
+      skey.rank = StorageRank::kGlobal;
+    } else {
+      skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+    }
+
+    auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type);
+    Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
+                old_var->span);
+
+    Buffer buf = op->buffer;
+    buf.CopyOnWrite()->data = new_var;
+
+    buf_remap_[old_var] = buf;
+
+    Stmt body = this->VisitStmt(op->body);
+    return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferLoad(it->second, op->indices, op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices, op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // If the rewritten buffers are part of a buffer_bind_scope, either
+  // as the buffer view or as the the buffer being viewed, then the
+  // buffer_bind_scope must be rewritten to refer to the updated
+  // buffers.
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    bool needs_rewrite = false;
+
+    {
+      auto it = buf_remap_.find(buffer->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        buffer = it->second;
+      }
+    }
+
+    {
+      auto it = buf_remap_.find(target->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        target = it->second;
+      }
+    }
+
+    if (needs_rewrite) {
+      Stmt body = this->VisitStmt(op->body);
+      return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+  // The current thread scope.
+  std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer.  This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+  explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[kv.second.get()] = std::move(e);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<StoreNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (StoreNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Store(new_buf_var, op->value, op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<LoadNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (LoadNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Load(op->dtype, new_buf_var, op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
+                                Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return indices;
+    }
+
+    ICHECK_EQ(begins.size(), indices.size());
+
+    Array<PrimExpr> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(begins[i] + indices[i]);
+    }
+    return out;
+  }
+
+  Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return bounds;
+    }
+
+    ICHECK_EQ(begins.size(), bounds.size());
+
+    Array<Range> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
+    }
+    return out;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferLoad(e.remap->target,
+                        remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferStore(e.remap->target, op->value,
+                         remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    const auto& key = op->buffer.get();
+
+    bool is_external = false;
+
+    if (buf_map_.count(key)) {
+      ICHECK(buf_map_.at(key).external)
+          << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
+
+      is_external = true;
+    } else {
+      BufferEntry e;
+      e.bounds = op->bounds;
+      e.buffer = op->buffer;
+      buf_map_[key] = std::move(e);
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    if (is_external) {
+      buf_map_[key].in_scope = false;
+    }
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const PrefetchNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<PrefetchNode>();
+    ICHECK(op != nullptr);
+
+    const auto& key = op->buffer.get();
+    auto it = buf_map_.find(key);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+    const BufferEntry& e = it->second;
+
+    ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+    ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+        << "Prefetch dim should be the same as buffer dim";
+
+    if (e.remap) {
+      return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
+                      op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // The specific tensor data layout is not determined before
+  // StorageFlatten pass. We use buffer_bind_scope
+  // to specify before hand we want to bind a subregion
+  // of tensor to a symbolic buffer, which get used in extern.
+  //
+  // Example:
+  //
+  // realize A in range [i*4, extent=10) {
+  //   bind Ab to A in [i*4+1, extent=4) {
+  //     call_func(Ab.ptr, Ab.shape[0])
+  //   }
+  // }
+  //
+  // After StorageFlatten
+  //
+  // alloc A[10]
+  //   call(A + 1,  4)
+  //
+  // Buffer is a protocol to declare specific
+  // data layout and shape we expect.
+  // So this function need to check:
+  // - If the bind range is within the realize range
+  // - If we can match the requirement of buffer
+  // - Remap variables such as Ab.ptr to the actual value.
+  //
+  // Here are a few possible failure cases:
+  // - Buffer is declared to have constant shape,
+  //   but we try to bind it to a different one.
+  // - Buffer is declared to be compact(no strides)
+  //   but this binded region is a subregion of
+  //   a matrix(tensor), which means it requires strides.
+  //
+  // We do support a few relaxed case, such as bindingx
+  // region with shape [1, 1, n, m] to buffer with shape [n, m]
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    // Unpack information from Attribute node
+    RemapInfo remap;
+
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    const Buffer source = Downcast<Buffer>(arr[0]);
+    ICHECK(source.defined());
+    remap.target = Downcast<Buffer>(arr[1]);
+    ICHECK(remap.target.defined());
+    const CallNode* tuple = op->value.as<CallNode>();
+    ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+
+    for (size_t i = 0; i < tuple->args.size(); i += 2) {
+      remap.begins.push_back(tuple->args[i]);
+      remap.extents.push_back(tuple->args[i + 1]);
+    }
+
+    // Determine bounds in the target buffer
+    auto it = buf_map_.find(remap.target.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find buffer " << remap.target << " @ "
+                                 << remap.target.get();
+    const BufferEntry& target_info = it->second;
+    ICHECK(target_info.in_scope) << "Cannot bind to a buffer that is out of scope";
+    ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size())
+        << "Incorrect number of arguments in buffer_bind_scope attribute.  "
+        << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N).";
+
+    if (target_info.bounds.size() > 0) {
+      Array<PrimExpr> mapped_begins;
+      for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) {
+        mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min);
+      }
+      remap.begins = std::move(mapped_begins);
+    }
+
+    ICHECK(target_info.remap == nullptr) << "Indirect remapping not handled";
+
+    for (size_t i = 0; i < remap.begins.size(); i++) {
+      remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i]));
+      remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i]));
+    }
+
+    // Add a buffer remap entry
+    {
+      BufferEntry source_info;
+      source_info.buffer = source;
+      source_info.remap = std::make_unique<RemapInfo>(remap);
+
+      buf_map_[source.get()] = std::move(source_info);
+    }
+
+    // Define remappings of any remaining Variables (e.g. Store/Load nodes).
+    ArgBinder binder(&var_remap_);
+
+    // Define a view that represents the source's view into the target
+    // buffer.  This Buffer object is only used to define the mapping
+    // to the target buffer, and never actually appears in the TIR
+    // graph.
+    Buffer view = remap.target.MakeSlice(remap.begins, remap.extents);
+    if (source->strides.size() == 0) {
+      ICHECK_EQ(view->strides.size(), 0U)
+          << "Cannot bind a compact buffer to a strided buffer" << view->strides;
+    } else {
+      // Add explicit strides to the view, in order to bind to source.strides[i].
+      view = view.MakeStrideView();
+    }
+    binder.BindBuffer(source, view, source->name, true);
+
+    // Apply the remaps
+    Stmt body = op->body;
+    body = MergeNest(binder.asserts(), body);
+    body = MergeNest(binder.init_nest(), body);
+    body = this->VisitStmt(body);
+    // remove the binds
+    for (const Var& v : binder.defs()) {
+      var_remap_.erase(v.get());
+    }
+    return body;
+  }
+
+  struct RemapInfo {
+    Buffer target;
+    Array<PrimExpr> begins;
+    Array<PrimExpr> extents;
+  };
+
+  // The buffer entry in the flatten map
+  struct BufferEntry {
+    // The storage buffer
+    Buffer buffer;
+    // the bounds of realization, can be null, means everything
+    Region bounds;
+    // Whether the buffer is external
+    bool external{false};
+    // Whether we are within the allocation scope of the buffer.
+    bool in_scope{true};
+
+    // The buffer to which the storage buffer should be remapped.
+    std::unique_ptr<RemapInfo> remap{nullptr};
+
+    PrimExpr ElemOffset() const {
+      ICHECK(remap);
+
+      Buffer copy = remap->target;
+      {
+        Array<PrimExpr> shape;
+        for (auto r : bounds) {
+          shape.push_back(r->extent);
+        }
+        copy.CopyOnWrite()->shape = std::move(shape);
+      }
+
+      Buffer target_slice = copy.MakeSlice(remap->begins, remap->extents);
+      if (buffer->strides.size() == 0) {
+        ICHECK_EQ(target_slice->strides.size(), 0U)
+            << "Trying to bind compact buffer to strided one strides=" << target_slice->strides;
+      } else {
+        target_slice = target_slice.MakeStrideView();
+      }
+
+      return copy->ElemOffset(remap->begins);
+    }
+  };

Review comment:
       And removed.  It was part of an earlier (and broken) implementation.

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,933 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  // Any buffers that give views into a resized buffer should be
+  // updated, both to refer to the resized buffer and to have the view
+  // window updated.  For example, suppose B1 is a 1-D buffer of size
+  // 100 which is only realized on the range (10,50), and buffer V1 is
+  // a view into B1[25:35].  When B1 is replaced with B2, a buffer of
+  // size 40 realized on the range (0,40), V1 must be replaced to be a
+  // view into B2[15:25].
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.  For example, ArgBinder::BindBuffer
+    // and tir.Specialize.
+    if (dim_align_.count(buf) == 0) {
+      return buf;
+    }
+
+    // Can't define the strides for a buffer without a known shape.
+    Array<PrimExpr> shape = buf->shape;
+    if (shape.size() == 0) {
+      return buf;
+    }
+
+    std::vector<PrimExpr> rstrides;
+    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+    int first_dim = 0;
+    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+    for (size_t i = shape.size(); i != 0; --i) {
+      size_t dim = i - 1;
+      if (dim < avec.size() && avec[dim].align_factor != 0) {
+        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+        stride = bound_analyzer_->Simplify(stride);
+      }
+      rstrides.push_back(stride);
+      stride = stride * shape[dim];
+    }
+
+    auto ptr = buf.CopyOnWrite();
+    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+    return buf;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::buffer_dim_align) {
+      auto buffer = Downcast<tir::Buffer>(op->node);
+      const CallNode* tuple = op->value.as<CallNode>();
+      ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+      auto& vinfo = dim_align_[buffer];
+      int dim = tuple->args[0].as<IntImmNode>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2U);
+      Buffer source = Downcast<Buffer>(arr[0]);
+      Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+      Buffer source_with_strides = WithStrides(source);
+
+      {
+        BufferEntry entry;
+        entry.remap_to = source_with_strides;
+        entry.in_scope = true;
+        entry.is_external = false;
+        buf_map_[source] = entry;
+      }
+
+      Stmt body = this->VisitStmt(op->body);
+
+      return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
+                      op->value, body, op->span);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Buffer key = op->buffer;
+    Buffer with_strides = WithStrides(op->buffer);
+    {
+      BufferEntry entry;
+      entry.remap_to = with_strides;
+      entry.in_scope = true;
+      entry.is_external = false;
+      buf_map_[key] = entry;
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    buf_map_[key].in_scope = false;
+    op = stmt.as<BufferRealizeNode>();
+    ICHECK(op);
+
+    return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    return BufferLoad(e.remap_to, op->indices, op->span);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    return BufferStore(e.remap_to, op->value, op->indices, op->span);
+  }
+
+ private:
+  Map<Var, Buffer> updated_extern_buffer_map_;
+
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  // Dimension alignment
+  std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
+
+  struct BufferEntry {
+    Buffer remap_to;
+    bool in_scope;
+    bool is_external;
+  };
+
+  std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation.  All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+  explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+    // External buffers shouldn't be overwritten, even if they have a
+    // BufferRealizeNode.
+    for (auto kv : extern_buffer_map) {
+      external_buffers_.insert(kv.second);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = buf_remap_.find(GetRef<Var>(op));
+    if (it != buf_remap_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "StorageFlattener assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+      curr_thread_scope_.push_back(ts);
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      curr_thread_scope_.pop_back();
+      return stmt;
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Var old_var = op->buffer->data;
+
+    // Don't remap buffers that already have an explicit scope,
+    // or external buffers.
+    std::string str_scope = GetPtrStorageScope(old_var);
+    if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    ICHECK_EQ(buf_remap_.count(old_var), 0)
+        << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
+
+    StorageScope skey;
+    if (curr_thread_scope_.size() == 0) {
+      skey.rank = StorageRank::kGlobal;
+    } else {
+      skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+    }
+
+    auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type);
+    Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
+                old_var->span);
+
+    Buffer buf = op->buffer;
+    buf.CopyOnWrite()->data = new_var;
+
+    buf_remap_[old_var] = buf;
+
+    Stmt body = this->VisitStmt(op->body);
+    return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferLoad(it->second, op->indices, op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices, op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // If the rewritten buffers are part of a buffer_bind_scope, either
+  // as the buffer view or as the the buffer being viewed, then the
+  // buffer_bind_scope must be rewritten to refer to the updated
+  // buffers.
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    bool needs_rewrite = false;
+
+    {
+      auto it = buf_remap_.find(buffer->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        buffer = it->second;
+      }
+    }
+
+    {
+      auto it = buf_remap_.find(target->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        target = it->second;
+      }
+    }
+
+    if (needs_rewrite) {
+      Stmt body = this->VisitStmt(op->body);
+      return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+  // The current thread scope.
+  std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer.  This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+  explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[kv.second.get()] = std::move(e);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<StoreNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (StoreNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Store(new_buf_var, op->value, op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<LoadNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (LoadNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Load(op->dtype, new_buf_var, op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
+                                Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return indices;
+    }
+
+    ICHECK_EQ(begins.size(), indices.size());
+
+    Array<PrimExpr> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(begins[i] + indices[i]);
+    }
+    return out;
+  }
+
+  Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return bounds;
+    }
+
+    ICHECK_EQ(begins.size(), bounds.size());
+
+    Array<Range> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
+    }
+    return out;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferLoad(e.remap->target,
+                        remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferStore(e.remap->target, op->value,
+                         remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    const auto& key = op->buffer.get();
+
+    bool is_external = false;
+
+    if (buf_map_.count(key)) {
+      ICHECK(buf_map_.at(key).external)
+          << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
+
+      is_external = true;
+    } else {
+      BufferEntry e;
+      e.bounds = op->bounds;
+      e.buffer = op->buffer;
+      buf_map_[key] = std::move(e);
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    if (is_external) {
+      buf_map_[key].in_scope = false;
+    }
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const PrefetchNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<PrefetchNode>();
+    ICHECK(op != nullptr);
+
+    const auto& key = op->buffer.get();
+    auto it = buf_map_.find(key);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+    const BufferEntry& e = it->second;
+
+    ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+    ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+        << "Prefetch dim should be the same as buffer dim";
+
+    if (e.remap) {
+      return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
+                      op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // The specific tensor data layout is not determined before
+  // StorageFlatten pass. We use buffer_bind_scope
+  // to specify before hand we want to bind a subregion
+  // of tensor to a symbolic buffer, which get used in extern.
+  //
+  // Example:
+  //
+  // realize A in range [i*4, extent=10) {
+  //   bind Ab to A in [i*4+1, extent=4) {
+  //     call_func(Ab.ptr, Ab.shape[0])
+  //   }
+  // }
+  //
+  // After StorageFlatten
+  //
+  // alloc A[10]
+  //   call(A + 1,  4)
+  //
+  // Buffer is a protocol to declare specific
+  // data layout and shape we expect.
+  // So this function need to check:
+  // - If the bind range is within the realize range
+  // - If we can match the requirement of buffer
+  // - Remap variables such as Ab.ptr to the actual value.
+  //
+  // Here are a few possible failure cases:
+  // - Buffer is declared to have constant shape,
+  //   but we try to bind it to a different one.
+  // - Buffer is declared to be compact(no strides)
+  //   but this binded region is a subregion of
+  //   a matrix(tensor), which means it requires strides.
+  //
+  // We do support a few relaxed case, such as bindingx

Review comment:
       Typo fixed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tmoreau89 commented on pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#issuecomment-931794142


   Thanks @csullivan and @Lunderberg the PR has been merged.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r717957244



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -137,8 +152,15 @@ class BufferShapeLegalize : public StmtExprMutator {
           << "Inconsistent dimensions for buffer " << op->buffer->name;
 
       Array<PrimExpr> new_indices;
-      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
-        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+
+      // Pad leading indices with zero, matching the "fuzzy_match"
+      // behavior from ArgBinder::BindBuffer.
+      size_t diff = entry.realized_begins.size() - op->indices.size();
+      for (size_t i = 0; i < diff; i++) {
+        new_indices.push_back(0);

Review comment:
       For the specific case, there is a check as part of the call to `ArgBinder::BindBuffer`.  I'll add a comment to indicate that.
   
   For the general question, that was the intent in `BufferShapeLegalize`, so that afterwards the buffers all have a single well-defined shape.  It looks like I had missed one case where `BufferBindUnwrapper` changed the number of dimensions when binding to an external buffer, but I've added another commit to this PR to pull that behavior into `BufferShapeLegalize` instead.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tmoreau89 removed a comment on pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
tmoreau89 removed a comment on pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#issuecomment-926758244


   Interestingly there is one CI unit test failure: 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r716977857



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,933 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  // Any buffers that give views into a resized buffer should be
+  // updated, both to refer to the resized buffer and to have the view
+  // window updated.  For example, suppose B1 is a 1-D buffer of size
+  // 100 which is only realized on the range (10,50), and buffer V1 is
+  // a view into B1[25:35].  When B1 is replaced with B2, a buffer of
+  // size 40 realized on the range (0,40), V1 must be replaced to be a
+  // view into B2[15:25].
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.  For example, ArgBinder::BindBuffer
+    // and tir.Specialize.
+    if (dim_align_.count(buf) == 0) {
+      return buf;
+    }
+
+    // Can't define the strides for a buffer without a known shape.
+    Array<PrimExpr> shape = buf->shape;
+    if (shape.size() == 0) {
+      return buf;
+    }
+
+    std::vector<PrimExpr> rstrides;
+    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+    int first_dim = 0;
+    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+    for (size_t i = shape.size(); i != 0; --i) {
+      size_t dim = i - 1;
+      if (dim < avec.size() && avec[dim].align_factor != 0) {
+        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+        stride = bound_analyzer_->Simplify(stride);
+      }
+      rstrides.push_back(stride);
+      stride = stride * shape[dim];
+    }
+
+    auto ptr = buf.CopyOnWrite();
+    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+    return buf;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::buffer_dim_align) {
+      auto buffer = Downcast<tir::Buffer>(op->node);
+      const CallNode* tuple = op->value.as<CallNode>();
+      ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+      auto& vinfo = dim_align_[buffer];
+      int dim = tuple->args[0].as<IntImmNode>()->value;
+      if (static_cast<size_t>(dim) >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2U);
+      Buffer source = Downcast<Buffer>(arr[0]);
+      Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+      Buffer source_with_strides = WithStrides(source);
+
+      {
+        BufferEntry entry;
+        entry.remap_to = source_with_strides;
+        entry.in_scope = true;
+        entry.is_external = false;
+        buf_map_[source] = entry;
+      }
+
+      Stmt body = this->VisitStmt(op->body);
+
+      return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
+                      op->value, body, op->span);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Buffer key = op->buffer;
+    Buffer with_strides = WithStrides(op->buffer);
+    {
+      BufferEntry entry;
+      entry.remap_to = with_strides;
+      entry.in_scope = true;
+      entry.is_external = false;
+      buf_map_[key] = entry;
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    buf_map_[key].in_scope = false;
+    op = stmt.as<BufferRealizeNode>();
+    ICHECK(op);
+
+    return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    return BufferLoad(e.remap_to, op->indices, op->span);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    return BufferStore(e.remap_to, op->value, op->indices, op->span);
+  }
+
+ private:
+  Map<Var, Buffer> updated_extern_buffer_map_;
+
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  // Dimension alignment
+  std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
+
+  struct BufferEntry {
+    Buffer remap_to;
+    bool in_scope;
+    bool is_external;
+  };
+
+  std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation.  All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+  explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+    // External buffers shouldn't be overwritten, even if they have a
+    // BufferRealizeNode.
+    for (auto kv : extern_buffer_map) {
+      external_buffers_.insert(kv.second);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = buf_remap_.find(GetRef<Var>(op));
+    if (it != buf_remap_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "StorageFlattener assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv = Downcast<IterVar>(op->node);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+      curr_thread_scope_.push_back(ts);
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      curr_thread_scope_.pop_back();
+      return stmt;
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    Var old_var = op->buffer->data;
+
+    // Don't remap buffers that already have an explicit scope,
+    // or external buffers.
+    std::string str_scope = GetPtrStorageScope(old_var);
+    if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    ICHECK_EQ(buf_remap_.count(old_var), 0)
+        << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
+
+    StorageScope skey;
+    if (curr_thread_scope_.size() == 0) {
+      skey.rank = StorageRank::kGlobal;
+    } else {
+      skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+    }
+
+    auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type);
+    Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
+                old_var->span);
+
+    Buffer buf = op->buffer;
+    buf.CopyOnWrite()->data = new_var;
+
+    buf_remap_[old_var] = buf;
+
+    Stmt body = this->VisitStmt(op->body);
+    return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferLoad(it->second, op->indices, op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = buf_remap_.find(op->buffer->data);
+    if (it != buf_remap_.end()) {
+      return BufferStore(it->second, op->value, op->indices, op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // If the rewritten buffers are part of a buffer_bind_scope, either
+  // as the buffer view or as the the buffer being viewed, then the
+  // buffer_bind_scope must be rewritten to refer to the updated
+  // buffers.
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    bool needs_rewrite = false;
+
+    {
+      auto it = buf_remap_.find(buffer->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        buffer = it->second;
+      }
+    }
+
+    {
+      auto it = buf_remap_.find(target->data);
+      if (it != buf_remap_.end()) {
+        needs_rewrite = true;
+        target = it->second;
+      }
+    }
+
+    if (needs_rewrite) {
+      Stmt body = this->VisitStmt(op->body);
+      return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+  // The current thread scope.
+  std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer.  This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+  explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[kv.second.get()] = std::move(e);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<StoreNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (StoreNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Store(new_buf_var, op->value, op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<LoadNode>();
+    auto it = var_remap_.find(op->buffer_var.get());
+    if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+      // TODO(Lunderberg): Change from warning to error once all mixed
+      // use of physical/logical layouts is removed.
+      DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), "
+                    << "but is accessed as a pointer (LoadNode).";
+
+      ICHECK(it->second.as<VarNode>());
+      Var new_buf_var = Downcast<Var>(it->second);
+      return Load(op->dtype, new_buf_var, op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+        << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
+        << "and all buffer_dim_align annotations are removed.  "
+        << "Please run BufferStrideLegalize first.";
+
+    if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
+                                Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return indices;
+    }
+
+    ICHECK_EQ(begins.size(), indices.size());
+
+    Array<PrimExpr> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(begins[i] + indices[i]);
+    }
+    return out;
+  }
+
+  Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
+    ICHECK_EQ(begins.size(), extents.size());
+
+    if (begins.size() == 0) {
+      return bounds;
+    }
+
+    ICHECK_EQ(begins.size(), bounds.size());
+
+    Array<Range> out;
+    for (size_t i = 0; i < begins.size(); i++) {
+      out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
+    }
+    return out;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferLoad(e.remap->target,
+                        remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+
+    auto it = buf_map_.find(op->buffer.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer;
+    const BufferEntry& e = it->second;
+    ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope";
+
+    if (e.remap) {
+      return BufferStore(e.remap->target, op->value,
+                         remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    const auto& key = op->buffer.get();
+
+    bool is_external = false;
+
+    if (buf_map_.count(key)) {
+      ICHECK(buf_map_.at(key).external)
+          << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
+
+      is_external = true;
+    } else {
+      BufferEntry e;
+      e.bounds = op->bounds;
+      e.buffer = op->buffer;
+      buf_map_[key] = std::move(e);
+    }
+
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+    if (is_external) {
+      buf_map_[key].in_scope = false;
+    }
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const PrefetchNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<PrefetchNode>();
+    ICHECK(op != nullptr);
+
+    const auto& key = op->buffer.get();
+    auto it = buf_map_.find(key);
+    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+    const BufferEntry& e = it->second;
+
+    ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+    ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+        << "Prefetch dim should be the same as buffer dim";
+
+    if (e.remap) {
+      return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
+                      op->span);
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  // The specific tensor data layout is not determined before
+  // StorageFlatten pass. We use buffer_bind_scope
+  // to specify before hand we want to bind a subregion
+  // of tensor to a symbolic buffer, which get used in extern.
+  //
+  // Example:
+  //
+  // realize A in range [i*4, extent=10) {
+  //   bind Ab to A in [i*4+1, extent=4) {
+  //     call_func(Ab.ptr, Ab.shape[0])
+  //   }
+  // }
+  //
+  // After StorageFlatten
+  //
+  // alloc A[10]
+  //   call(A + 1,  4)
+  //
+  // Buffer is a protocol to declare specific
+  // data layout and shape we expect.
+  // So this function need to check:
+  // - If the bind range is within the realize range
+  // - If we can match the requirement of buffer
+  // - Remap variables such as Ab.ptr to the actual value.
+  //
+  // Here are a few possible failure cases:
+  // - Buffer is declared to have constant shape,
+  //   but we try to bind it to a different one.
+  // - Buffer is declared to be compact(no strides)
+  //   but this binded region is a subregion of
+  //   a matrix(tensor), which means it requires strides.
+  //
+  // We do support a few relaxed case, such as bindingx
+  // region with shape [1, 1, n, m] to buffer with shape [n, m]
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    // Unpack information from Attribute node
+    RemapInfo remap;
+
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    const Buffer source = Downcast<Buffer>(arr[0]);
+    ICHECK(source.defined());
+    remap.target = Downcast<Buffer>(arr[1]);
+    ICHECK(remap.target.defined());
+    const CallNode* tuple = op->value.as<CallNode>();
+    ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+
+    for (size_t i = 0; i < tuple->args.size(); i += 2) {
+      remap.begins.push_back(tuple->args[i]);
+      remap.extents.push_back(tuple->args[i + 1]);
+    }
+
+    // Determine bounds in the target buffer
+    auto it = buf_map_.find(remap.target.get());
+    ICHECK(it != buf_map_.end()) << "Cannot find buffer " << remap.target << " @ "
+                                 << remap.target.get();
+    const BufferEntry& target_info = it->second;
+    ICHECK(target_info.in_scope) << "Cannot bind to a buffer that is out of scope";
+    ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size())
+        << "Incorrect number of arguments in buffer_bind_scope attribute.  "
+        << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N).";
+
+    if (target_info.bounds.size() > 0) {
+      Array<PrimExpr> mapped_begins;
+      for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) {
+        mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min);
+      }
+      remap.begins = std::move(mapped_begins);
+    }
+
+    ICHECK(target_info.remap == nullptr) << "Indirect remapping not handled";
+
+    for (size_t i = 0; i < remap.begins.size(); i++) {
+      remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i]));
+      remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i]));
+    }
+
+    // Add a buffer remap entry
+    {
+      BufferEntry source_info;
+      source_info.buffer = source;
+      source_info.remap = std::make_unique<RemapInfo>(remap);
+
+      buf_map_[source.get()] = std::move(source_info);
+    }
+
+    // Define remappings of any remaining Variables (e.g. Store/Load nodes).
+    ArgBinder binder(&var_remap_);
+
+    // Define a view that represents the source's view into the target
+    // buffer.  This Buffer object is only used to define the mapping
+    // to the target buffer, and never actually appears in the TIR
+    // graph.
+    Buffer view = remap.target.MakeSlice(remap.begins, remap.extents);
+    if (source->strides.size() == 0) {
+      ICHECK_EQ(view->strides.size(), 0U)
+          << "Cannot bind a compact buffer to a strided buffer" << view->strides;
+    } else {
+      // Add explicit strides to the view, in order to bind to source.strides[i].
+      view = view.MakeStrideView();
+    }
+    binder.BindBuffer(source, view, source->name, true);
+
+    // Apply the remaps
+    Stmt body = op->body;
+    body = MergeNest(binder.asserts(), body);
+    body = MergeNest(binder.init_nest(), body);
+    body = this->VisitStmt(body);
+    // remove the binds
+    for (const Var& v : binder.defs()) {
+      var_remap_.erase(v.get());
+    }
+    return body;
+  }
+
+  struct RemapInfo {
+    Buffer target;
+    Array<PrimExpr> begins;
+    Array<PrimExpr> extents;
+  };
+
+  // The buffer entry in the flatten map
+  struct BufferEntry {
+    // The storage buffer
+    Buffer buffer;
+    // the bounds of realization, can be null, means everything
+    Region bounds;
+    // Whether the buffer is external
+    bool external{false};
+    // Whether we are within the allocation scope of the buffer.
+    bool in_scope{true};
+
+    // The buffer to which the storage buffer should be remapped.
+    std::unique_ptr<RemapInfo> remap{nullptr};
+
+    PrimExpr ElemOffset() const {
+      ICHECK(remap);
+
+      Buffer copy = remap->target;
+      {
+        Array<PrimExpr> shape;
+        for (auto r : bounds) {
+          shape.push_back(r->extent);
+        }
+        copy.CopyOnWrite()->shape = std::move(shape);
+      }
+
+      Buffer target_slice = copy.MakeSlice(remap->begins, remap->extents);
+      if (buffer->strides.size() == 0) {
+        ICHECK_EQ(target_slice->strides.size(), 0U)
+            << "Trying to bind compact buffer to strided one strides=" << target_slice->strides;
+      } else {
+        target_slice = target_slice.MakeStrideView();
+      }
+
+      return copy->ElemOffset(remap->begins);
+    }
+  };

Review comment:
       And removed.  It was part of an earlier (and broken) implementation.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r715837497



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {

Review comment:
       Makes sense, and comments have been added.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tmoreau89 merged pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
tmoreau89 merged pull request #9091:
URL: https://github.com/apache/tvm/pull/9091


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#issuecomment-925826598


   @csullivan 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9091: [TIR] tir.transform.StorageFlatten refactor

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#issuecomment-926141602


   CC: @vinx13 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org