You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/24 17:35:17 UTC

[GitHub] [tvm] csullivan commented on a change in pull request #9091: [TIR] tir.transform.StorageFlatten refactor

csullivan commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r715763492



##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";

Review comment:
       Consider making this a CHECK as external buffers can be provided by the user and failing this check could indicate a use error. 

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer = Downcast<Buffer>(arr[0]);
+    ICHECK(buffer.defined());
+    Buffer target = Downcast<Buffer>(arr[1]);
+    ICHECK(target.defined());
+
+    auto it = internal_buf_map_.find(target);
+    if (it == internal_buf_map_.end()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    const InternalBufferRemap& target_remap = it->second;
+
+    ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+                                  << " to the out-of-scope buffer " << target_remap.remap_to->name;
+
+    Call tuple = Downcast<Call>(op->value);
+    ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+    Array<PrimExpr> new_tuple_args;
+    Array<PrimExpr> realized_begins;
+    Array<PrimExpr> realized_shape;
+    ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2);
+    for (size_t i = 0; i < target_remap.realized_begins.size(); i++) {
+      PrimExpr parent_begin = tuple->args[2 * i];
+      PrimExpr view_extent = tuple->args[2 * i + 1];
+      // Offset the begin of the buffer view by the offset of the target buffer.
+      new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]);
+      // Keep the extent of the buffer view the same.
+      new_tuple_args.push_back(view_extent);
+      // Use the extent of the buffer view to define the buffer view's shape.
+      realized_shape.push_back(view_extent);
+      // Within the buffer view, indices start at 0.
+      realized_begins.push_back(0);
+    }
+
+    Buffer key = buffer;
+
+    auto write_ptr = buffer.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.realized_begins = realized_begins;
+      remap.remap_to = buffer;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
+                         Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span),
+                         this->VisitStmt(op->body));
+    internal_buf_map_.at(key).in_scope = false;
+    return stmt;
+  }
+
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+  struct InternalBufferRemap {
+    Buffer remap_to;
+    Array<PrimExpr> realized_begins;
+    bool in_scope;
+  };
+
+  std::unordered_map<Buffer, InternalBufferRemap, ObjectPtrHash, ObjectPtrEqual> internal_buf_map_;
+
+  IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+  explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                                IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      Buffer buf = kv.second;
+      Buffer with_strides = WithStrides(buf);
+      {
+        BufferEntry entry;
+        entry.remap_to = with_strides;
+        entry.in_scope = true;
+        entry.is_external = true;
+        buf_map_[buf] = entry;
+      }
+      updated_extern_buffer_map_.Set(kv.first, with_strides);
+    }
+  }
+
+  Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
+
+  Buffer WithStrides(Buffer buf) {
+    auto it = buf_map_.find(buf);
+    if (it != buf_map_.end()) {
+      const BufferEntry& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+      return entry.remap_to;
+    }
+
+    if (buf->strides.size()) {
+      ICHECK_EQ(buf->strides.size(), buf->shape.size());
+      return buf;
+    }
+
+    Array<PrimExpr> shape = buf->shape;
+
+    // Keeping this to have matched behavior to previous version.
+    // There are many parts of the codebase that assume that a strided
+    // array cannot be compact.

Review comment:
       nit: An example of one such case that makes this assumption could be useful for context.

##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer.  For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+  explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+                               IRVisitorWithAnalyzer* bound_analyzer)
+      : bound_analyzer_(bound_analyzer) {
+    for (auto kv : extern_buffer_map) {
+      extern_buffers_.insert(kv.second);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    // External buffers should not be changed.
+    if (extern_buffers_.count(op->buffer)) {
+      ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+          << "External buffer realize has mismatched dimension";
+      Stmt stmt = StmtExprMutator::VisitStmt_(op);
+      op = stmt.as<BufferRealizeNode>();
+      ICHECK(op);
+
+      for (size_t i = 0; i < op->bounds.size(); i++) {
+        PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
+        std::ostringstream ss;
+        ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
+           << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
+        if (auto eq_int = eq.as<IntImmNode>()) {
+          ICHECK(eq_int->value) << ss.str();
+        } else {
+          stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+        }
+      }
+      return stmt;
+    }
+
+    // Compute the new buffer shape, new realization bounds, and the
+    // offsets to be applied to buffer access.
+    Array<PrimExpr> realized_shape;
+    Array<PrimExpr> realized_begins;
+    Array<Range> new_bounds;
+    for (size_t i = 0; i < op->bounds.size(); i++) {
+      const Range& bound = op->bounds[i];
+      realized_shape.push_back(bound->extent);
+      realized_begins.push_back(bound->min);
+      new_bounds.push_back({0, bound->extent});
+    }
+
+    Buffer key = op->buffer;
+
+    Buffer buf = op->buffer;
+    auto write_ptr = buf.CopyOnWrite();
+    write_ptr->shape = realized_shape;
+
+    {
+      InternalBufferRemap remap;
+      remap.remap_to = buf;
+      remap.realized_begins = realized_begins;
+      remap.in_scope = true;
+      internal_buf_map_[key] = remap;
+    }
+
+    Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
+
+    internal_buf_map_.at(key).in_scope = false;
+
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferStore updated = GetRef<BufferStore>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      stmt = updated;
+    }
+
+    return stmt;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op);
+
+    auto it = internal_buf_map_.find(op->buffer);
+    if (it != internal_buf_map_.end()) {
+      const InternalBufferRemap& entry = it->second;
+      ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+      ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+          << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+      Array<PrimExpr> new_indices;
+      for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+        new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+      }
+
+      BufferLoad updated = GetRef<BufferLoad>(op);
+      auto write_ptr = updated.CopyOnWrite();
+      write_ptr->indices = new_indices;
+      write_ptr->buffer = entry.remap_to;
+      expr = updated;
+    }
+
+    return expr;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->node->IsInstance<tir::BufferNode>()) {
+      // Visit body before checking internal_buf_map_, because we
+      // don't know if the BufferNode needs to be changed until we
+      // look in the body for a BufferRealizeNode with different
+      // extents.
+      Stmt body = this->VisitStmt(op->body);
+
+      Buffer buffer = Downcast<tir::Buffer>(op->node);
+      auto it = internal_buf_map_.find(buffer);
+      if (it != internal_buf_map_.end()) {
+        buffer = it->second.remap_to;
+        return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+      }
+      return AttrStmt(buffer, op->attr_key, op->value, body);
+
+    } else if (op->attr_key == attr::buffer_bind_scope) {
+      return HandleBufferBindScope(op);
+    }
+
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {

Review comment:
       Add a comment describing that the buffer bind scope buffer attributes are updated according to the legalized buffer. Similarly add comments in the other passes which have BufferBindScope handler methods. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org