You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/09 00:02:32 UTC

[GitHub] [tvm] vinx13 opened a new pull request, #13327: [TIR] Unify index data type when creating prim func

vinx13 opened a new pull request, #13327:
URL: https://github.com/apache/tvm/pull/13327

   * Added data type pass unification pass to by default promote data types of all indices and shapes to int64 when creating prim func.
   * Added some fixes for lowering passes to make it compatible with int64 data type.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on pull request #13327: [TIR] Unify index data type when creating prim func

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on PR #13327:
URL: https://github.com/apache/tvm/pull/13327#issuecomment-1458543609

   Sometimes the model contains mixed indices type (e.g both int32 and int64). It causes dtype mismatch error during scheduling. It is expected that this pass doesn't hurt performance since there is another pass NarrowDataType that should convert it back to int32 (or other smaller types) if possible. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao merged pull request #13327: [TIR] Unify index data type when creating prim func

Posted by GitBox <gi...@apache.org>.
junrushao merged PR #13327:
URL: https://github.com/apache/tvm/pull/13327


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao commented on a diff in pull request #13327: [TIR] Unify index data type when creating prim func

Posted by GitBox <gi...@apache.org>.
junrushao commented on code in PR #13327:
URL: https://github.com/apache/tvm/pull/13327#discussion_r1024384378


##########
src/tir/transforms/narrow_datatype.cc:
##########
@@ -315,65 +265,25 @@ class DataTypeRewriter : public DataTypeLegalizer {
     return Parent::VisitExpr_(op);
   }
 
-  PrimExpr VisitExpr_(const EQNode* op) final;
-  PrimExpr VisitExpr_(const NENode* op) final;
-  PrimExpr VisitExpr_(const LTNode* op) final;
-  PrimExpr VisitExpr_(const LENode* op) final;
-  PrimExpr VisitExpr_(const GTNode* op) final;
-  PrimExpr VisitExpr_(const GENode* op) final;
-  PrimExpr VisitExpr_(const CallNode* op) final;
-
  private:
   // the internal visitor to deduce the narrowed dtype
   DataTypeVisitor visitor_;
   // a map from Var before rewrite to that after rewrite,
   // ensures one old Var maps to exactly one new Var
   std::unordered_map<const VarNode*, Var> vmap_;
-  // indicator of index expr to rewrite
-  bool is_index_{false};
-  // indicator of condition
-  bool is_condition_{false};
 };
 
-#define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)                          \
-  PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) {                             \
-    bool is_index = is_index_;                                                      \
-    bool rewrite = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \
-    if (rewrite) {                                                                  \
-      is_index_ = true;                                                             \
-    }                                                                               \
-    auto result = Parent::VisitExpr_(op);                                           \
-    is_index_ = is_index;                                                           \
-    return std::move(result);                                                       \
-  }
-
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=);
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=);
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<);  // NOLINT(*)
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>);  // NOLINT(*)
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
-
-PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
-  // handle if_then_else condition
-  if (op->op.same_as(builtin::if_then_else())) {
-    bool is_condition = is_condition_;
-    is_condition_ = true;
-    PrimExpr cond = VisitExpr(op->args[0]);
-    is_condition_ = is_condition;
-    return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2]));
-  }
-  return Parent::VisitExpr_(op);
+Stmt NarrowDataType(Stmt stmt, int target_bits) {
+  return NarrowDataTypeRewriter(target_bits)(stmt);
 }
 
-Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); }
-
 namespace transform {
 
 Pass NarrowDataType(int target_bits) {
   auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    n->body = DataTypeRewriter(target_bits)(std::move(n->body));
+    n->body = NarrowDataTypeRewriter(target_bits)(std::move(n->body));
+    // LOG(INFO) << "AfterNarrow: " << tir::AsTVMScript(f);

Review Comment:
   Remove this line?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao commented on pull request #13327: [TIR] Unify index data type when creating prim func

Posted by GitBox <gi...@apache.org>.
junrushao commented on PR #13327:
URL: https://github.com/apache/tvm/pull/13327#issuecomment-1315831212

   Happy to review and let's fix the CI :-)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #13327: [TIR] Unify index data type when creating prim func

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13327:
URL: https://github.com/apache/tvm/pull/13327#issuecomment-1308000619

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao commented on a diff in pull request #13327: [TIR] Unify index data type when creating prim func

Posted by GitBox <gi...@apache.org>.
junrushao commented on code in PR #13327:
URL: https://github.com/apache/tvm/pull/13327#discussion_r1024384023


##########
src/tir/ir/data_type_rewriter.cc:
##########
@@ -191,5 +191,352 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
   return e;
 }
 
+Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_extents = op->extents.Map([this](const PrimExpr& e) { return this->VisitExpr(e); });
+  auto new_cond = VisitExpr(op->condition);
+  is_enabled_ = is_enabled;
+  auto new_body = this->VisitStmt(op->body);
+  if (!new_extents.same_as(op->extents) || !new_cond.same_as(op->condition) ||
+      !new_body.same_as(op->body)) {
+    Allocate new_allocate = GetRef<Allocate>(op);
+    auto* n = new_allocate.CopyOnWrite();
+    n->extents = std::move(new_extents);
+    n->condition = std::move(new_cond);
+    n->body = std::move(new_body);
+    return std::move(new_allocate);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) {
+  Buffer new_buffer = VisitBuffer(op->buffer);
+  DeclBuffer decl_buffer = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+  if (!new_buffer.same_as(op->buffer)) {
+    decl_buffer.CopyOnWrite()->buffer = new_buffer;
+  }
+  return std::move(decl_buffer);
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) {
+  bool is_condition = is_condition_;
+  is_condition_ = true;
+  auto new_predicate = VisitExpr(op->predicate);
+  is_condition_ = is_condition;
+
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_iter_values =
+      op->iter_values.Map([this](const PrimExpr& e) { return this->VisitExpr(e); });
+  is_enabled_ = is_enabled;
+  Block new_body = Downcast<Block>(this->VisitStmt(op->block));
+  if (!new_predicate.same_as(op->predicate) || !new_iter_values.same_as(op->iter_values) ||
+      !new_body.same_as(op->block)) {
+    BlockRealize new_block_realize = GetRef<BlockRealize>(op);
+    auto* n = new_block_realize.CopyOnWrite();
+    n->predicate = std::move(new_predicate);
+    n->iter_values = std::move(new_iter_values);
+    n->block = std::move(new_body);
+    return std::move(new_block_realize);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) {
+  Array<Buffer> new_alloc_buffers =
+      op->alloc_buffers.Map([this](const Buffer& buffer) { return this->VisitBuffer(buffer); });
+  Array<MatchBufferRegion> new_match_buffers =
+      op->match_buffers.Map([this](const MatchBufferRegion& match_buffer_region) {
+        Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer);
+        BufferRegion new_buffer_region = this->VisitBufferRegion(match_buffer_region->source);
+        if (!new_buffer.same_as(match_buffer_region->buffer) ||
+            !new_buffer_region.same_as(match_buffer_region->source)) {
+          return MatchBufferRegion(new_buffer, new_buffer_region);
+        } else {
+          return match_buffer_region;
+        }
+      });
+  Array<BufferRegion> new_reads = op->reads.Map(
+      [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); });
+  Array<BufferRegion> new_writes = op->writes.Map(
+      [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); });
+  Array<IterVar> new_iter_vars =
+      op->iter_vars.Map([this](const IterVar& iter_var) { return this->VisitIterVar(iter_var); });
+  Optional<Stmt> new_init = NullOpt;
+  if (op->init.defined()) {
+    new_init = this->VisitStmt(op->init.value());
+  }
+  Stmt new_body = this->VisitStmt(op->body);
+
+  if (!new_init.same_as(op->init) || !new_body.same_as(op->body) ||
+      !new_alloc_buffers.same_as(op->alloc_buffers) ||
+      !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) ||
+      !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars)) {
+    Block new_block = GetRef<Block>(op);
+    BlockNode* n = new_block.CopyOnWrite();
+    n->alloc_buffers = std::move(new_alloc_buffers);
+    n->match_buffers = std::move(new_match_buffers);
+    n->reads = std::move(new_reads);
+    n->writes = std::move(new_writes);
+    n->iter_vars = std::move(new_iter_vars);
+    n->init = std::move(new_init);
+    n->body = std::move(new_body);
+    return std::move(new_block);
+  }
+  return GetRef<Stmt>(op);
+}
+
+Map<String, ObjectRef> IndexDataTypeRewriter::VisitBlockAnnotations(
+    const Map<String, ObjectRef>& annotations) {
+  auto new_annotations = annotations;
+
+  std::function<ObjectRef(const ObjectRef&)> f_mutate_obj =
+      [this, &f_mutate_obj](const ObjectRef& obj) -> ObjectRef {
+    if (!obj.defined()) {
+      return obj;
+    }
+    if (obj->IsInstance<BufferNode>()) {
+      Buffer buffer = Downcast<Buffer>(obj);
+      if (Buffer new_buffer = GetRemappedBuffer(buffer); !new_buffer.same_as(buffer)) {
+        return new_buffer;
+      }
+    } else if (obj->IsInstance<ArrayNode>()) {
+      return Downcast<Array<ObjectRef>>(obj).Map(f_mutate_obj);
+    }
+    return obj;
+  };
+  for (const auto& [key, value] : annotations) {
+    auto new_value = f_mutate_obj(value);
+    if (!new_value.same_as(value)) {
+      new_annotations.Set(key, new_value);
+    }
+  }
+  return new_annotations;
+}
+
+Buffer IndexDataTypeRewriter::GetRemappedBuffer(const Buffer& buffer) {
+  if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
+    return (*it).second;
+  }
+  return buffer;
+}
+
+IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  Var new_var = Downcast<Var>(VisitExpr(iter_var->var));
+  PrimExpr min = VisitExpr(iter_var->dom->min);
+  PrimExpr extent = VisitExpr(iter_var->dom->extent);
+  is_enabled_ = is_enabled;
+  if (!new_var.same_as(iter_var->var) || !min.same_as(iter_var->dom->min) ||
+      !extent.same_as(iter_var->dom->extent)) {
+    IterVar new_iter_var = iter_var;
+    IterVarNode* n = new_iter_var.CopyOnWrite();
+    n->var = std::move(new_var);
+    n->dom = Range(min, extent);
+    return new_iter_var;
+  }
+  return iter_var;
+}
+
+Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
+  bool is_enabled = is_enabled_;
+
+  is_enabled_ = true;
+  Array<PrimExpr> new_shape =
+      buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
+  Array<PrimExpr> new_strides =
+      buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
+  auto new_elem_offset = VisitExpr(buffer->elem_offset);
+  is_enabled_ = is_enabled;
+
+  if (!buffer->shape.same_as(new_shape) || !buffer->strides.same_as(new_strides) ||
+      !buffer->elem_offset.same_as(new_elem_offset)) {
+    Buffer new_buffer = buffer;
+    BufferNode* new_buffer_node = new_buffer.CopyOnWrite();
+    new_buffer_node->shape = std::move(new_shape);
+    new_buffer_node->strides = std::move(new_strides);
+    new_buffer_node->elem_offset = std::move(new_elem_offset);
+    buffer_remap_.Set(buffer, new_buffer);
+    return new_buffer;
+  } else {
+    return buffer;
+  }
+}
+
+BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& buffer_region) {
+  Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer);
+
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_region = buffer_region->region.Map([&](const Range& range) {
+    return Range::FromMinExtent(this->VisitExpr(range->min), this->VisitExpr(range->extent));
+  });
+  is_enabled_ = is_enabled;
+
+  if (!remapped_buffer.same_as(buffer_region->buffer) ||
+      !new_region.same_as(buffer_region->region)) {
+    return BufferRegion(remapped_buffer, new_region);
+  } else {
+    return buffer_region;
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
+  BufferStore store = GetRef<BufferStore>(op);
+
+  Buffer new_buffer = GetRemappedBuffer(op->buffer);
+  auto value = this->VisitExpr(op->value);
+  auto indices = VisitIndices(op->indices);
+
+  if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) ||
+      !indices.same_as(op->indices)) {
+    auto writer = store.CopyOnWrite();
+    writer->buffer = new_buffer;
+    writer->value = value;
+    writer->indices = indices;
+  }
+
+  return std::move(store);
+}
+
+PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) {
+  BufferLoad load = GetRef<BufferLoad>(op);
+
+  Buffer new_buffer = GetRemappedBuffer(op->buffer);
+  auto indices = VisitIndices(op->indices);
+
+  if (!new_buffer.same_as(op->buffer) || !indices.same_as(op->indices)) {
+    auto writer = load.CopyOnWrite();
+    writer->indices = indices;
+    writer->buffer = new_buffer;
+  }
+
+  return std::move(load);
+}
+
+Array<PrimExpr> IndexDataTypeRewriter::VisitIndices(Array<PrimExpr> indices) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+
+  auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
+  indices.MutateByApply(fmutate);
+
+  is_enabled_ = is_enabled;
+
+  return indices;
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) {
+  bool is_condition = is_condition_;
+  is_condition_ = true;
+  PrimExpr cond = VisitExpr(op->condition);
+  is_condition_ = is_condition;
+
+  Stmt then_case = VisitStmt(op->then_case);
+  Optional<Stmt> else_case =
+      op->else_case.defined() ? Optional<Stmt>{VisitStmt(op->else_case.value())} : NullOpt;
+  if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) ||
+      !else_case.same_as(op->else_case)) {
+    IfThenElse new_stmt = GetRef<IfThenElse>(op);
+    auto* n = new_stmt.CopyOnWrite();
+    n->condition = std::move(cond);
+    n->then_case = std::move(then_case);
+    n->else_case = std::move(else_case);
+    return std::move(new_stmt);
+  }
+  return GetRef<Stmt>(op);
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  Var new_loop_var = Downcast<Var>(VisitExpr(op->loop_var));
+  PrimExpr min = VisitExpr(op->min);
+  PrimExpr extent = VisitExpr(op->extent);
+  is_enabled_ = is_enabled;
+
+  Stmt new_body = VisitStmt(op->body);
+
+  if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || !extent.same_as(op->extent) ||
+      !new_body.same_as(op->body)) {
+    For new_for = GetRef<For>(op);
+    auto* n = new_for.CopyOnWrite();
+    n->loop_var = new_loop_var;
+    n->min = cast(new_loop_var.dtype(), min);
+    n->extent = cast(new_loop_var.dtype(), extent);
+    n->body = new_body;
+    return std::move(new_for);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+#define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)                         \
+  PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) {                       \
+    bool is_enabled = is_enabled_;                                                 \
+    is_enabled_ = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \
+    auto result = Parent::VisitExpr_(op);                                          \
+    is_enabled_ = is_enabled;                                                      \
+    return std::move(result);                                                      \
+  }
+
+DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);

Review Comment:
   Prefix it with TVM_



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ibsidorenko commented on pull request #13327: [TIR] Unify index data type when creating prim func

Posted by "ibsidorenko (via GitHub)" <gi...@apache.org>.
ibsidorenko commented on PR #13327:
URL: https://github.com/apache/tvm/pull/13327#issuecomment-1458536268

   Hello, @vinx13 !
   
   I have one question about this PR... Is it possible to give more information or motivation about why do we need to convert indexes into "int64" data type?
   
   A few words about why I am asking:
   I am working on MetaScheduler for Hexagon target. And found that this PR dramatically reduce performance for some operations.
   
   **Example**: Average Pooling 2D
   For this operator we use indexes in its compute function and pool2d divisor. 
   **Before IndexDataTypeNormalizer**:
   `pool_avg[ax0, ax1, ax2, ax3, ax4] = (pool_sum[ax0, ax1, ax2, ax3, ax4] / max((((min(1, (34 - ax2)) + 2) - max((1 - ax2), 0))*((min(1, (34 - ax3)) + 2) - max((1 - ax3), 0))), 1))`
   **After IndexDataTypeNormalizer:**
   `pool_avg[ax0, ax1, ax2, ax3, ax4] = cast(int32, (cast(int64, pool_sum[ax0, ax1, ax2, ax3, ax4]) / max((((min(1i64, (34i64 - ax2)) + 2i64) - max((1i64 - ax2), 0i64))*((min(1i64, (34i64 - ax3)) + 2i64) - max((1i64 - ax3), 0i64))), 1i64)))`
   
   As you can see we get extra cast("int64"). Unfortunately, Hexagon does not support vectorization of "int64" data types and performance became very very poor.
   
   P.S. Just for experiment I have reverted conversion of indexes into int64 and get performance gain **+40%** (!!!).
   
   So, I would like to fix it somehow but I would like to know motivation for this PR. 
   
   Thank you in advance!
   
   Just FYI cc @masahi 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org