You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/25 13:49:33 UTC

[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #12172: [TIR Pass] Decouple flatten buffer to lower opaque block and flatten buffer.

wrongtest-intellif commented on code in PR #12172:
URL: https://github.com/apache/tvm/pull/12172#discussion_r928907466


##########
src/tir/transforms/flatten_buffer.cc:
##########
@@ -68,76 +52,25 @@ class BufferFlattener : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const BlockRealizeNode* op) final {
-    // We have convert blocks into opaque blocks in previous passes.
-    ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please "
-                                       "call pass ConvertBlocksToOpaque before.";
-    // Step 1. Visit the body
-    Block new_block = Downcast<Block>(this->VisitStmt(op->block));
-    PrimExpr predicate = this->VisitExpr(op->predicate);
-    // Step 2. Transform the `predicate` to if-then-else
-    Stmt body = new_block->body;
-    if (!is_one(predicate)) {
-      body = IfThenElse(predicate, std::move(body));
-    }
-    // Step 3. Handle allocations in reverse order
-    for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
-      Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]);
-      body = Allocate(buffer->data, buffer->dtype, buffer->shape, const_true(), std::move(body));
-    }
-    return body;
-  }
-
-  Stmt VisitStmt_(const ForNode* op) final {
-    // Step 1. Update unit loop info.
-    PrimExpr min = this->VisitExpr(op->min);
-    PrimExpr extent = this->VisitExpr(op->extent);
-    if (is_one(extent) && op->annotations.empty()) {
-      // handling unit loop
-      unit_loop_vars_[op->loop_var] = min;
-    }
-    // Step 2. Visit recursively
-    Stmt body = this->VisitStmt(op->body);
-    // Step 3. Create new For loop accordingly
-    if (op->kind == ForKind::kThreadBinding) {
-      // Case 1. Thread binding
-      ICHECK(op->thread_binding.defined());
-      String thread_tag = op->thread_binding.value()->thread_tag;
-      body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
-    } else if (is_one(extent) && op->annotations.empty()) {
-      // Case 2. Unit loop
-      return body;
-    } else {
-      // Case 3. An ordinary loop
-      body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
-    }
-    // Step 4. Handle annotations
-    std::set<std::string> ordered_ann_keys;
-    for (const auto& annotation : op->annotations) {
-      ordered_ann_keys.insert(annotation.first);
-    }
-    for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) {
-      const std::string& ann_key = *it;
-      const ObjectRef& ann_value = op->annotations.at(ann_key);
-      if (attr::IsPragmaKey(ann_key)) {
-        body =
-            AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body));
-      }
+  Stmt VisitStmt_(const AllocateNode* op) final {
+    Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+    // TODO(Lunderberg): Move the handling of boolean into a
+    // dedicated pass.
+    if (alloc->dtype == DataType::Bool()) {
+      auto writer = alloc.CopyOnWrite();
+      writer->dtype = DataType::Int(8);
     }
-    return body;
-  }
-
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    Var var = GetRef<Var>(op);
-    auto it = unit_loop_vars_.find(var);
-    if (it == unit_loop_vars_.end()) {
-      return std::move(var);
+    // Handle multi-dimension allocations
+    if (alloc->extents.size() == 1) {
+      return std::move(alloc);
     } else {
-      PrimExpr expr = it->second;
-      if (expr.dtype() != var.dtype()) {
-        expr = tvm::cast(var.dtype(), std::move(expr));
+      Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);

Review Comment:
   why the size is `static_cast<size_t>(1)` instead of `alloc->extents.size()`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org