You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/16 18:46:23 UTC

[GitHub] [tvm] junrushao commented on a diff in pull request #13327: [TIR] Unify index data type when creating prim func

junrushao commented on code in PR #13327:
URL: https://github.com/apache/tvm/pull/13327#discussion_r1024384023


##########
src/tir/ir/data_type_rewriter.cc:
##########
@@ -191,5 +191,352 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
   return e;
 }
 
+Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_extents = op->extents.Map([this](const PrimExpr& e) { return this->VisitExpr(e); });
+  auto new_cond = VisitExpr(op->condition);
+  is_enabled_ = is_enabled;
+  auto new_body = this->VisitStmt(op->body);
+  if (!new_extents.same_as(op->extents) || !new_cond.same_as(op->condition) ||
+      !new_body.same_as(op->body)) {
+    Allocate new_allocate = GetRef<Allocate>(op);
+    auto* n = new_allocate.CopyOnWrite();
+    n->extents = std::move(new_extents);
+    n->condition = std::move(new_cond);
+    n->body = std::move(new_body);
+    return std::move(new_allocate);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) {
+  Buffer new_buffer = VisitBuffer(op->buffer);
+  DeclBuffer decl_buffer = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+  if (!new_buffer.same_as(op->buffer)) {
+    decl_buffer.CopyOnWrite()->buffer = new_buffer;
+  }
+  return std::move(decl_buffer);
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) {
+  bool is_condition = is_condition_;
+  is_condition_ = true;
+  auto new_predicate = VisitExpr(op->predicate);
+  is_condition_ = is_condition;
+
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_iter_values =
+      op->iter_values.Map([this](const PrimExpr& e) { return this->VisitExpr(e); });
+  is_enabled_ = is_enabled;
+  Block new_body = Downcast<Block>(this->VisitStmt(op->block));
+  if (!new_predicate.same_as(op->predicate) || !new_iter_values.same_as(op->iter_values) ||
+      !new_body.same_as(op->block)) {
+    BlockRealize new_block_realize = GetRef<BlockRealize>(op);
+    auto* n = new_block_realize.CopyOnWrite();
+    n->predicate = std::move(new_predicate);
+    n->iter_values = std::move(new_iter_values);
+    n->block = std::move(new_body);
+    return std::move(new_block_realize);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) {
+  Array<Buffer> new_alloc_buffers =
+      op->alloc_buffers.Map([this](const Buffer& buffer) { return this->VisitBuffer(buffer); });
+  Array<MatchBufferRegion> new_match_buffers =
+      op->match_buffers.Map([this](const MatchBufferRegion& match_buffer_region) {
+        Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer);
+        BufferRegion new_buffer_region = this->VisitBufferRegion(match_buffer_region->source);
+        if (!new_buffer.same_as(match_buffer_region->buffer) ||
+            !new_buffer_region.same_as(match_buffer_region->source)) {
+          return MatchBufferRegion(new_buffer, new_buffer_region);
+        } else {
+          return match_buffer_region;
+        }
+      });
+  Array<BufferRegion> new_reads = op->reads.Map(
+      [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); });
+  Array<BufferRegion> new_writes = op->writes.Map(
+      [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); });
+  Array<IterVar> new_iter_vars =
+      op->iter_vars.Map([this](const IterVar& iter_var) { return this->VisitIterVar(iter_var); });
+  Optional<Stmt> new_init = NullOpt;
+  if (op->init.defined()) {
+    new_init = this->VisitStmt(op->init.value());
+  }
+  Stmt new_body = this->VisitStmt(op->body);
+
+  if (!new_init.same_as(op->init) || !new_body.same_as(op->body) ||
+      !new_alloc_buffers.same_as(op->alloc_buffers) ||
+      !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) ||
+      !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars)) {
+    Block new_block = GetRef<Block>(op);
+    BlockNode* n = new_block.CopyOnWrite();
+    n->alloc_buffers = std::move(new_alloc_buffers);
+    n->match_buffers = std::move(new_match_buffers);
+    n->reads = std::move(new_reads);
+    n->writes = std::move(new_writes);
+    n->iter_vars = std::move(new_iter_vars);
+    n->init = std::move(new_init);
+    n->body = std::move(new_body);
+    return std::move(new_block);
+  }
+  return GetRef<Stmt>(op);
+}
+
+Map<String, ObjectRef> IndexDataTypeRewriter::VisitBlockAnnotations(
+    const Map<String, ObjectRef>& annotations) {
+  auto new_annotations = annotations;
+
+  std::function<ObjectRef(const ObjectRef&)> f_mutate_obj =
+      [this, &f_mutate_obj](const ObjectRef& obj) -> ObjectRef {
+    if (!obj.defined()) {
+      return obj;
+    }
+    if (obj->IsInstance<BufferNode>()) {
+      Buffer buffer = Downcast<Buffer>(obj);
+      if (Buffer new_buffer = GetRemappedBuffer(buffer); !new_buffer.same_as(buffer)) {
+        return new_buffer;
+      }
+    } else if (obj->IsInstance<ArrayNode>()) {
+      return Downcast<Array<ObjectRef>>(obj).Map(f_mutate_obj);
+    }
+    return obj;
+  };
+  for (const auto& [key, value] : annotations) {
+    auto new_value = f_mutate_obj(value);
+    if (!new_value.same_as(value)) {
+      new_annotations.Set(key, new_value);
+    }
+  }
+  return new_annotations;
+}
+
+Buffer IndexDataTypeRewriter::GetRemappedBuffer(const Buffer& buffer) {
+  if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
+    return (*it).second;
+  }
+  return buffer;
+}
+
+IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  Var new_var = Downcast<Var>(VisitExpr(iter_var->var));
+  PrimExpr min = VisitExpr(iter_var->dom->min);
+  PrimExpr extent = VisitExpr(iter_var->dom->extent);
+  is_enabled_ = is_enabled;
+  if (!new_var.same_as(iter_var->var) || !min.same_as(iter_var->dom->min) ||
+      !extent.same_as(iter_var->dom->extent)) {
+    IterVar new_iter_var = iter_var;
+    IterVarNode* n = new_iter_var.CopyOnWrite();
+    n->var = std::move(new_var);
+    n->dom = Range(min, extent);
+    return new_iter_var;
+  }
+  return iter_var;
+}
+
+Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
+  bool is_enabled = is_enabled_;
+
+  is_enabled_ = true;
+  Array<PrimExpr> new_shape =
+      buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
+  Array<PrimExpr> new_strides =
+      buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
+  auto new_elem_offset = VisitExpr(buffer->elem_offset);
+  is_enabled_ = is_enabled;
+
+  if (!buffer->shape.same_as(new_shape) || !buffer->strides.same_as(new_strides) ||
+      !buffer->elem_offset.same_as(new_elem_offset)) {
+    Buffer new_buffer = buffer;
+    BufferNode* new_buffer_node = new_buffer.CopyOnWrite();
+    new_buffer_node->shape = std::move(new_shape);
+    new_buffer_node->strides = std::move(new_strides);
+    new_buffer_node->elem_offset = std::move(new_elem_offset);
+    buffer_remap_.Set(buffer, new_buffer);
+    return new_buffer;
+  } else {
+    return buffer;
+  }
+}
+
+BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& buffer_region) {
+  Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer);
+
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_region = buffer_region->region.Map([&](const Range& range) {
+    return Range::FromMinExtent(this->VisitExpr(range->min), this->VisitExpr(range->extent));
+  });
+  is_enabled_ = is_enabled;
+
+  if (!remapped_buffer.same_as(buffer_region->buffer) ||
+      !new_region.same_as(buffer_region->region)) {
+    return BufferRegion(remapped_buffer, new_region);
+  } else {
+    return buffer_region;
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
+  BufferStore store = GetRef<BufferStore>(op);
+
+  Buffer new_buffer = GetRemappedBuffer(op->buffer);
+  auto value = this->VisitExpr(op->value);
+  auto indices = VisitIndices(op->indices);
+
+  if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) ||
+      !indices.same_as(op->indices)) {
+    auto writer = store.CopyOnWrite();
+    writer->buffer = new_buffer;
+    writer->value = value;
+    writer->indices = indices;
+  }
+
+  return std::move(store);
+}
+
+PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) {
+  BufferLoad load = GetRef<BufferLoad>(op);
+
+  Buffer new_buffer = GetRemappedBuffer(op->buffer);
+  auto indices = VisitIndices(op->indices);
+
+  if (!new_buffer.same_as(op->buffer) || !indices.same_as(op->indices)) {
+    auto writer = load.CopyOnWrite();
+    writer->indices = indices;
+    writer->buffer = new_buffer;
+  }
+
+  return std::move(load);
+}
+
+Array<PrimExpr> IndexDataTypeRewriter::VisitIndices(Array<PrimExpr> indices) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+
+  auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
+  indices.MutateByApply(fmutate);
+
+  is_enabled_ = is_enabled;
+
+  return indices;
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) {
+  bool is_condition = is_condition_;
+  is_condition_ = true;
+  PrimExpr cond = VisitExpr(op->condition);
+  is_condition_ = is_condition;
+
+  Stmt then_case = VisitStmt(op->then_case);
+  Optional<Stmt> else_case =
+      op->else_case.defined() ? Optional<Stmt>{VisitStmt(op->else_case.value())} : NullOpt;
+  if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) ||
+      !else_case.same_as(op->else_case)) {
+    IfThenElse new_stmt = GetRef<IfThenElse>(op);
+    auto* n = new_stmt.CopyOnWrite();
+    n->condition = std::move(cond);
+    n->then_case = std::move(then_case);
+    n->else_case = std::move(else_case);
+    return std::move(new_stmt);
+  }
+  return GetRef<Stmt>(op);
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  Var new_loop_var = Downcast<Var>(VisitExpr(op->loop_var));
+  PrimExpr min = VisitExpr(op->min);
+  PrimExpr extent = VisitExpr(op->extent);
+  is_enabled_ = is_enabled;
+
+  Stmt new_body = VisitStmt(op->body);
+
+  if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || !extent.same_as(op->extent) ||
+      !new_body.same_as(op->body)) {
+    For new_for = GetRef<For>(op);
+    auto* n = new_for.CopyOnWrite();
+    n->loop_var = new_loop_var;
+    n->min = cast(new_loop_var.dtype(), min);
+    n->extent = cast(new_loop_var.dtype(), extent);
+    n->body = new_body;
+    return std::move(new_for);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+#define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)                         \
+  PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) {                       \
+    bool is_enabled = is_enabled_;                                                 \
+    is_enabled_ = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \
+    auto result = Parent::VisitExpr_(op);                                          \
+    is_enabled_ = is_enabled;                                                      \
+    return std::move(result);                                                      \
+  }
+
+DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);

Review Comment:
   Prefix it with TVM_



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org