You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/06/07 21:11:16 UTC

[incubator-tvm] branch master updated: [REFACTOR][TE][TIR] Call::Halide => ProducerLoad, DSL/TIR decouple. (#5743)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ae439c  [REFACTOR][TE][TIR] Call::Halide => ProducerLoad, DSL/TIR decouple. (#5743)
6ae439c is described below

commit 6ae439c8c58dd0118d2f2c5d1c4bcb650df47104
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sun Jun 7 14:11:05 2020 -0700

    [REFACTOR][TE][TIR] Call::Halide => ProducerLoad, DSL/TIR decouple. (#5743)
    
    In the HalideIR's design, DSL components and IR are mixed together.
    For example, Call::Halide can containa reference to a function which is
    constructed in the tensor expression language.
    
    While this coupled design simplifies certain aspect of the DSL construction,
    it prevents the TIR to evolve as a clean standalone IR:
    
    - The additional tensor expression provided in the function is opaque to the IR
      and may become obsolete as we transform them.
    - The duplication of the information in the DSL tensor and IR makes it hard to
      design a stand-alone text format (when there are elements shared in the tensor
      expression and normal statements).
    
    This PR aims to clearly de-couple the TIR from high-level DSL structures(tensor expression),
    while still provide clear extensions to build DSLs on top of the TIR.
    
    We introduce a DataProducer as a base class for high level tensor expressions objects
    that produce data. We then introduce ProducerLoad to replace the Call::Halide usage,
    so that the Call node can always be self contained and used for low-level calls.
    
    The high-level tensor expression DSL can still generate a PrimExpr that contains a ProducerLoad.
    These PrimExprs contains fragments of information that can be combined together to
    generate a low-level TIR PrimFunc.
    
    We also state clearly that DataProducer **should not** appear in any TIR PrimFunc.
    Instead, the high-level DSL layer should lowered DataProducers to Buffers and TIR statements
    that produces these buffers. We can further provide verifications to validate such invariance.
    
    Changes:
    - Introduce DataProducer to serve as a base class for Tensor in tensor expressions.
    - Migrate use of Call::Halide to ProducerLoad
    - Migrate the other usages of Calls.
    
    We will also create follow-up PRs to migrate the remaining two DSL related IR nodes(Realize/Provide)
    to use the DataProducer.
---
 include/tvm/te/tensor.h                            |  15 +-
 include/tvm/tir/buffer.h                           |  55 +++++
 include/tvm/tir/expr.h                             |  81 +++++---
 include/tvm/tir/expr_functor.h                     |   4 +
 python/tvm/autotvm/task/task.py                    |   8 +-
 python/tvm/target/datatype.py                      |   6 +-
 python/tvm/te/hybrid/parser.py                     |   8 +-
 python/tvm/te/hybrid/util.py                       |   7 +-
 python/tvm/te/tensor.py                            |   9 +-
 python/tvm/tir/__init__.py                         |   4 +-
 python/tvm/tir/buffer.py                           |   5 +
 python/tvm/tir/expr.py                             |  30 ++-
 python/tvm/tir/ir_builder.py                       |   2 +-
 python/tvm/tir/op.py                               |  18 +-
 src/contrib/hybrid/codegen_hybrid.cc               |  25 ++-
 src/contrib/hybrid/codegen_hybrid.h                |   1 +
 src/printer/text_printer.h                         |   1 +
 src/printer/tir_text_printer.cc                    |  12 +-
 src/te/autodiff/jacobian.cc                        |  32 +--
 src/te/operation/compute_op.cc                     |  12 +-
 src/te/operation/hybrid_op.cc                      |   5 +-
 src/te/operation/op_util.cc                        |  23 ++-
 src/te/operation/tensorize.cc                      |  25 +--
 src/te/schedule/graph.cc                           |  24 +--
 src/te/schedule/operation_inline.cc                |  19 +-
 src/te/schedule/schedule_ops.cc                    |  25 ++-
 .../schedule_postproc_rewrite_for_tensor_core.cc   | 223 +++++++++++----------
 src/te/schedule/schedule_postproc_to_primfunc.cc   |  15 +-
 src/te/tensor.cc                                   |  10 +-
 src/tir/ir/expr.cc                                 |  46 +++--
 src/tir/ir/expr_functor.cc                         |  16 +-
 src/tir/transforms/storage_flatten.cc              |   7 +-
 src/tir/transforms/vectorize_loop.cc               |   9 +-
 tests/lint/git-clang-format.sh                     |  18 +-
 .../unittest/test_arith_canonical_simplify.py      |   2 +-
 tests/python/unittest/test_target_codegen_llvm.py  |   2 +-
 tests/python/unittest/test_te_hybrid_script.py     |  32 +--
 tests/python/unittest/test_tir_constructor.py      |   4 +-
 tests/python/unittest/test_tir_nodes.py            |  26 +--
 .../test_tir_transform_combine_context_call.py     |   2 +-
 topi/python/topi/cuda/rcnn/proposal.py             |   4 +-
 topi/python/topi/cuda/sort.py                      |   6 +-
 vta/python/vta/environment.py                      |   2 +-
 vta/python/vta/transform.py                        |   2 +-
 44 files changed, 519 insertions(+), 363 deletions(-)

diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h
index e2d847f..045d186 100644
--- a/include/tvm/te/tensor.h
+++ b/include/tvm/te/tensor.h
@@ -49,11 +49,11 @@ class OperationNode;
  * \brief Tensor structure representing a possible input,
  *  or intermediate computation result.
  */
-class Tensor : public ObjectRef {
+class Tensor : public DataProducer {
  public:
   /*! \brief default constructor, used internally */
   Tensor() {}
-  explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
+  explicit Tensor(ObjectPtr<Object> n) : DataProducer(n) {}
   /*!
    * \brief access the internal node container
    * \return the pointer to the internal node container
@@ -157,7 +157,7 @@ class Operation : public tir::FunctionRef {
 };
 
 /*! \brief Node to represent a tensor */
-class TensorNode : public Object {
+class TensorNode : public DataProducerNode {
  public:
   /*! \brief The shape of the tensor */
   Array<PrimExpr> shape;
@@ -176,10 +176,17 @@ class TensorNode : public Object {
     v->Visit("op", &op);
     v->Visit("value_index", &value_index);
   }
+
+  Array<PrimExpr> GetShape() const final { return shape; }
+
+  DataType GetDataType() const final { return dtype; }
+
+  TVM_DLL String GetNameHint() const final;
+
   TVM_DLL static Tensor make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
 
   static constexpr const char* _type_key = "Tensor";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode);
 };
 
 // Implementations of inline functions
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index 5d4e860..6904f2a 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -203,6 +203,61 @@ inline const BufferNode* Buffer::operator->() const {
  */
 TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
                            std::string name = "buffer");
+
+/*!
+ * \brief Base node for data producers.
+ *
+ *  A DataProducer stores necessary information(e.g. a tensor expression) to produce
+ *  a multi-dimensional array. The stored information is opaque to the TIR.
+ *  DataProducer can appear in high-level DSLs that are built on top of the TIR.
+ *
+ *  A valid TIR PrimFunc should not contain any DataProducer, high level DSLs should lower
+ *  all DataProducers to Buffers before TIR transformations.
+ *
+ * \sa tvm::te::Tensor
+ */
+class DataProducerNode : public Object {
+ public:
+  /*! \brief destructor. */
+  virtual ~DataProducerNode() {}
+  /*!
+   * \brief Get the shape of the result.
+   * \return The shape.
+   */
+  virtual Array<PrimExpr> GetShape() const = 0;
+  /*!
+   * \brief Get the data type of the result.
+   * \return The data type.
+   */
+  virtual DataType GetDataType() const = 0;
+  /*!
+   * \brief Get the name hint of the data producer.
+   * \return The data type.
+   */
+  virtual String GetNameHint() const = 0;
+
+  bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const {
+    // because buffer producer is opaque, we just do pointer equality.
+    return this == other;
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {}
+
+  static constexpr const char* _type_key = "DataProducer";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const bool _type_has_method_shash_reduce = true;
+  TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DataProducerNode.
+ * \sa DataProducerNode
+ */
+class DataProducer : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode);
+};
+
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_BUFFER_H_
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 5909a24..d34165e 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -449,6 +449,10 @@ class BufferLoadNode : public PrimExprNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
 };
 
+/*!
+ * \brief Managed reference to BufferLoadNode.
+ * \sa BufferLoadNode
+ */
 class BufferLoad : public PrimExpr {
  public:
   TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices);
@@ -456,6 +460,54 @@ class BufferLoad : public PrimExpr {
 };
 
 /*!
+ * \brief Load value from the result produced by the producer.
+ *
+ * \note This node only appears in high-level DSLs that are built on top of the TIR.
+ *       It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
+ *       this node before TIR transformations.
+ *
+ * \sa ProducerLoad, DataProducerNode
+ */
+class ProducerLoadNode : public PrimExprNode {
+ public:
+  /*! \brief The buffer producer. */
+  DataProducer producer;
+  /*! \brief The location arguments. */
+  Array<PrimExpr> indices;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("dtype", &(this->dtype));
+    v->Visit("producer", &producer);
+    v->Visit("indices", &indices);
+  }
+
+  bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
+    return equal(dtype, other->dtype) && equal(producer, other->producer) &&
+           equal(indices, other->indices);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(dtype);
+    hash_reduce(producer);
+    hash_reduce(indices);
+  }
+
+  static constexpr const char* _type_key = "ProducerLoad";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
+};
+
+/*!
+ * \brief Managed reference to ProducerLoadNode.
+ * \sa ProducerLoadNode
+ */
+class ProducerLoad : public PrimExpr {
+ public:
+  TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
+};
+
+/*!
  * \brief Load the value from buffer_var.
  *
  *  Equivalent to ((DType*)buffer_var)[index]
@@ -661,11 +713,6 @@ class CallNode : public PrimExprNode {
     ExternCPlusPlus = 1,
     /*! \brief Extern "C" without side-effect. */
     PureExtern = 2,
-    /*!
-     * \brief Halide-style call, evaluates func(args).
-     * \note Deprecated, move to BufferLoad in the future.
-     */
-    Halide = 3,
     /*! \brief Intrinsic functions. */
     Intrinsic = 4,
     /*! \brief Intrinsic functions that are pure. */
@@ -677,30 +724,17 @@ class CallNode : public PrimExprNode {
   Array<PrimExpr> args;
   /*! \brief Type of calls. */
   CallType call_type;
-  /*!
-   * \brief The function to be called.
-   * \note Deprecated, move to BufferLoad in the future.
-   */
-  FunctionRef func;
-  /*!
-   * \brief The output value index if func's value is a tuple.
-   * \note Deprecated, move to BufferLoad in the future.
-   */
-  int value_index{0};
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
     v->Visit("name", &name);
     v->Visit("args", &args);
     v->Visit("call_type", &call_type);
-    v->Visit("func", &func);
-    v->Visit("value_index", &value_index);
   }
 
   bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
     return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) &&
-           equal(call_type, other->call_type) && equal(func, other->func) &&
-           equal(value_index, other->value_index);
+           equal(call_type, other->call_type);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
@@ -708,18 +742,13 @@ class CallNode : public PrimExprNode {
     hash_reduce(name);
     hash_reduce(args);
     hash_reduce(call_type);
-    hash_reduce(func);
-    hash_reduce(value_index);
   }
 
   TVM_DLL static PrimExpr make(DataType dtype, std::string name, Array<PrimExpr> args,
-                               CallType call_type, FunctionRef func = FunctionRef(),
-                               int value_index = 0);
+                               CallType call_type);
 
   /*! \return Whether call node is pure. */
-  bool is_pure() const {
-    return (call_type == PureExtern || call_type == PureIntrinsic || call_type == Halide);
-  }
+  bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }
 
   /*!
    * \return Whether call node corresponds to a defined intrinsic.
diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h
index 15ec3d2..a6c90b3 100644
--- a/include/tvm/tir/expr_functor.h
+++ b/include/tvm/tir/expr_functor.h
@@ -119,6 +119,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
     return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
   }
   virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
@@ -163,6 +164,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
     IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
     IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
     IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
+    IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
     IR_EXPR_FUNCTOR_DISPATCH(LetNode);
     IR_EXPR_FUNCTOR_DISPATCH(CallNode);
     IR_EXPR_FUNCTOR_DISPATCH(AddNode);
@@ -213,6 +215,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
   void VisitExpr_(const SizeVarNode* op) override;
   void VisitExpr_(const LoadNode* op) override;
   void VisitExpr_(const BufferLoadNode* op) override;
+  void VisitExpr_(const ProducerLoadNode* op) override;
   void VisitExpr_(const LetNode* op) override;
   void VisitExpr_(const CallNode* op) override;
   void VisitExpr_(const AddNode* op) override;
@@ -258,6 +261,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
   PrimExpr VisitExpr_(const SizeVarNode* op) override;
   PrimExpr VisitExpr_(const LoadNode* op) override;
   PrimExpr VisitExpr_(const BufferLoadNode* op) override;
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
   PrimExpr VisitExpr_(const LetNode* op) override;
   PrimExpr VisitExpr_(const CallNode* op) override;
   PrimExpr VisitExpr_(const AddNode* op) override;
diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py
index 00b6676..b7cd6f2 100644
--- a/python/tvm/autotvm/task/task.py
+++ b/python/tvm/autotvm/task/task.py
@@ -495,11 +495,11 @@ def compute_flop(sch):
         if isinstance(exp, expr.Select):
             return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
                                                     _count_flop(exp.false_value))
-        if isinstance(exp, expr.Call):
-            if exp.call_type == expr.Call.Halide:
-                # Ignore flops from indexing expressions.
-                return 0
+        if isinstance(exp, expr.ProducerLoad):
+            # Ignore flops from indexing expressions.
+            return 0
 
+        if isinstance(exp, expr.Call):
             return sum([_count_flop(x) for x in exp.args])
 
         raise FlopCalculationError("Found unsupported operator in the compute expr")
diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py
index 328568a..e42ac6b 100644
--- a/python/tvm/target/datatype.py
+++ b/python/tvm/target/datatype.py
@@ -88,7 +88,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None):
 
     op_name : str
         The name of the operation which the function computes, given by its
-        Halide::Internal class name (e.g. Add, LE, Cast).
+        class name (e.g. Add, LE, Cast).
 
     target : str
         The name of codegen target.
@@ -136,8 +136,8 @@ def create_lower_func(extern_func_name):
                 dtype += "x" + str(t.lanes)
         if isinstance(op, (_Cast, _FloatImm)):
             return _Call(dtype, extern_func_name, convert([op.value]),
-                         _Call.Extern, None, 0)
+                         _Call.Extern)
         return _Call(dtype, extern_func_name, convert([op.a, op.b]),
-                     _Call.Extern, None, 0)
+                     _Call.Extern)
 
     return lower
diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py
index 765efa0..75300ab 100644
--- a/python/tvm/te/hybrid/parser.py
+++ b/python/tvm/te/hybrid/parser.py
@@ -272,8 +272,7 @@ class HybridParser(ast.NodeVisitor):
             return entry if isinstance(node.ctx, ast.Load) else None
         if ty is Symbol.BufferVar:
             if isinstance(node.ctx, ast.Load):
-                return tvm.tir.Call(entry.dtype, entry.name, [tvm.runtime.const(0, 'int32')], \
-                                  _expr.Call.Halide, entry.op, entry.value_index)
+                return tvm.tir.ProducerLoad(entry, [tvm.runtime.const(0, 'int32')])
             return entry, [tvm.runtime.const(0, 'int32')]
         # Do I need any assertion here?
         return entry
@@ -305,7 +304,7 @@ class HybridParser(ast.NodeVisitor):
             args = [tvm.runtime.const(0, 'int32')]
         _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
 
-        read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
+        read = tvm.tir.ProducerLoad(buf, args)
         value = HybridParser._binop_maker[type(node.op)](read, rhs)
 
         return tvm.tir.Provide(buf.op, 0, value, args)
@@ -392,8 +391,7 @@ class HybridParser(ast.NodeVisitor):
                     arr = arr[i.value]
             return arr
         if isinstance(node.ctx, ast.Load):
-            return tvm.tir.Call(arr.dtype, arr.name, args,
-                                _expr.Call.Halide, arr.op, arr.value_index)
+            return tvm.tir.ProducerLoad(arr, args)
         return arr, args
 
     def visit_With(self, node):
diff --git a/python/tvm/te/hybrid/util.py b/python/tvm/te/hybrid/util.py
index 01eeeec..35c59f1 100644
--- a/python/tvm/te/hybrid/util.py
+++ b/python/tvm/te/hybrid/util.py
@@ -78,10 +78,9 @@ def replace_io(body, rmap):
         if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
             buf = rmap[op.func]
             return _stmt.Provide(buf.op, op.value_index, op.value, op.args)
-        if isinstance(op, _expr.Call) and  op.func in rmap.keys():
-            buf = rmap[op.func]
-            return _expr.Call(buf.dtype, buf.name, op.args, \
-                              _expr.Call.Halide, buf.op, buf.value_index)
+        if isinstance(op, _expr.ProducerLoad) and  op.producer.op in rmap.keys():
+            buf = rmap[op.producer.op]
+            return _expr.ProducerLoad(buf, op.indices)
         return None
 
     return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])
diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py
index 739268a..7d73bf4 100644
--- a/python/tvm/te/tensor.py
+++ b/python/tvm/te/tensor.py
@@ -19,7 +19,7 @@
 import tvm._ffi
 
 from tvm.runtime import Object, ObjectGeneric, convert_to_object
-from tvm.tir import expr as _expr
+from tvm.tir import expr as _expr, DataProducer
 
 from . import _ffi_api
 
@@ -52,7 +52,7 @@ class TensorIntrinCall(Object):
 
 
 @tvm._ffi.register_object
-class Tensor(Object, _expr.ExprOp):
+class Tensor(DataProducer, _expr.ExprOp):
     """Tensor object, to construct, see function.Tensor"""
 
     def __call__(self, *indices):
@@ -69,9 +69,8 @@ class Tensor(Object, _expr.ExprOp):
             else:
                 raise ValueError("The indices must be expression")
 
-        return _expr.Call(self.dtype, self.op.name,
-                          args, _expr.Call.Halide,
-                          self.op, self.value_index)
+        return _expr.ProducerLoad(self, args)
+
 
     def __getitem__(self, indices):
         return TensorSlice(self, indices)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 07e0c9c..9aec24a 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -19,12 +19,12 @@
 from tvm.ir import PrimExpr
 from tvm.runtime import const
 
-from .buffer import Buffer, decl_buffer
+from .buffer import Buffer, decl_buffer, DataProducer
 from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
 from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
 from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
 from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
+from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
 from .expr import IterVar, Any
 
 from .stmt import Stmt, LetStmt, AssertStmt, For
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index 0c7753e..e4dec5f 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -245,3 +245,8 @@ def decl_buffer(shape,
     return _ffi_api.Buffer(
         data, dtype, shape, strides, elem_offset, name, scope,
         data_alignment, offset_factor, buffer_type)
+
+
+@tvm._ffi.register_object
+class DataProducer(Object):
+    pass
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index aca5e5a..d55370e 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -144,7 +144,7 @@ class ExprOp(object):
     def __invert__(self):
         if _dtype_is_float(self):
             raise RuntimeError("Cannot use ~ operator on float type Expr.")
-        return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
+        return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic)
 
     def __lt__(self, other):
         return _ffi_api._OpLT(self, other)
@@ -889,6 +889,23 @@ class BufferLoad(PrimExprWithOp):
 
 
 @tvm._ffi.register_object
+class ProducerLoad(PrimExprWithOp):
+    """Producer load node.
+
+    Parameters
+    ----------
+    producer : DataProducer
+        The buffer to be loaded.
+
+    indices : List[PrimExpr]
+        The buffer indices.
+    """
+    def __init__(self, producer, indices):
+        self.__init_handle_by_constructor__(
+            _ffi_api.ProducerLoad, producer, indices)
+
+
+@tvm._ffi.register_object
 class Ramp(PrimExprWithOp):
     """Ramp node.
 
@@ -959,22 +976,15 @@ class Call(PrimExprWithOp):
 
     call_type : int
         The type of the call
-
-    func : Operation, optional
-        Operation if call_type is Halide
-
-    value_index : int
-        The output value index
     """
     Extern = 0
     ExternCPlusPlus = 1
     PureExtern = 2
-    Halide = 3
     Intrinsic = 4
     PureIntrinsic = 5
-    def __init__(self, dtype, name, args, call_type, func, value_index):
+    def __init__(self, dtype, name, args, call_type):
         self.__init_handle_by_constructor__(
-            _ffi_api.Call, dtype, name, args, call_type, func, value_index)
+            _ffi_api.Call, dtype, name, args, call_type)
 
 
 @tvm._ffi.register_object
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index 4dd541e..47ba2e2 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -380,7 +380,7 @@ class IRBuilder(object):
             The expression will likely tag.
         """
         return _expr.Call(expr.dtype, "likely", [expr],
-                          _expr.Call.PureIntrinsic, None, 0)
+                          _expr.Call.PureIntrinsic)
 
     def get(self):
         """Return the builded IR.
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index b87db19..929d422 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -30,9 +30,9 @@ def _pack_buffer(buf):
     """
     assert buf.shape
     shape = Call("handle", "tvm_stack_make_shape", buf.shape,
-                 Call.Intrinsic, None, 0)
+                 Call.Intrinsic)
     strides = Call("handle", "tvm_stack_make_shape", buf.strides,
-                   Call.Intrinsic, None, 0) if buf.strides else 0
+                   Call.Intrinsic) if buf.strides else 0
     pack_args = [buf.data,
                  shape,
                  strides,
@@ -40,7 +40,7 @@ def _pack_buffer(buf):
                  const(0, dtype=buf.dtype),
                  buf.elem_offset]
     return Call("handle", "tvm_stack_make_array",
-                pack_args, Call.Intrinsic, None, 0)
+                pack_args, Call.Intrinsic)
 
 def call_packed(*args):
     """Build expression by call an external packed function.
@@ -68,7 +68,7 @@ def call_packed(*args):
     """
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
     return Call(
-        "int32", "tvm_call_packed", call_args, Call.Intrinsic, None, 0)
+        "int32", "tvm_call_packed", call_args, Call.Intrinsic)
 
 
 def call_pure_intrin(dtype, func_name, *args):
@@ -95,7 +95,7 @@ def call_pure_intrin(dtype, func_name, *args):
     """
     args = convert(args)
     return Call(
-        dtype, func_name, convert(args), Call.PureIntrinsic, None, 0)
+        dtype, func_name, convert(args), Call.PureIntrinsic)
 
 
 def call_intrin(dtype, func_name, *args):
@@ -122,7 +122,7 @@ def call_intrin(dtype, func_name, *args):
     """
     args = convert(args)
     return Call(
-        dtype, func_name, convert(args), Call.Intrinsic, None, 0)
+        dtype, func_name, convert(args), Call.Intrinsic)
 
 
 def call_pure_extern(dtype, func_name, *args):
@@ -145,7 +145,7 @@ def call_pure_extern(dtype, func_name, *args):
         The call expression.
     """
     return Call(
-        dtype, func_name, convert(args), Call.PureExtern, None, 0)
+        dtype, func_name, convert(args), Call.PureExtern)
 
 
 def call_extern(dtype, func_name, *args):
@@ -168,7 +168,7 @@ def call_extern(dtype, func_name, *args):
         The call expression.
     """
     return Call(
-        dtype, func_name, convert(args), Call.Extern, None, 0)
+        dtype, func_name, convert(args), Call.Extern)
 
 
 def call_llvm_intrin(dtype, name, *args):
@@ -278,7 +278,7 @@ def trace(args, trace_action="tvm.default_trace_action"):
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
     call_args.insert(0, trace_action)
     return tvm.tir.Call(
-        args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic, None, 0)
+        args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic)
 
 
 
diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc
index f61ad33..7062520 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -202,18 +202,21 @@ void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) {  // NOLINT
   PrintExpr(op->a, os);
 }
 
+void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) {  // NOLINT(*)
+  auto tensor = Downcast<Tensor>(op->producer);
+
+  os << GetTensorID(tensor->op, tensor->value_index);
+  os << "[";
+  for (size_t i = 0; i < op->indices.size(); ++i) {
+    if (i) os << ", ";
+    std::stringstream idx;
+    PrintExpr(op->indices[i], idx);
+    os << idx.str();
+  }
+  os << "]";
+}
 void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
-  if (op->call_type == CallNode::Halide) {
-    os << GetTensorID(op->func, op->value_index);
-    os << "[";
-    for (size_t i = 0; i < op->args.size(); ++i) {
-      if (i) os << ", ";
-      std::stringstream idx;
-      PrintExpr(op->args[i], idx);
-      os << idx.str();
-    }
-    os << "]";
-  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
+  if (op->is_intrinsic(CallNode::bitwise_and)) {
     PrintBinaryIntrinsitc(op, "&", os, this);
   } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
     PrintBinaryIntrinsitc(op, "^", os, this);
diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h
index 78a22b5..8a31e09 100644
--- a/src/contrib/hybrid/codegen_hybrid.h
+++ b/src/contrib/hybrid/codegen_hybrid.h
@@ -90,6 +90,7 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   void VisitExpr_(const LoadNode* op, std::ostream& os) override;       // NOLINT(*)
   void VisitExpr_(const LetNode* op, std::ostream& os) override;        // NOLINT(*)
   void VisitExpr_(const CallNode* op, std::ostream& os) override;       // NOLINT(*)
+  void VisitExpr_(const ProducerLoadNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const AddNode* op, std::ostream& os) override;        // NOLINT(*)
   void VisitExpr_(const SubNode* op, std::ostream& os) override;        // NOLINT(*)
   void VisitExpr_(const MulNode* op, std::ostream& os) override;        // NOLINT(*)
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
index fdf14d9..c7b2b31 100644
--- a/src/printer/text_printer.h
+++ b/src/printer/text_printer.h
@@ -286,6 +286,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
   Doc VisitExpr_(const NotNode* op) override;
   Doc VisitExpr_(const SelectNode* op) override;
   Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const ProducerLoadNode* op) override;
   Doc VisitExpr_(const LoadNode* op) override;
   Doc VisitExpr_(const RampNode* op) override;
   Doc VisitExpr_(const BroadcastNode* op) override;
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index 4d22cbb..2992737 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -291,6 +291,13 @@ Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
   return doc;
 }
 
+Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) {
+  // TODO(tvm-team): consider make a better text format for producer.
+  Doc doc;
+  doc << op->producer->GetNameHint() << Print(op->indices);
+  return doc;
+}
+
 Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
   Doc doc;
   doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index)
@@ -327,8 +334,6 @@ inline const char* CallType2String(CallNode::CallType t) {
       return "extern_cpp";
     case CallNode::PureExtern:
       return "pure_extern";
-    case CallNode::Halide:
-      return "halide";
     case CallNode::Intrinsic:
       return "intrin";
     case CallNode::PureIntrinsic:
@@ -346,8 +351,7 @@ Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
     args.push_back(Print(arg));
   }
   doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype)
-      << ", type=" << Doc::StrLiteral(CallType2String(op->call_type))
-      << ", index=" << op->value_index << ")";
+      << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) << ")";
   return doc;
 }
 
diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc
index f770169..ecddf5e 100644
--- a/src/te/autodiff/jacobian.cc
+++ b/src/te/autodiff/jacobian.cc
@@ -78,21 +78,24 @@ class JacobianMutator : public ExprMutator {
   PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED;
   PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED;
 
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    auto tensor = Downcast<te::Tensor>(op->producer);
+    if (input_.get() && tensor == input_) {
+      // Tensor(indices)
+      CHECK_EQ(indices_.size(), op->indices.size());
+      PrimExpr condition = const_true();
+      for (size_t i = 0; i < input_.ndim(); ++i) {
+        condition = AndNode::make(condition, EQNode::make(indices_[i], op->indices[i]));
+      }
+      return CastNode::make(op->dtype, condition);
+    } else {
+      return make_zero(op->dtype);
+    }
+  }
+
   PrimExpr VisitExpr_(const CallNode* op) {
     PrimExpr expr = GetRef<PrimExpr>(op);
-    if (op->call_type == CallNode::CallType::Halide) {
-      if (input_.get() && op->func.same_as(input_->op) && op->value_index == input_->value_index) {
-        // Tensor(indices)
-        CHECK_EQ(indices_.size(), op->args.size());
-        PrimExpr condition = const_true();
-        for (size_t i = 0; i < input_.ndim(); ++i) {
-          condition = AndNode::make(condition, EQNode::make(indices_[i], op->args[i]));
-        }
-        return CastNode::make(op->dtype, condition);
-      } else {
-        return make_zero(op->dtype);
-      }
-    } else if (op->call_type == CallNode::CallType::PureIntrinsic) {
+    if (op->call_type == CallNode::CallType::PureIntrinsic) {
       static std::unordered_set<std::string> piecewise_const = {"floor", "ceil", "trunc", "round"};
       if (op->name == "exp") {
         return MulNode::make(Mutate(op->args[0]), expr);
@@ -116,8 +119,7 @@ class JacobianMutator : public ExprMutator {
                                               FloatImm(type, 1.0), FloatImm(type, -1.0)));
       } else if (op->name == intrinsic::tvm_if_then_else) {
         Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
-        return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func,
-                              op->value_index);
+        return CallNode::make(op->dtype, op->name, new_args, op->call_type);
       } else if (piecewise_const.count(op->name)) {
         return FloatImm(expr.dtype(), 0.0);
       } else {
diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc
index cc843a4..41bf49f 100644
--- a/src/te/operation/compute_op.cc
+++ b/src/te/operation/compute_op.cc
@@ -153,9 +153,8 @@ Array<Tensor> ComputeOpNode::InputTensors() const {
   std::unordered_set<Tensor> visited;
   for (auto& e : body) {
     tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
-      const tir::CallNode* call = n.as<tir::CallNode>();
-      if (call != nullptr && call->func.defined()) {
-        Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+      if (auto* pload = n.as<tir::ProducerLoadNode>()) {
+        Tensor t = Downcast<Tensor>(pload->producer);
         if (!visited.count(t)) {
           ret.push_back(t);
           visited.insert(t);
@@ -202,9 +201,8 @@ void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* an
                                       std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
   CHECK_EQ(self.operator->(), this);
   auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
-    auto* call = n.as<tir::CallNode>();
-    if (call != nullptr && call->func.defined()) {
-      Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+    if (auto* pload = n.as<tir::ProducerLoadNode>()) {
+      Tensor t = Downcast<Tensor>(pload->producer);
       if (t->op.defined() && out_dom_map->count(t)) {
         TensorDom& dom = out_dom_map->at(t);
         for (size_t i = 0; i < t.ndim(); ++i) {
@@ -212,7 +210,7 @@ void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* an
           // undefined behaviour), so we can intersect the estimated set of the argument with the
           // range expected by the tensor. However, intersection may result in overly complex
           // expressions, so we perform a more relaxed form of intersection.
-          IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map));
+          IntSet arg_intset = analyzer->int_set(pload->indices[i], ConvertDomMap(dom_map));
           const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
           if (arg_interval) {
             PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc
index 7ee5833..55996a5 100644
--- a/src/te/operation/hybrid_op.cc
+++ b/src/te/operation/hybrid_op.cc
@@ -86,9 +86,8 @@ Array<Tensor> HybridOpNode::InputTensors() const {
   std::unordered_set<Tensor> visited;
   Array<Tensor> curr_inputs;
   tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
-    const tir::CallNode* call = n.as<tir::CallNode>();
-    if (call != nullptr && call->func.defined()) {
-      Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+    if (auto* pload = n.as<tir::ProducerLoadNode>()) {
+      Tensor t = Downcast<Tensor>(pload->producer);
       if (orig_inputs.count(t) && !visited.count(t)) {
         curr_inputs.push_back(t);
         visited.insert(t);
diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc
index 341e761..936781d 100644
--- a/src/te/operation/op_util.cc
+++ b/src/te/operation/op_util.cc
@@ -206,18 +206,19 @@ class TensorReplacer : public tir::StmtExprMutator {
  public:
   explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {}
 
-  PrimExpr VisitExpr_(const tir::CallNode* op) final {
-    if (op->call_type == tir::CallNode::Halide) {
-      Tensor t = Downcast<Operation>(op->func).output(op->value_index);
-      auto it = vmap_.find(t);
-      if (it != vmap_.end()) {
-        PrimExpr ret = tir::CallNode::make(op->dtype, it->second->op->name, op->args, op->call_type,
-                                           it->second->op, it->second->value_index);
-        found = true;
-        return this->VisitExpr(ret);
-      }
+  PrimExpr VisitExpr_(const tir::ProducerLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<tir::ProducerLoadNode>();
+    CHECK(op != nullptr);
+
+    Tensor t = Downcast<Tensor>(op->producer);
+    auto it = vmap_.find(t);
+    if (it != vmap_.end()) {
+      found = true;
+      return tir::ProducerLoad(it->second, op->indices);
+    } else {
+      return expr;
     }
-    return StmtExprMutator::VisitExpr_(op);
   }
 
   // whether it is found.
diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc
index f322e12..ddc0595 100644
--- a/src/te/operation/tensorize.cc
+++ b/src/te/operation/tensorize.cc
@@ -156,22 +156,19 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage,
 // Remap the tensor placeholder, index and inline things.
 class TensorIntrinMatcher final : public StmtExprMutator {
  public:
-  PrimExpr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<CallNode>();
-    if (op->call_type == CallNode::Halide) {
-      Tensor t = Downcast<Operation>(op->func).output(op->value_index);
-      auto it = in_remap_.find(t);
-      if (it != in_remap_.end()) {
-        const InputEntry& e = it->second;
-        CHECK_EQ(op->args.size(), e.region.size());
-        Array<PrimExpr> args;
-        for (size_t i = e.start; i < e.region.size(); ++i) {
-          args.push_back(op->args[i] - e.region[i]->min);
-        }
-        return CallNode::make(op->dtype, e.tensor->op->name, args, op->call_type, e.tensor->op,
-                              e.tensor->value_index);
+    op = expr.as<ProducerLoadNode>();
+    auto t = Downcast<Tensor>(op->producer);
+    auto it = in_remap_.find(t);
+    if (it != in_remap_.end()) {
+      const InputEntry& e = it->second;
+      CHECK_EQ(op->indices.size(), e.region.size());
+      Array<PrimExpr> indices;
+      for (size_t i = e.start; i < e.region.size(); ++i) {
+        indices.push_back(op->indices[i] - e.region[i]->min);
       }
+      return ProducerLoad(e.tensor, indices);
     }
     return expr;
   }
diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc
index bcde680..62557ed 100644
--- a/src/te/schedule/graph.cc
+++ b/src/te/schedule/graph.cc
@@ -40,8 +40,6 @@ struct TensorDimKey {
   int value_index;
   int dim;
   TensorDimKey() {}
-  TensorDimKey(const tir::CallNode* op, int dim)
-      : f(op->func), value_index(op->value_index), dim(dim) {}
   TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {}
   TensorDimKey(const Tensor& t, size_t dim)
       : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
@@ -240,11 +238,11 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
         reach[TensorDimKey(t, i)] = {};
       }
       auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
-        const tir::CallNode* call = n.as<tir::CallNode>();
-        if (call != nullptr && call->func.defined()) {
-          if (!bset.count(call->func.get())) return;
-          for (size_t i = 0; i < call->args.size(); ++i) {
-            TensorDimKey dkey(call, static_cast<int>(i));
+        if (auto* pload = n.as<tir::ProducerLoadNode>()) {
+          Tensor t = Downcast<Tensor>(pload->producer);
+          if (!bset.count(t->op.get())) return;
+          for (size_t i = 0; i < pload->indices.size(); ++i) {
+            TensorDimKey dkey(t, static_cast<int>(i));
             auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
               const VarNode* v = node.as<VarNode>();
               auto it = vmap.find(v);
@@ -252,7 +250,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
                 reach[it->second].push_back(dkey);
               }
             };
-            tir::PostOrderVisit(call->args[i], fpush);
+            tir::PostOrderVisit(pload->indices[i], fpush);
           }
         }
       };
@@ -328,11 +326,11 @@ Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
         vmap[axis[i]->var.get()] = std::move(keys);
       }
       auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) {
-        const tir::CallNode* call = n.as<tir::CallNode>();
-        if (call != nullptr && call->func.defined()) {
-          for (size_t i = 0; i < call->args.size(); ++i) {
-            auto it = vmap.find(call->args[i].get());
-            TensorDimKey src(call, static_cast<int>(i));
+        if (auto* pload = n.as<tir::ProducerLoadNode>()) {
+          Tensor t = Downcast<Tensor>(pload->producer);
+          for (size_t i = 0; i < pload->indices.size(); ++i) {
+            auto it = vmap.find(pload->indices[i].get());
+            TensorDimKey src(t, static_cast<int>(i));
             if (it != vmap.end()) {
               const std::vector<TensorDimKey>& keys = it->second;
               for (const auto& key : keys) {
diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc
index 8c8f092..8a130e9 100644
--- a/src/te/schedule/operation_inline.cc
+++ b/src/te/schedule/operation_inline.cc
@@ -42,27 +42,28 @@ class OperationInliner final : public StmtExprMutator {
   OperationInliner(Operation op, Array<Var> args, PrimExpr body)
       : operation_(op), args_(args), body_(body) {}
 
-  PrimExpr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<CallNode>();
+    op = expr.as<ProducerLoadNode>();
+    auto tensor = Downcast<Tensor>(op->producer);
 
-    if (op->func.same_as(operation_)) {
-      CHECK_EQ(op->value_index, 0);
+    if (tensor->op.same_as(operation_)) {
+      CHECK_EQ(tensor->value_index, 0);
       expr = body_;
-      CHECK_EQ(args_.size(), op->args.size());
+      CHECK_EQ(args_.size(), op->indices.size());
 
       bool has_side_effect = false;
-      for (size_t i = 0; i < op->args.size(); ++i) {
-        if (HasSideEffect(op->args[i])) has_side_effect = true;
+      for (size_t i = 0; i < op->indices.size(); ++i) {
+        if (HasSideEffect(op->indices[i])) has_side_effect = true;
       }
       if (has_side_effect) {
         for (size_t i = 0; i < args_.size(); ++i) {
-          expr = LetNode::make(args_[i], op->args[i], expr);
+          expr = LetNode::make(args_[i], op->indices[i], expr);
         }
       } else {
         Map<Var, PrimExpr> vmap;
         for (size_t i = 0; i < args_.size(); ++i) {
-          vmap.Set(args_[i], op->args[i]);
+          vmap.Set(args_[i], op->indices[i]);
         }
         expr = Substitute(EvaluateNode::make(expr), vmap).as<EvaluateNode>()->value;
       }
diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc
index 6cc04d9..10f1ed3 100644
--- a/src/te/schedule/schedule_ops.cc
+++ b/src/te/schedule/schedule_ops.cc
@@ -245,18 +245,21 @@ class SchedulePostProc : public StmtExprMutator {
     }
   }
 
-  PrimExpr VisitExpr_(const CallNode* op) final {
-    if (op->call_type == CallNode::Halide) {
-      TensorKey key{op->func, op->value_index};
-      auto it = replace_buffer_.find(key);
-      if (it != replace_buffer_.end()) {
-        const Tensor& dst = it->second;
-        PrimExpr ret = CallNode::make(op->dtype, dst->op->name, op->args, op->call_type, dst->op,
-                                      dst->value_index);
-        return this->VisitExpr(ret);
-      }
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<ProducerLoadNode>();
+    CHECK(op != nullptr);
+
+    auto tensor = Downcast<Tensor>(op->producer);
+    TensorKey key{tensor->op, tensor->value_index};
+
+    auto it = replace_buffer_.find(key);
+    if (it != replace_buffer_.end()) {
+      const Tensor& dst = it->second;
+      return ProducerLoad(dst, op->indices);
+    } else {
+      return expr;
     }
-    return StmtExprMutator::VisitExpr_(op);
   }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
index 46fc91be..e0d5882 100644
--- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
+++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
@@ -53,6 +53,11 @@ struct Tile {
   int k{-1};
 };
 
+TensorKey TensorKeyFromProducer(DataProducer producer) {
+  auto tensor = Downcast<Tensor>(producer);
+  return TensorKey{tensor->op, tensor->value_index};
+}
+
 std::string simplify_name(std::string input) {
   auto pos = input.find(".");
   if (pos != std::string::npos) {
@@ -152,27 +157,25 @@ class MMAMatcher : public StmtVisitor {
   };
 
   // Check whether the storage scope is local
-  bool check_local_buffer_(const CallNode* op, BufferInfo* bi) {
-    if (op->call_type == CallNode::Halide) {
-      auto it = storage_scope_.find(op->func.get());
-      if (it == storage_scope_.end()) {
-        return false;
-      }
-      const std::string& strkey = it->second;
-      if (strkey != "local") {
-        return false;
-      }
-      auto it1 = buf_map_.find(TensorKey{op->func, op->value_index});
-      if (it1 == buf_map_.end()) {
-        return false;
-      }
-      *bi = it1->second;
-      if (bi->released) {
-        return false;
-      }
-      return true;
+  bool check_local_buffer_(const ProducerLoadNode* op, BufferInfo* bi) {
+    auto tensor = Downcast<Tensor>(op->producer);
+    auto it = storage_scope_.find(tensor.get());
+    if (it == storage_scope_.end()) {
+      return false;
     }
-    return false;
+    const std::string& strkey = it->second;
+    if (strkey != "local") {
+      return false;
+    }
+    auto it1 = buf_map_.find(TensorKey{tensor->op, tensor->value_index});
+    if (it1 == buf_map_.end()) {
+      return false;
+    }
+    *bi = it1->second;
+    if (bi->released) {
+      return false;
+    }
+    return true;
   }
 
   // Do the pattern matching
@@ -182,7 +185,7 @@ class MMAMatcher : public StmtVisitor {
       return false;
     }
 
-    auto* load_c = add->a.as<CallNode>();
+    auto* load_c = add->a.as<ProducerLoadNode>();
     BufferInfo buffer_c;
     if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) ||
         !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) {
@@ -195,7 +198,7 @@ class MMAMatcher : public StmtVisitor {
     }
 
     auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype);
-    auto load_a = load_a_expr.as<CallNode>();
+    auto load_a = load_a_expr.as<ProducerLoadNode>();
     BufferInfo buffer_a;
     if (!check_local_buffer_(load_a, &buffer_a) ||
         !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) ||
@@ -205,7 +208,7 @@ class MMAMatcher : public StmtVisitor {
     }
 
     auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype);
-    auto load_b = load_b_expr.as<CallNode>();
+    auto load_b = load_b_expr.as<ProducerLoadNode>();
     BufferInfo buffer_b;
     if (!check_local_buffer_(load_b, &buffer_b) ||
         !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) ||
@@ -469,7 +472,7 @@ class BufferAnalyser : public StmtExprVisitor {
     strides_.insert(std::make_pair(key.GetName(), strides));
 
     if (frag_reg_.count(bi.name)) {
-      PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0);
+      PrimExpr dst = ProducerLoad(Downcast<Operation>(op->func).output(0), op->args);
       frag_load_.insert(std::make_pair(op, dst));
 
       auto rel_index = bi.RelIndex(op->args);
@@ -524,69 +527,70 @@ class BufferAnalyser : public StmtExprVisitor {
       }
     }
 
-    const CallNode* value = op->value.as<CallNode>();
-    if (value != nullptr && frag_reg_.count(value->name)) {
-      PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0);
+    const ProducerLoadNode* value = op->value.as<ProducerLoadNode>();
+    // TODO(tvm-team): string matching is dangerous, consider other means.
+    if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) {
+      PrimExpr dst = ProducerLoad(Downcast<Operation>(op->func).output(0), op->args);
       frag_store_.insert(std::make_pair(op, dst));
     }
   }
 
-  void VisitExpr_(const CallNode* op) final {
+  void VisitExpr_(const ProducerLoadNode* op) final {
     StmtExprVisitor::VisitExpr_(op);
-    if (op->call_type == CallNode::Halide) {
-      TensorKey key{op->func, op->value_index};
-      auto it = buf_map_.find(key);
-      CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f;
-      const BufferInfo& bi = it->second;
-      CHECK(!bi.released) << "Read a buffer that is already out of scope";
 
-      if (matrix_abc_.count(op->name)) {
-        if (bi.shape.size() < 2) {
+    auto tensor = Downcast<Tensor>(op->producer);
+    TensorKey key{tensor->op, tensor->value_index};
+    auto it = buf_map_.find(key);
+    CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f;
+    const BufferInfo& bi = it->second;
+    CHECK(!bi.released) << "Read a buffer that is already out of scope";
+
+    if (matrix_abc_.count(tensor->op->name)) {
+      if (bi.shape.size() < 2) {
+        invalid_ = true;
+        return;
+      }
+      for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
+        const IntImmNode* shape = bi.shape[i].as<IntImmNode>();
+        if (shape == nullptr || shape->value % 16 != 0) {
           invalid_ = true;
           return;
         }
-        for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
-          const IntImmNode* shape = bi.shape[i].as<IntImmNode>();
-          if (shape == nullptr || shape->value % 16 != 0) {
-            invalid_ = true;
-            return;
-          }
-        }
       }
+    }
 
-      Array<PrimExpr> strides;
-      if (bi.strides.size() > 0) {
-        strides = bi.strides;
-      } else {
-        for (size_t i = 1; i < bi.shape.size(); ++i) {
-          PrimExpr stride = IntImm(DataType::Int(32), 1);
-          for (size_t j = bi.shape.size() - 1; j >= i; --j) {
-            stride = MulNode::make(stride, bi.shape[j]);
-          }
-          strides.push_back(stride);
+    Array<PrimExpr> strides;
+    if (bi.strides.size() > 0) {
+      strides = bi.strides;
+    } else {
+      for (size_t i = 1; i < bi.shape.size(); ++i) {
+        PrimExpr stride = IntImm(DataType::Int(32), 1);
+        for (size_t j = bi.shape.size() - 1; j >= i; --j) {
+          stride = MulNode::make(stride, bi.shape[j]);
         }
-        strides.push_back(make_const(DataType::Int(32), 1));
+        strides.push_back(stride);
       }
-      strides_.insert(std::make_pair(key.GetName(), strides));
+      strides.push_back(make_const(DataType::Int(32), 1));
+    }
+    strides_.insert(std::make_pair(key.GetName(), strides));
 
-      if (!frag_reg_.count(bi.name)) {
-        return;
-      }
+    if (!frag_reg_.count(bi.name)) {
+      return;
+    }
 
-      auto rel_index = bi.RelIndex(op->args);
-      if (op->args.size() < 2) {
-        invalid_ = true;
-        return;
-      }
-      for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
-        index_visitor.scaling_factor_ = 16;
-        if (const IntImmNode* shape = bi.shape[i].as<IntImmNode>()) {
-          index_visitor.scaling_factor_ = shape->value;
-        }
-        auto index = rel_index[i];
-        auto simplified_index = analyzer_.Simplify(index);
-        index_visitor(simplified_index);
+    auto rel_index = bi.RelIndex(op->indices);
+    if (op->indices.size() < 2) {
+      invalid_ = true;
+      return;
+    }
+    for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) {
+      index_visitor.scaling_factor_ = 16;
+      if (const IntImmNode* shape = bi.shape[i].as<IntImmNode>()) {
+        index_visitor.scaling_factor_ = shape->value;
       }
+      auto index = rel_index[i];
+      auto simplified_index = analyzer_.Simplify(index);
+      index_visitor(simplified_index);
     }
   }
 
@@ -836,11 +840,11 @@ class TensorCoreIRMutator : public StmtExprMutator {
     if (it != mma_sync_.end()) {
       const auto& operands = it->second;
       PrimExpr a = operands[0];
-      auto ca = a.as<CallNode>();
+      auto ca = a.as<ProducerLoadNode>();
       PrimExpr b = operands[1];
-      auto cb = b.as<CallNode>();
+      auto cb = b.as<ProducerLoadNode>();
       PrimExpr c = operands[2];
-      auto cc = c.as<CallNode>();
+      auto cc = c.as<ProducerLoadNode>();
 
       ObjectPtr<BufferNode> buffer_node_a = make_object<BufferNode>();
       ObjectPtr<BufferNode> buffer_node_b = make_object<BufferNode>();
@@ -865,24 +869,24 @@ class TensorCoreIRMutator : public StmtExprMutator {
       };
 
       auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) {
-        return add_buffer_bind_scope_(cc, buffer_node_c, TensorKey{cc->func, cc->value_index},
-                                      mma_sync_call, cc->dtype);
+        return add_buffer_bind_scope_(cc, buffer_node_c, TensorKeyFromProducer(cc->producer),
+                                      mma_sync_call);
       };
 
       auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) {
-        return add_buffer_bind_scope_(cb, buffer_node_b, TensorKey{cb->func, cb->value_index},
-                                      call_add_c, cb->dtype);
+        return add_buffer_bind_scope_(cb, buffer_node_b, TensorKeyFromProducer(cb->producer),
+                                      call_add_c);
       };
 
-      return add_buffer_bind_scope_(ca, buffer_node_a, TensorKey{ca->func, ca->value_index},
-                                    call_add_b, ca->dtype);
+      return add_buffer_bind_scope_(ca, buffer_node_a, TensorKeyFromProducer(ca->producer),
+                                    call_add_b);
     }
 
     auto it2 = frag_load_.find(op);
     if (it2 != frag_load_.end()) {
       PrimExpr dst = it2->second;
       if (op->value.as<FloatImmNode>() != nullptr || op->value.as<IntImmNode>() != nullptr) {
-        auto call = dst.as<CallNode>();
+        auto pload = dst.as<ProducerLoadNode>();
 
         auto fill_fragment_call = [this, &op](const Buffer& buffer) {
           return EvaluateNode::make(CallNode::make(DataType::Handle(), intrinsic::tvm_fill_fragment,
@@ -892,8 +896,8 @@ class TensorCoreIRMutator : public StmtExprMutator {
         };
 
         ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
-        return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index},
-                                      fill_fragment_call, call->dtype);
+        return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer),
+                                      fill_fragment_call);
       }
 
       const CallNode* value = op->value.as<CallNode>();
@@ -911,16 +915,17 @@ class TensorCoreIRMutator : public StmtExprMutator {
       PrimExpr mutated_value = thread_idx_mutator(op->value);
       PrimExpr src = CallNode::make(value->dtype, "&", {mutated_value}, CallNode::Extern);
 
-      auto call = dst.as<CallNode>();
+      auto pload = dst.as<ProducerLoadNode>();
       PrimExpr matrix_major;
-      auto iter2 = matrix_major_.find(simplify_name(call->name));
-      CHECK(iter2 != matrix_major_.end()) << "Can not determine matrix major for " << call->name;
+      auto iter2 = matrix_major_.find(simplify_name(pload->producer->GetNameHint()));
+      CHECK(iter2 != matrix_major_.end())
+          << "Can not determine matrix major for " << pload->producer->GetNameHint();
       if (iter2->second == "col_major") {
         matrix_major = StringImmNode::make("col_major");
       } else if (iter2->second == "row_major") {
         matrix_major = StringImmNode::make("row_major");
       } else {
-        LOG(FATAL) << "invalid matrix major for " << call->name;
+        LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint();
       }
 
       auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) {
@@ -932,8 +937,8 @@ class TensorCoreIRMutator : public StmtExprMutator {
       };
 
       ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
-      return add_buffer_bind_scope_(call, buffer_node, TensorKey{op->func, op->value_index},
-                                    load_matrix_call, call->dtype);
+      return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer),
+                                    load_matrix_call);
     }
 
     auto it3 = frag_store_.find(op);
@@ -952,7 +957,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
       dst = thread_idx_mutator(dst);
       dst = CallNode::make(DataType::Handle(), "&", {dst}, CallNode::Extern);
 
-      auto call = op->value.as<CallNode>();
+      auto pload = op->value.as<ProducerLoadNode>();
 
       auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) {
         return EvaluateNode::make(
@@ -963,8 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator {
       };
 
       ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
-      return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index},
-                                    store_matrix_call, call->dtype);
+      return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer),
+                                    store_matrix_call);
     }
 
     return stmt;
@@ -1022,10 +1027,10 @@ class TensorCoreIRMutator : public StmtExprMutator {
     return tile_size;
   }
 
-  Stmt add_buffer_bind_scope_(const CallNode* call, const ObjectPtr<BufferNode>& buffer_node,
-                              const TensorKey& key,
-                              const std::function<Stmt(const Buffer& buffer)>& call_back,
-                              DataType datatype) {
+  Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload,
+                              const ObjectPtr<BufferNode>& buffer_node, const TensorKey& key,
+                              const std::function<Stmt(const Buffer& buffer)>& call_back) {
+    auto tensor = Downcast<Tensor>(pload->producer);
     auto it = bounds_.find(key);
     CHECK(it != bounds_.end());
     Array<PrimExpr> min_bound;
@@ -1038,7 +1043,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
     for (size_t i = 0; i < it->second.size() - 2; ++i) {
       shape.push_back(it->second[i]->extent);
     }
-    auto tile_size = get_tile_size_(simplify_name(call->name));
+    auto tile_size = get_tile_size_(simplify_name(tensor->op->name));
     shape.push_back(tile_size[0]);
     shape.push_back(tile_size[1]);
 
@@ -1053,18 +1058,18 @@ class TensorCoreIRMutator : public StmtExprMutator {
     strides.push_back(make_const(DataType::Int(32), 1));
 
     PrimExpr elem_offset = IntImm(DataType::Int(32), 0);
-    CHECK_EQ(call->args.size(), min_bound.size());
+    CHECK_EQ(pload->indices.size(), min_bound.size());
     for (size_t i = 0; i < min_bound.size(); i++) {
       elem_offset = AddNode::make(
-          elem_offset, MulNode::make(strides[i], SubNode::make(call->args[i], min_bound[i])));
+          elem_offset, MulNode::make(strides[i], SubNode::make(pload->indices[i], min_bound[i])));
     }
 
-    auto it2 = matrix_abc_.find(simplify_name(call->name));
-    CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << call->name;
-    buffer_node->data = Var(call->name, DataType::Handle());
-    buffer_node->name = call->name;
+    auto it2 = matrix_abc_.find(simplify_name(tensor->op->name));
+    CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << tensor->op->name;
+    buffer_node->data = Var(tensor->op->name, DataType::Handle());
+    buffer_node->name = tensor->op->name;
     buffer_node->scope = "wmma." + it2->second;
-    buffer_node->dtype = datatype;
+    buffer_node->dtype = tensor->dtype;
     buffer_node->strides = strides;
     buffer_node->shape = shape;
     buffer_node->data_alignment = 1;
@@ -1076,17 +1081,17 @@ class TensorCoreIRMutator : public StmtExprMutator {
     tensor_node->value_index = key.value_index;
     tensor_node->op = Downcast<te::Operation>(key.f);
     tensor_node->shape = shape;
-    tensor_node->dtype = datatype;
-    Tensor tensor(tensor_node);
+    tensor_node->dtype = tensor->dtype;
+    Tensor tensor_bind(tensor_node);
 
     Array<PrimExpr> args;
-    for (size_t i = 0; i < call->args.size(); ++i) {
-      args.push_back(call->args[i]);
+    for (size_t i = 0; i < pload->indices.size(); ++i) {
+      args.push_back(pload->indices[i]);
       args.push_back(shape[i]);
     }
     auto tuple =
         CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic);
-    Array<ObjectRef> node = {buffer, tensor};
+    Array<ObjectRef> node = {buffer, tensor_bind};
     return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer));
   }
 
diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc
index 57e5528..96df24d 100644
--- a/src/te/schedule/schedule_postproc_to_primfunc.cc
+++ b/src/te/schedule/schedule_postproc_to_primfunc.cc
@@ -114,17 +114,12 @@ class TensorToBufferMapper : public StmtExprMutator {
     return BufferStore(buffer, op->value, op->args);
   }
 
-  PrimExpr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
     auto ret = StmtExprMutator::VisitExpr_(op);
-    op = ret.as<CallNode>();
-
-    if (op->call_type == CallNode::Halide) {
-      Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
-      Buffer buffer = GetBuffer(tensor);
-      return tir::BufferLoad(buffer, op->args);
-    } else {
-      return ret;
-    }
+    op = ret.as<ProducerLoadNode>();
+    Tensor tensor = Downcast<Tensor>(op->producer);
+    Buffer buffer = GetBuffer(tensor);
+    return tir::BufferLoad(buffer, op->indices);
   }
 
  private:
diff --git a/src/te/tensor.cc b/src/te/tensor.cc
index 606797d..1a31a85 100644
--- a/src/te/tensor.cc
+++ b/src/te/tensor.cc
@@ -47,14 +47,16 @@ PrimExpr Tensor::operator()(Array<Var> indices) const {
 }
 
 PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
-  using tir::CallNode;
   if (ndim() != 0) {
     CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read"
                                      << "ndim = " << ndim() << ", indices.size=" << indices.size();
   }
-  auto n = CallNode::make((*this)->dtype, (*this)->op->name, indices, CallNode::Halide, (*this)->op,
-                          (*this)->value_index);
-  return n;
+
+  return ProducerLoad((*this), indices);
+}
+
+String TensorNode::GetNameHint() const {
+  return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index));
 }
 
 Tensor Operation::output(size_t i) const {
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 8b9a8e2..e1d8b3f 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -270,25 +270,17 @@ bool CallNode::is_vectorizable() const {
   return false;
 }
 
-PrimExpr CallNode::make(DataType dtype, std::string name, Array<PrimExpr> args, CallType call_type,
-                        FunctionRef func, int value_index) {
+PrimExpr CallNode::make(DataType dtype, std::string name, Array<PrimExpr> args,
+                        CallType call_type) {
   for (size_t i = 0; i < args.size(); ++i) {
     CHECK(args[i].defined());
   }
 
-  if (call_type == Halide) {
-    for (size_t i = 0; i < args.size(); ++i) {
-      CHECK(args[i].dtype().is_int());
-    }
-  }
-
   ObjectPtr<CallNode> node = make_object<CallNode>();
   node->dtype = dtype;
   node->name = std::move(name);
   node->args = std::move(args);
   node->call_type = call_type;
-  node->func = std::move(func);
-  node->value_index = value_index;
   return PrimExpr(node);
 }
 
@@ -403,6 +395,21 @@ TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array<Pri
 
 TVM_REGISTER_NODE_TYPE(BufferLoadNode);
 
+ProducerLoad::ProducerLoad(DataProducer producer, Array<PrimExpr> indices) {
+  ObjectPtr<ProducerLoadNode> node = make_object<ProducerLoadNode>();
+  node->dtype = producer->GetDataType();
+  node->producer = std::move(producer);
+  node->indices = std::move(indices);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.ProducerLoad")
+    .set_body_typed([](DataProducer producer, Array<PrimExpr> indices) {
+      return ProducerLoad(producer, indices);
+    });
+
+TVM_REGISTER_NODE_TYPE(ProducerLoadNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const StringImmNode*>(node.get());
@@ -639,6 +646,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ProducerLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const ProducerLoadNode*>(node.get());
+      p->stream << op->producer->GetNameHint() << "[";
+      for (size_t i = 0; i < op->indices.size(); ++i) {
+        p->Print(op->indices[i]);
+        if (i < op->indices.size() - 1) {
+          p->stream << ", ";
+        }
+      }
+      p->stream << "]";
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const LetNode*>(node.get());
       p->stream << "(let " << op->var << " = ";
@@ -758,8 +778,7 @@ TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) {
 });
 
 TVM_REGISTER_GLOBAL("tir.Call")
-    .set_body_typed([](DataType type, std::string name, Array<ObjectRef> args, int call_type,
-                       FunctionRef func, int value_index) {
+    .set_body_typed([](DataType type, std::string name, Array<ObjectRef> args, int call_type) {
       Array<PrimExpr> prim_expr_args;
       for (const auto& it : args) {
         CHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>());
@@ -769,8 +788,7 @@ TVM_REGISTER_GLOBAL("tir.Call")
           prim_expr_args.push_back(Downcast<PrimExpr>(it));
         }
       }
-      return CallNode::make(type, name, prim_expr_args, static_cast<CallNode::CallType>(call_type),
-                            func, value_index);
+      return CallNode::make(type, name, prim_expr_args, static_cast<CallNode::CallType>(call_type));
     });
 
 }  // namespace tir
diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc
index 7f30abe..98d61a0 100644
--- a/src/tir/ir/expr_functor.cc
+++ b/src/tir/ir/expr_functor.cc
@@ -41,6 +41,10 @@ void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
   VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
 }
 
+void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) {
+  VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
+}
+
 void ExprVisitor::VisitExpr_(const LetNode* op) {
   this->VisitExpr(op->value);
   this->VisitExpr(op->body);
@@ -135,6 +139,16 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
   }
 }
 
+PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) {
+  auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
+  Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
+  if (indices.same_as(op->indices)) {
+    return GetRef<PrimExpr>(op);
+  } else {
+    return ProducerLoad(op->producer, indices);
+  }
+}
+
 PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {
   PrimExpr value = this->VisitExpr(op->value);
   PrimExpr body = this->VisitExpr(op->body);
@@ -152,7 +166,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
   if (args.same_as(op->args)) {
     return GetRef<PrimExpr>(op);
   } else {
-    return CallNode::make(op->dtype, op->name, args, op->call_type, op->func, op->value_index);
+    return CallNode::make(op->dtype, op->name, args, op->call_type);
   }
 }
 
diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc
index 1e656ce..447a1e3 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -335,10 +335,9 @@ class StorageFlattener : public StmtExprMutator {
     return stmt;
   }
 
-  PrimExpr VisitExpr_(const CallNode* op) final {
-    CHECK(op->call_type != CallNode::Halide) << "Cannot handle Halide calls "
-                                             << " please run SchedulePostProcToPrimFunc first";
-    return StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
+    LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc.";
+    return PrimExpr();
   }
 
   Stmt VisitStmt_(const ProvideNode* op) final {
diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc
index 91993ac..61ec572 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -214,8 +214,7 @@ class Vectorizer : public StmtExprMutator {
       int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
       t = BroadcastTo(t, lanes);
       f = BroadcastTo(f, lanes);
-      return CallNode::make(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type,
-                            op->func, op->value_index);
+      return CallNode::make(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type);
     }
   }
   // Call
@@ -237,8 +236,7 @@ class Vectorizer : public StmtExprMutator {
       if (op->args.same_as(new_args)) {
         return GetRef<PrimExpr>(op);
       } else {
-        return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func,
-                              op->value_index);
+        return CallNode::make(op->dtype, op->name, new_args, op->call_type);
       }
     } else {
       int lane = 0;
@@ -247,8 +245,7 @@ class Vectorizer : public StmtExprMutator {
       if (op->args.same_as(new_args)) {
         return GetRef<PrimExpr>(op);
       } else {
-        return CallNode::make(op->dtype.with_lanes(lane), op->name, new_args, op->call_type,
-                              op->func, op->value_index);
+        return CallNode::make(op->dtype.with_lanes(lane), op->name, new_args, op->call_type);
       }
     }
   }
diff --git a/tests/lint/git-clang-format.sh b/tests/lint/git-clang-format.sh
index c76d4f4..dc4450a 100755
--- a/tests/lint/git-clang-format.sh
+++ b/tests/lint/git-clang-format.sh
@@ -19,13 +19,21 @@ set -e
 set -u
 set -o pipefail
 
-if [ "$#" -lt 1 ]; then
-    echo "Usage: tests/lint/git-clang-format.sh <commit>"
+if [[ "$1" == "-i" ]]; then
+    INPLACE_FORMAT=1
+    shift 1
+else
+    INPLACE_FORMAT=0
+fi
+
+if [[ "$#" -lt 1 ]]; then
+    echo "Usage: tests/lint/git-clang-format.sh [-i] <commit>"
     echo ""
     echo "Run clang-format on files that changed since <commit>"
     echo "Examples:"
     echo "- Compare last one commit: tests/lint/git-clang-format.sh HEAD~1"
     echo "- Compare against upstream/master: tests/lint/git-clang-format.sh upsstream/master"
+    echo "You can also add -i option to do inplace format"
     exit 1
 fi
 
@@ -50,6 +58,12 @@ fi
 # Print out specific version
 ${CLANG_FORMAT} --version
 
+if [[ ${INPLACE_FORMAT} -eq 1 ]]; then
+    echo "Running inplace git-clang-format against" $1
+    git-${CLANG_FORMAT} --extensions h,mm,c,cc --binary=${CLANG_FORMAT} $1
+    exit 0
+fi
+
 echo "Running git-clang-format against" $1
 git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc --binary=${CLANG_FORMAT} $1 1> /tmp/$$.clang-format.txt
 echo "---------clang-format log----------"
diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py
index 1791522..525cd6c 100644
--- a/tests/python/unittest/test_arith_canonical_simplify.py
+++ b/tests/python/unittest/test_arith_canonical_simplify.py
@@ -202,7 +202,7 @@ def test_reduce_combiner_simplify():
             assert tvm.ir.structural_equal(lhs, rhs)
 
     # Test that components with side effects are not removed
-    side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0)
+    side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic)
     ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
              sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
     ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py
index 2e53bfd..34db08f 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -34,7 +34,7 @@ def test_llvm_intrin():
     ]
     ib.emit(tvm.tir.Evaluate(
         tvm.tir.Call(
-            "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
+            "int32", "prefetch", args, tvm.tir.Call.Intrinsic)))
     body = ib.get()
 
     mod = tvm.IRModule.from_expr(
diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py
index ea4179d..c6f28ad 100644
--- a/tests/python/unittest/test_te_hybrid_script.py
+++ b/tests/python/unittest/test_te_hybrid_script.py
@@ -138,9 +138,9 @@ def test_outer_product():
     assert jbody.args[1].name == 'j'
     assert isinstance(jbody.value, tvm.tir.Mul)
     mul = jbody.value
-    assert isinstance(mul.a, tvm.tir.Call)
-    assert mul.a.name == 'a'
-    assert mul.b.name == 'b'
+    assert isinstance(mul.a, tvm.tir.ProducerLoad)
+    assert mul.a.producer.name == 'a'
+    assert mul.b.producer.name == 'b'
 
     func, ins, outs = run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101})
     temp = util.tempdir()
@@ -209,29 +209,29 @@ def test_fanout():
     assert jbody.func.name == 'sigma'
     assert isinstance(jbody.value, tvm.tir.Add)
     value = jbody.value
-    assert isinstance(value.a, tvm.tir.Call)
-    assert value.a.name == 'sigma'
-    assert len(value.a.args) == 1
-    assert value.a.args[0].value == 0
-    assert value.b.name == 'a'
-    assert len(value.b.args) == 1
-    assert tvm.ir.structural_equal(value.b.args[0], ir.loop_var + jloop.loop_var)
+    assert isinstance(value.a, tvm.tir.ProducerLoad)
+    assert value.a.producer.name == 'sigma'
+    assert len(value.a.indices) == 1
+    assert value.a.indices[0].value == 0
+    assert value.b.producer.name == 'a'
+    assert len(value.b.indices) == 1
+    assert tvm.ir.structural_equal(value.b.indices[0], ir.loop_var + jloop.loop_var)
     divide= rbody[2]
     assert isinstance(divide, tvm.tir.Provide)
     assert len(divide.args) == 1
     assert divide.args[0].value == 0
     value = divide.value
     assert isinstance(value, tvm.tir.Mul)
-    assert value.a.name == 'sigma'
-    assert len(value.a.args) == 1
-    assert value.a.args[0].value == 0
+    assert value.a.producer.name == 'sigma'
+    assert len(value.a.indices) == 1
+    assert value.a.indices[0].value == 0
     assert abs(value.b.value - (1 / 3.0)) < 1e-5
     write = rbody[3]
     assert isinstance(write, tvm.tir.Provide)
     assert write.func.name == 'b'
-    assert write.value.name == 'sigma'
-    assert len(write.value.args) == 1
-    assert write.value.args[0].value == 0
+    assert write.value.producer.name == 'sigma'
+    assert len(write.value.indices) == 1
+    assert write.value.indices[0].value == 0
 
     func, ins, outs = run_and_check(fanout, [n, a], {n: 10})
     run_and_check(func, ins, {n: 10}, outs=outs)
diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py
index 4af93fd..86f8734 100644
--- a/tests/python/unittest/test_tir_constructor.py
+++ b/tests/python/unittest/test_tir_constructor.py
@@ -112,14 +112,12 @@ def test_expr_constructor():
     assert x.vectors[0] == a
     assert x.indices[0].value == 0
 
-    x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern, None, 0)
+    x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern)
     assert isinstance(x, tvm.tir.Call)
     assert x.dtype == "float32"
     assert x.name == "xyz"
     assert x.args[0] == a
     assert x.call_type == tvm.tir.Call.Extern
-    assert x.func == None
-    assert x.value_index == 0
 
     v = te.var("aa")
     x = tvm.tir.Let(v, 1, v)
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index 36c9c76..e632259 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -171,18 +171,18 @@ def test_all():
 def test_bitwise():
     x = te.var('x')
     y = te.var('y')
-    assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
-    assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+    assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+    assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+    assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+    assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+    assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")'
+    assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")'
+    assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")'
+    assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin")'
+    assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin")'
     assert str(10 % x) == 'floormod(10, x: int32)'
-    assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin")'
     assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
     assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
     assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
@@ -239,10 +239,10 @@ def test_divide_by_zero():
 
 def test_isnan():
     x = te.var('x', 'float32')
-    assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin", index=0)'
+    assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin")'
     assert str(tvm.tir.isnan(x).dtype) == 'bool'
     y = te.var('y', 'float16')
-    assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin", index=0)'
+    assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")'
     z = te.var('z', 'int32')
     assert str(tvm.tir.isnan(z)) == 'False'
     k = te.var('k', 'int8x2')
diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py
index 7fd2593..29a3303 100644
--- a/tests/python/unittest/test_tir_transform_combine_context_call.py
+++ b/tests/python/unittest/test_tir_transform_combine_context_call.py
@@ -22,7 +22,7 @@ def test_for():
     def device_context(dev_id):
         ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
         return tvm.tir.Call(
-            "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic, None, 0)
+            "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic)
 
     ib = tvm.tir.ir_builder.create()
     n = te.var("n")
diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py
index b52fc91..f713bb2 100644
--- a/topi/python/topi/cuda/rcnn/proposal.py
+++ b/topi/python/topi/cuda/rcnn/proposal.py
@@ -187,7 +187,7 @@ def argsort_ir(data_buf, out_index_buf):
                 index_out[offset + 1] = temp_index[0]
             ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
                                  tvm.runtime.convert(['shared']),
-                                 tvm.tir.Call.Intrinsic, None, 0))
+                                 tvm.tir.Call.Intrinsic))
     return ib.get()
 
 
@@ -248,7 +248,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
                     p_out[base_idx + i] = True
         ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
                              tvm.runtime.convert(['shared']),
-                             tvm.tir.Call.Intrinsic, None, 0))
+                             tvm.tir.Call.Intrinsic))
     return ib.get()
 
 
diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py
index a1c70c4..ddae2bd 100644
--- a/topi/python/topi/cuda/sort.py
+++ b/topi/python/topi/cuda/sort.py
@@ -117,7 +117,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
                         tvm.tir.generic.cast(tid, indices_out.dtype)
     ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
                          tvm.runtime.convert(['shared']),
-                         tvm.tir.Call.Intrinsic, None, 0))
+                         tvm.tir.Call.Intrinsic))
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
 
@@ -145,7 +145,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
                             indices_out[offset + axis_mul_after] = temp_index[0]
                 ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
                                      tvm.runtime.convert(['shared']),
-                                     tvm.tir.Call.Intrinsic, None, 0))
+                                     tvm.tir.Call.Intrinsic))
 
     return ib.get()
 
@@ -237,7 +237,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
                         output[offset + axis_mul_after] = temp_index[0]
                 ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
                                      tvm.runtime.convert(['shared']),
-                                     tvm.tir.Call.Intrinsic, None, 0))
+                                     tvm.tir.Call.Intrinsic))
 
     return ib.get()
 
diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py
index bbaac2c..e68f098 100644
--- a/vta/python/vta/environment.py
+++ b/vta/python/vta/environment.py
@@ -80,7 +80,7 @@ class DevContext(object):
         ctx = tvm.tir.call_extern("handle", "VTATLSCommandHandle")
         self.command_handle = tvm.tir.Call(
             "handle", "tvm_thread_context", [ctx],
-            tvm.tir.Call.Intrinsic, None, 0)
+            tvm.tir.Call.Intrinsic)
         self.DEBUG_NO_SYNC = False
         env._dev_ctx = self
         self.gemm = intrin.gemm(env, env.mock_mode)
diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py
index 1d54bb0..37b4e0e 100644
--- a/vta/python/vta/transform.py
+++ b/vta/python/vta/transform.py
@@ -297,7 +297,7 @@ def InjectCoProcSync():
             if _match_pragma(stmt, "coproc_sync"):
                 success[0] = True
                 sync = tvm.tir.Call(
-                    "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
+                    "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic)
                 return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
             if _match_pragma(stmt, "trim_loop"):
                 op = stmt.body