You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/02 19:37:06 UTC

[incubator-tvm] branch master updated: [TIR] Introduce BufferLoad/Store (#5205)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 88d2f34  [TIR] Introduce BufferLoad/Store (#5205)
88d2f34 is described below

commit 88d2f34b981f5a4c32b5d5d9c6141ef6cfa74a6a
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Thu Apr 2 12:36:55 2020 -0700

    [TIR] Introduce BufferLoad/Store (#5205)
    
    Co-authored-by: Siyuan Feng <hz...@sjtu.edu.cn>
    
    This PR introduces BufferLoad/Store to TIR. The new nodes will replace
    Provide and Call with Tensor arguments in the subsequent refactors.
---
 include/tvm/tir/buffer.h                           |  14 +-
 include/tvm/tir/expr.h                             | 356 +++------------------
 include/tvm/tir/expr_functor.h                     |   4 +
 include/tvm/tir/stmt.h                             |  51 +++
 include/tvm/tir/stmt_functor.h                     |   3 +
 include/tvm/tir/var.h                              | 343 ++++++++++++++++++++
 python/tvm/tir/__init__.py                         |   4 +-
 python/tvm/tir/expr.py                             |  26 +-
 python/tvm/tir/stmt.py                             |  27 +-
 src/tir/ir/expr.cc                                 |  16 +
 src/tir/ir/expr_functor.cc                         |  14 +
 src/tir/ir/stmt.cc                                 |  15 +
 src/tir/ir/stmt_functor.cc                         |  15 +
 tests/python/unittest/test_tir_nodes.py            |  11 +
 .../unittest/test_tir_structural_equal_hash.py     |  17 +
 15 files changed, 590 insertions(+), 326 deletions(-)

diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index 7b15776..08a8e69 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -25,9 +25,8 @@
 #define TVM_TIR_BUFFER_H_
 
 #include <tvm/node/container.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/op.h>
-
+#include <tvm/ir/expr.h>
+#include <tvm/tir/var.h>
 #include <string>
 
 
@@ -36,6 +35,9 @@ namespace tir {
 // Internal node container Buffer
 class BufferNode;
 
+// forward declare Stmt
+class Stmt;
+
 /*! \brief buffer type */
 enum BufferType : int {
   kDefault = 1,
@@ -75,9 +77,9 @@ class Buffer : public ObjectRef {
    * \param offset The offset of ptr.
    */
   TVM_DLL PrimExpr access_ptr(int access_mask,
-                          DataType ptr_type = DataType::Handle(),
-                          int content_lanes = 1,
-                          PrimExpr offset = make_const(DataType::Int(32), 0)) const;
+                              DataType ptr_type = DataType::Handle(),
+                              int content_lanes = 1,
+                              PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
   /*!
    * \brief Create an Expr that does a vector load at begin index.
    * \param begin The beginning index
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 7b8ab44..6295a36 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -31,6 +31,8 @@
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/ir/expr.h>
+#include <tvm/tir/var.h>
+#include <tvm/tir/buffer.h>
 
 #include <string>
 #include <algorithm>
@@ -42,313 +44,6 @@
 namespace tvm {
 namespace tir {
 
-/*!
- * \brief A variable node in the IR.
- *
- * A variable is uniquely identified by its address.
- *
- * Each variable is only binded once in the following nodes:
- * - Allocate
- * - For
- * - Let
- * - LetStmt
- */
-class VarNode : public PrimExprNode {
- public:
-  /*!
-   * \brief The hint to the variable name.
-   * \note Each variable is uniquely identified by its address.
-   */
-  std::string name_hint;
-  /*!
-   * \brief type annotaion of the variable.
-   *
-   * It is an optional field that provides a refined type of the variable than dtype.
-   *
-   * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type.
-   */
-  Type type_annotation;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("dtype", &dtype);
-    v->Visit("name", &name_hint);
-    v->Visit("type_annotation", &type_annotation);
-  }
-
-  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
-    if (!equal(dtype, other->dtype)) return false;
-    if (!equal(type_annotation, other->type_annotation)) return false;
-    return equal.FreeVarEqualImpl(this, other);
-  }
-
-  void SHashReduce(SHashReducer hash_reduce) const {
-    hash_reduce(dtype);
-    hash_reduce(type_annotation);
-    hash_reduce.FreeVarHashImpl(this);
-  }
-
-  static constexpr const char* _type_key = "tir.Var";
-  TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
-};
-
-/*! \brief a named variable in TVM */
-class Var : public PrimExpr {
- public:
-  explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
-  /*!
-   * \brief Constructor
-   * \param name_hint variable name
-   * \param dtype data type
-   */
-  TVM_DLL explicit Var(std::string name_hint = "v",
-                       DataType dtype = DataType::Int(32));
-  /*!
-   * \brief Constructor which provides a more detailed type annotation.
-   * \param name_hint variable name.
-   * \param type_annotation The type annotation.
-   */
-  TVM_DLL explicit Var(std::string name_hint, Type type_annotation);
-  /*!
-   * \brief Make a new copy of var with same type, append suffix
-   * \param suffix The suffix to be appended.
-   * \return the new Var copy
-   */
-  TVM_DLL Var copy_with_suffix(const std::string& suffix) const;
-  /*!
-   * \brief Get pointer to the internal value.
-   * \return the corresponding Variable.
-   */
-  const VarNode* operator->() const {
-    return get();
-  }
-  /*!
-   * \brief Get pointer to the internal value.
-   * \return the corresponding Variable.
-   */
-  const VarNode* get() const {
-    return static_cast<const VarNode*>(data_.get());
-  }
-  /*! \brief type indicate the container type */
-  using ContainerType = VarNode;
-};
-
-/*!
- * \brief A variable node represent a tensor index size,
- * whose value must be non-negative.
- */
-class SizeVarNode : public VarNode {
- public:
-  static constexpr const char* _type_key = "tir.SizeVar";
-  TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
-};
-
-/*! \brief a named variable represents a tensor index size */
-class SizeVar : public Var {
- public:
-  explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
-  /*!
-   * \brief constructor
-   * \param name_hint variable name
-   * \param t data type
-   */
-  TVM_DLL explicit SizeVar(std::string name_hint = "s",
-                           DataType t = DataType::Int(32));
-  /*!
-   * \brief Get pointer to the internal value.
-   * \return the corresponding Variable.
-   */
-  const SizeVarNode* operator->() const {
-    return get();
-  }
-  /*!
-   * \brief Get pointer to the internal value.
-   * \return the corresponding Variable.
-   */
-  const SizeVarNode* get() const {
-    return static_cast<const SizeVarNode*>(data_.get());
-  }
-  /*! \brief type indicate the container type */
-  using ContainerType = SizeVarNode;
-};
-
-
-/*! \brief container class of iteration variable. */
-class IterVarNode;
-
-using Region = Array<Range>;
-
-/*!
- * \brief Type of iteration variable.
- *  Each IterVar have a specific type.
- *
- *  The type of iter var can be overriden via
- *  stage.iter_var_attrs given they are compatible.
- */
-enum IterVarType : int {
-  /*!
-   * \brief Data parallel iteration.
-   *  This normally corresponds to axis of Tensor.
-   *  Allow all IterVar manipulations.
-   *
-   * \note This does not mean the loop
-   *  have to be executed in parallel fashion.
-   */
-  kDataPar = 0,
-  /*!
-   * \brief The IterVar itself is a thread-index
-   *  of a fixed thread launching group.
-   *  Note that this is already assumed to be paralellized.
-   *
-   *  Disallow: split/fuse/vectorize/parallel
-   */
-  kThreadIndex = 1,
-  /*!
-   * \brief Communicative reduction.
-   *  Cannot be directly parallelized.
-   *
-   *  Disallow: parallel/vectorize
-   */
-  kCommReduce = 2,
-  /*!
-   * \brief Serial loops with loop carry dependency,
-   *  the iteration must execute in order.
-   *  Cannot be re-ordered.
-   *
-   *  Disallow: reorder/parallel/vectorize
-   */
-  kOrdered = 3,
-  /*!
-   * \brief IterVar is opaque,
-   *
-   *  May not corresponds to any generated loop
-   *  Disallow all IterVar manipulations and compute_at
-   *
-   * \note This is usually used to implement composite op
-   *  or external op, where the
-   */
-  kOpaque = 4,
-  // The following are possible additional
-  // types that are provided during schedule
-  /*!
-   * \brief The execution is unrolled.
-   */
-  kUnrolled = 5,
-  /*!
-   * \brief The loop is vectorized.
-   */
-  kVectorized = 6,
-  /*!
-   * \brief The loop is parallelized.
-   */
-  kParallelized = 7,
-  /*!
-   * \brief Marks boundary of tensorization intrinsic.
-   */
-  kTensorized = 8
-};
-
-/*!
- * \brief Iteration Variable,
- *  represents an iteration over an integer interval.
- */
-class IterVar : public ObjectRef {
- public:
-  // construct a new iter var without a domain
-  IterVar() {}
-  // construct from shared ptr.
-  explicit IterVar(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const IterVarNode* operator->() const;
-  /*!
-   * \return the corresponding var in the IterVar.
-   */
-  inline operator PrimExpr() const;
-  /*! \brief specify container node */
-  using ContainerType = IterVarNode;
-};
-
-using Domain = Array<Range>;
-
-/*!
- * \brief An iteration variable representing an iteration
- *  over a one dimensional interval.
- */
-class IterVarNode : public Object {
- public:
-  /*!
-   * \brief the domain of iteration, if known, can be None
-   *  For the intermediate schedule node, before schedule.
-   */
-  Range dom;
-  /*! \brief The looping variable */
-  Var var;
-  /*! \brief The type of the IterVar */
-  IterVarType iter_type;
-  /*!
-   * \brief additional tag on the iteration variable,
-   *  set this if this is binded already to a known thread tag.
-   */
-  std::string thread_tag;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("dom", &dom);
-    v->Visit("var", &var);
-    v->Visit("iter_type", &iter_type);
-    v->Visit("thread_tag", &thread_tag);
-  }
-
-  bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
-    return
-        equal(dom, other->dom) &&
-        equal.DefEqual(var, other->var) &&
-        equal(iter_type, other->iter_type) &&
-        equal(thread_tag, other->thread_tag);
-  }
-
-  void SHashReduce(SHashReducer hash_reduce) const {
-    hash_reduce(dom);
-    hash_reduce.DefHash(var);
-    hash_reduce(iter_type);
-    hash_reduce(thread_tag);
-  }
-
-  TVM_DLL static IterVar make(Range dom, Var var,
-                              IterVarType iter_type,
-                              std::string thread_tag = "");
-
-  static constexpr const char* _type_key = "IterVar";
-  static constexpr const bool _type_has_method_sequal_reduce = true;
-  static constexpr const bool _type_has_method_shash_reduce = true;
-  TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
-};
-
-// inline implementations
-inline const IterVarNode* IterVar::operator->() const {
-  return static_cast<const IterVarNode*>(data_.get());
-}
-
-inline IterVar::operator PrimExpr() const {
-  return (*this)->var;
-}
-
-inline const char* IterVarType2String(IterVarType t) {
-  switch (t) {
-    case kDataPar: return "DataPar";
-    case kThreadIndex: return "ThreadIndex";
-    case kCommReduce: return "CommReduce";
-    case kOrdered: return "Ordered";
-    case kOpaque: return "Opaque";
-    case kUnrolled: return "Unrolled";
-    case kVectorized: return "Vectorized";
-    case kParallelized: return "Parallelized";
-    case kTensorized: return "Tensorized";
-  }
-  return "Unknown";
-}
-
 using IntImmNode = tvm::IntImmNode;
 using FloatImmNode = tvm::FloatImmNode;
 
@@ -734,6 +429,53 @@ class SelectNode : public PrimExprNode {
 };
 
 /*!
+ * \brief Load value from the high dimension buffer.
+ *
+ * \code
+ *
+ *  value = buffer[i, j];
+ *
+ * \endcode
+ * \sa BufferStore
+ */
+class BufferLoadNode : public PrimExprNode {
+ public:
+  /*! \brief The buffer variable. */
+  Buffer buffer;
+  /*! \brief The indices location to be loaded. */
+  Array<PrimExpr> indices;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("dtype", &(this->dtype));
+    v->Visit("buffer", &buffer);
+    v->Visit("indices", &indices);
+  }
+
+  bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(buffer, other->buffer) &&
+        equal(indices, other->indices);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(dtype);
+    hash_reduce(buffer);
+    hash_reduce(indices);
+  }
+
+  static constexpr const char* _type_key = "BufferLoad";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
+};
+
+class BufferLoad : public PrimExpr {
+ public:
+  TVM_DLL explicit BufferLoad(Buffer buffer,
+                              Array<PrimExpr> indices);
+  TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
+};
+
+/*!
  * \brief Load the value from buffer_var.
  *
  *  Equivalent to ((DType*)buffer_var)[index]
diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h
index 0de05a6..dcf04c3 100644
--- a/include/tvm/tir/expr_functor.h
+++ b/include/tvm/tir/expr_functor.h
@@ -121,6 +121,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
   virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
     return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
   }
+  virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
@@ -164,6 +165,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
     IR_EXPR_FUNCTOR_DISPATCH(VarNode);
     IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
     IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
+    IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
     IR_EXPR_FUNCTOR_DISPATCH(LetNode);
     IR_EXPR_FUNCTOR_DISPATCH(CallNode);
     IR_EXPR_FUNCTOR_DISPATCH(AddNode);
@@ -214,6 +216,7 @@ class TVM_DLL ExprVisitor :
   void VisitExpr_(const VarNode* op) override;
   void VisitExpr_(const SizeVarNode* op) override;
   void VisitExpr_(const LoadNode* op) override;
+  void VisitExpr_(const BufferLoadNode* op) override;
   void VisitExpr_(const LetNode* op) override;
   void VisitExpr_(const CallNode* op) override;
   void VisitExpr_(const AddNode* op) override;
@@ -259,6 +262,7 @@ class TVM_DLL ExprMutator :
   PrimExpr VisitExpr_(const VarNode* op) override;
   PrimExpr VisitExpr_(const SizeVarNode* op) override;
   PrimExpr VisitExpr_(const LoadNode* op) override;
+  PrimExpr VisitExpr_(const BufferLoadNode* op) override;
   PrimExpr VisitExpr_(const LetNode* op) override;
   PrimExpr VisitExpr_(const CallNode* op) override;
   PrimExpr VisitExpr_(const AddNode* op) override;
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 47ec305..fe0d9ed 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -275,6 +275,57 @@ class StoreNode : public StmtNode {
 };
 
 /*!
+ * \brief Store value to the high dimension buffer.
+ *
+ * \code
+ *
+ *  buffer[i, j] = value;
+ *
+ * \endcode
+ * \sa BufferLoad
+ */
+class BufferStore;
+class BufferStoreNode : public StmtNode {
+ public:
+  /*! \brief The buffer variable. */
+  Buffer buffer;
+  /*! \brief The value to be stored. */
+  PrimExpr value;
+  /*! \brief The indices location to be stored. */
+  Array<PrimExpr> indices;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("buffer", &buffer);
+    v->Visit("value", &value);
+    v->Visit("indices", &indices);
+  }
+
+  bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
+    return
+        equal(buffer, other->buffer) &&
+        equal(value, other->value) &&
+        equal(indices, other->indices);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(buffer);
+    hash_reduce(value);
+    hash_reduce(indices);
+  }
+
+  static constexpr const char* _type_key = "BufferStore";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
+};
+
+class BufferStore : public Stmt {
+ public:
+  TVM_DLL explicit BufferStore(Buffer buffer,
+                               PrimExpr value,
+                               Array<PrimExpr> indices);
+  TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
+};
+
+/*!
  * \brief Store value into mult-dimensional array defined by func.
  */
 class ProvideNode : public StmtNode {
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index c880a48..6824022 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -91,6 +91,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
   virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -154,6 +155,7 @@ class TVM_DLL StmtVisitor :
   void VisitStmt_(const ForNode* op) override;
   void VisitStmt_(const AllocateNode* op) override;
   void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const BufferStoreNode* op) override;
   void VisitStmt_(const FreeNode* op) override;
   void VisitStmt_(const AssertStmtNode* op) override;
   void VisitStmt_(const ProducerConsumerNode* op) override;
@@ -248,6 +250,7 @@ class TVM_DLL StmtMutator :
   Stmt VisitStmt_(const ForNode* op) override;
   Stmt VisitStmt_(const AllocateNode* op) override;
   Stmt VisitStmt_(const StoreNode* op) override;
+  Stmt VisitStmt_(const BufferStoreNode* op) override;
   Stmt VisitStmt_(const FreeNode* op) override;
   Stmt VisitStmt_(const AssertStmtNode* op) override;
   Stmt VisitStmt_(const ProducerConsumerNode* op) override;
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
new file mode 100644
index 0000000..19c904a
--- /dev/null
+++ b/include/tvm/tir/var.h
@@ -0,0 +1,343 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/tir/var.h
+ * \brief Variables in the TIR.
+ */
+#ifndef TVM_TIR_VAR_H_
+#define TVM_TIR_VAR_H_
+
+#include <tvm/node/node.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/ir/expr.h>
+#include <string>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A variable node in the IR.
+ *
+ * A variable is uniquely identified by its address.
+ *
+ * Each variable is only binded once in the following nodes:
+ * - Allocate
+ * - For
+ * - Let
+ * - LetStmt
+ */
+class VarNode : public PrimExprNode {
+ public:
+  /*!
+   * \brief The hint to the variable name.
+   * \note Each variable is uniquely identified by its address.
+   */
+  std::string name_hint;
+  /*!
+   * \brief type annotaion of the variable.
+   *
+   * It is an optional field that provides a refined type of the variable than dtype.
+   *
+   * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type.
+   */
+  Type type_annotation;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("dtype", &dtype);
+    v->Visit("name", &name_hint);
+    v->Visit("type_annotation", &type_annotation);
+  }
+
+  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
+    if (!equal(dtype, other->dtype)) return false;
+    if (!equal(type_annotation, other->type_annotation)) return false;
+    return equal.FreeVarEqualImpl(this, other);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(dtype);
+    hash_reduce(type_annotation);
+    hash_reduce.FreeVarHashImpl(this);
+  }
+
+  static constexpr const char* _type_key = "tir.Var";
+  TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
+};
+
+/*! \brief a named variable in TVM */
+class Var : public PrimExpr {
+ public:
+  explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
+  /*!
+   * \brief Constructor
+   * \param name_hint variable name
+   * \param dtype data type
+   */
+  TVM_DLL explicit Var(std::string name_hint = "v",
+                       DataType dtype = DataType::Int(32));
+  /*!
+   * \brief Constructor which provides a more detailed type annotation.
+   * \param name_hint variable name.
+   * \param type_annotation The type annotation.
+   */
+  TVM_DLL explicit Var(std::string name_hint, Type type_annotation);
+  /*!
+   * \brief Make a new copy of var with same type, append suffix
+   * \param suffix The suffix to be appended.
+   * \return the new Var copy
+   */
+  TVM_DLL Var copy_with_suffix(const std::string& suffix) const;
+  /*!
+   * \brief Get pointer to the internal value.
+   * \return the corresponding Variable.
+   */
+  const VarNode* operator->() const {
+    return get();
+  }
+  /*!
+   * \brief Get pointer to the internal value.
+   * \return the corresponding Variable.
+   */
+  const VarNode* get() const {
+    return static_cast<const VarNode*>(data_.get());
+  }
+  /*! \brief type indicate the container type */
+  using ContainerType = VarNode;
+};
+
+/*!
+ * \brief A variable node represent a tensor index size,
+ * whose value must be non-negative.
+ */
+class SizeVarNode : public VarNode {
+ public:
+  static constexpr const char* _type_key = "tir.SizeVar";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
+};
+
+/*! \brief a named variable represents a tensor index size */
+class SizeVar : public Var {
+ public:
+  explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
+  /*!
+   * \brief constructor
+   * \param name_hint variable name
+   * \param t data type
+   */
+  TVM_DLL explicit SizeVar(std::string name_hint = "s",
+                           DataType t = DataType::Int(32));
+  /*!
+   * \brief Get pointer to the internal value.
+   * \return the corresponding Variable.
+   */
+  const SizeVarNode* operator->() const {
+    return get();
+  }
+  /*!
+   * \brief Get pointer to the internal value.
+   * \return the corresponding Variable.
+   */
+  const SizeVarNode* get() const {
+    return static_cast<const SizeVarNode*>(data_.get());
+  }
+  /*! \brief type indicate the container type */
+  using ContainerType = SizeVarNode;
+};
+
+
+/*! \brief container class of iteration variable. */
+class IterVarNode;
+
+using Region = Array<Range>;
+
+/*!
+ * \brief Type of iteration variable.
+ *  Each IterVar have a specific type.
+ *
+ *  The type of iter var can be overriden via
+ *  stage.iter_var_attrs given they are compatible.
+ */
+enum IterVarType : int {
+  /*!
+   * \brief Data parallel iteration.
+   *  This normally corresponds to axis of Tensor.
+   *  Allow all IterVar manipulations.
+   *
+   * \note This does not mean the loop
+   *  have to be executed in parallel fashion.
+   */
+  kDataPar = 0,
+  /*!
+   * \brief The IterVar itself is a thread-index
+   *  of a fixed thread launching group.
+   *  Note that this is already assumed to be paralellized.
+   *
+   *  Disallow: split/fuse/vectorize/parallel
+   */
+  kThreadIndex = 1,
+  /*!
+   * \brief Communicative reduction.
+   *  Cannot be directly parallelized.
+   *
+   *  Disallow: parallel/vectorize
+   */
+  kCommReduce = 2,
+  /*!
+   * \brief Serial loops with loop carry dependency,
+   *  the iteration must execute in order.
+   *  Cannot be re-ordered.
+   *
+   *  Disallow: reorder/parallel/vectorize
+   */
+  kOrdered = 3,
+  /*!
+   * \brief IterVar is opaque,
+   *
+   *  May not corresponds to any generated loop
+   *  Disallow all IterVar manipulations and compute_at
+   *
+   * \note This is usually used to implement composite op
+   *  or external op, where the
+   */
+  kOpaque = 4,
+  // The following are possible additional
+  // types that are provided during schedule
+  /*!
+   * \brief The execution is unrolled.
+   */
+  kUnrolled = 5,
+  /*!
+   * \brief The loop is vectorized.
+   */
+  kVectorized = 6,
+  /*!
+   * \brief The loop is parallelized.
+   */
+  kParallelized = 7,
+  /*!
+   * \brief Marks boundary of tensorization intrinsic.
+   */
+  kTensorized = 8
+};
+
+/*!
+ * \brief Iteration Variable,
+ *  represents an iteration over an integer interval.
+ */
+class IterVar : public ObjectRef {
+ public:
+  // construct a new iter var without a domain
+  IterVar() {}
+  // construct from shared ptr.
+  explicit IterVar(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const IterVarNode* operator->() const;
+  /*!
+   * \return the corresponding var in the IterVar.
+   */
+  inline operator PrimExpr() const;
+  /*! \brief specify container node */
+  using ContainerType = IterVarNode;
+};
+
+using Domain = Array<Range>;
+
+/*!
+ * \brief An iteration variable representing an iteration
+ *  over a one dimensional interval.
+ */
+class IterVarNode : public Object {
+ public:
+  /*!
+   * \brief the domain of iteration, if known, can be None
+   *  For the intermediate schedule node, before schedule.
+   */
+  Range dom;
+  /*! \brief The looping variable */
+  Var var;
+  /*! \brief The type of the IterVar */
+  IterVarType iter_type;
+  /*!
+   * \brief additional tag on the iteration variable,
+   *  set this if this is binded already to a known thread tag.
+   */
+  std::string thread_tag;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("dom", &dom);
+    v->Visit("var", &var);
+    v->Visit("iter_type", &iter_type);
+    v->Visit("thread_tag", &thread_tag);
+  }
+
+  bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
+    return
+        equal(dom, other->dom) &&
+        equal.DefEqual(var, other->var) &&
+        equal(iter_type, other->iter_type) &&
+        equal(thread_tag, other->thread_tag);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(dom);
+    hash_reduce.DefHash(var);
+    hash_reduce(iter_type);
+    hash_reduce(thread_tag);
+  }
+
+  TVM_DLL static IterVar make(Range dom, Var var,
+                              IterVarType iter_type,
+                              std::string thread_tag = "");
+
+  static constexpr const char* _type_key = "IterVar";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const bool _type_has_method_shash_reduce = true;
+  TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
+};
+
+// inline implementations
+inline const IterVarNode* IterVar::operator->() const {
+  return static_cast<const IterVarNode*>(data_.get());
+}
+
+inline IterVar::operator PrimExpr() const {
+  return (*this)->var;
+}
+
+inline const char* IterVarType2String(IterVarType t) {
+  switch (t) {
+    case kDataPar: return "DataPar";
+    case kThreadIndex: return "ThreadIndex";
+    case kCommReduce: return "CommReduce";
+    case kOrdered: return "Ordered";
+    case kOpaque: return "Opaque";
+    case kUnrolled: return "Unrolled";
+    case kVectorized: return "Vectorized";
+    case kParallelized: return "Parallelized";
+    case kTensorized: return "Tensorized";
+  }
+  return "Unknown";
+}
+}  // namespace tir
+}  // namespace tvm
+#endif  // TVM_TIR_VAR_H_
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 653c395..bd8e33f 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -24,11 +24,11 @@ from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
 from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
 from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
 from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let
+from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
 from .expr import IterVar, Any
 
 from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
-from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
+from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
 from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
 
 from .function import PrimFunc
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index a192fce..20a3bca 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -14,18 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Expression AST Node in TVM.
+# pylint: disable=redefined-builtin
+"""TIR expression nodes.
 
-User do not need to deal with expression AST node directly.
-But they can be helpful for developer to do quick proptyping.
-While not displayed in the document and python file.
 Each expression node have subfields that can be visited from python side.
-
 For example, you can use addexp.a to get the left operand of an Add node.
 
 .. code-block:: python
 
-  x = te.var("n")
+  x = tvm.tir.Var("n", "int32")
   y = x + 2
   assert(isinstance(y, tvm.tir.Add))
   assert(y.a == x)
@@ -859,6 +856,23 @@ class Load(PrimExprWithOp):
 
 
 @tvm._ffi.register_object
+class BufferLoad(PrimExprWithOp):
+    """Buffer load node.
+
+    Parameters
+    ----------
+    buffer : Buffer
+        The buffer to be loaded.
+
+    indices : List[PrimExpr]
+        The buffer indices.
+    """
+    def __init__(self, buffer, indices):
+        self.__init_handle_by_constructor__(
+            _ffi_api.BufferLoad, buffer, indices)
+
+
+@tvm._ffi.register_object
 class Ramp(PrimExprWithOp):
     """Ramp node.
 
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 65c72dd..0badad3 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -16,15 +16,12 @@
 # under the License.
 """Statement AST Node in TVM.
 
-User do not need to deal with AST node directly.
-But they can be helpful for developer to do quick proptyping.
-While not displayed in the document and python file.
 Each statement node have subfields that can be visited from python side.
 
 .. code-block:: python
 
-    x = te.var("n")
-    a = te.var("array", "handle")
+    x = tvm.tir.Var("n", "int32")
+    a = tvm.tir.Var("array", "handle")
     st = tvm.tir.stmt.Store(a, x + 1, 1)
     assert isinstance(st, tvm.tir.stmt.Store)
     assert(st.buffer_var == a)
@@ -164,6 +161,26 @@ class Store(Stmt):
 
 
 @tvm._ffi.register_object
+class BufferStore(Stmt):
+    """Buffer store node.
+
+    Parameters
+    ----------
+    buffer : Buffer
+        The buffer.
+
+    value : PrimExpr
+        The value we to be stored.
+
+    indices : List[PrimExpr]
+        The indices location to be stored.
+    """
+    def __init__(self, buffer, value, indices):
+        self.__init_handle_by_constructor__(
+            _ffi_api.BufferStore, buffer, value, indices)
+
+
+@tvm._ffi.register_object
 class Provide(Stmt):
     """Provide node.
 
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index bee0256..891d137 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -407,6 +407,22 @@ PrimExpr AnyNode::make() {
   return PrimExpr(n);
 }
 
+BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices) {
+  ObjectPtr<BufferLoadNode> node = make_object<BufferLoadNode>();
+  node->dtype = buffer->dtype;
+  node->buffer = std::move(buffer);
+  node->indices = std::move(indices);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.BufferLoad")
+.set_body_typed([](Buffer buffer, Array<PrimExpr> indices) {
+  return BufferLoad(buffer, indices);
+});
+
+TVM_REGISTER_NODE_TYPE(BufferLoadNode);
+
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const StringImmNode*>(node.get());
diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc
index f8371f3..57ff627 100644
--- a/src/tir/ir/expr_functor.cc
+++ b/src/tir/ir/expr_functor.cc
@@ -36,6 +36,10 @@ void ExprVisitor::VisitExpr_(const LoadNode* op) {
   this->VisitExpr(op->predicate);
 }
 
+void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
+  VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
+}
+
 void ExprVisitor::VisitExpr_(const LetNode* op) {
   this->VisitExpr(op->value);
   this->VisitExpr(op->body);
@@ -128,6 +132,16 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
   }
 }
 
+PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
+  auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
+  Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
+  if (indices.same_as(op->indices)) {
+    return GetRef<PrimExpr>(op);
+  } else {
+    return BufferLoad(op->buffer, indices);
+  }
+}
+
 PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {
   PrimExpr value = this->VisitExpr(op->value);
   PrimExpr body = this->VisitExpr(op->body);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index a8fe9cd..64e7ef5 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -324,6 +324,21 @@ Stmt EvaluateNode::make(PrimExpr value) {
 TVM_REGISTER_GLOBAL("tir.Evaluate")
 .set_body_typed(EvaluateNode::make);
 
+BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+  ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
+  node->buffer = std::move(buffer);
+  node->value = std::move(value);
+  node->indices = std::move(indices);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.BufferStore")
+.set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+  return BufferStore(buffer, value, indices);
+});
+
+TVM_REGISTER_NODE_TYPE(BufferStoreNode);
+
 // Printers
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index b4b27b9..ea19982 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -160,6 +160,10 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) {
   this->VisitExpr(op->predicate);
 }
 
+void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
+  VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
+}
+
 void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
   this->VisitExpr(op->condition);
   this->VisitStmt(op->then_case);
@@ -343,6 +347,17 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
   }
 }
 
+Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
+  Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
+  if (indices.same_as(op->indices)) {
+    return GetRef<Stmt>(op);
+  } else {
+    auto n = CopyOnWrite(op);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+}
+
 Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
   Array<PrimExpr> args = Internal::Mutate(this, op->args);
   PrimExpr value = this->VisitExpr(op->value);
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index 2904953..2e23a61 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -292,7 +292,18 @@ def test_vars():
     assert isinstance(ptype.element_type, tvm.ir.PrimType)
 
 
+def test_buffer_load_store():
+    b = tvm.tir.decl_buffer((10,), "float32")
+    x = tvm.tir.BufferLoad(b, [0])
+    assert isinstance(x, tvm.tir.BufferLoad)
+    assert x.dtype == "float32"
+    assert x.buffer == b
+    s = tvm.tir.BufferStore(b, 0.1, [0])
+    assert isinstance(s, tvm.tir.BufferStore)
+
+
 if __name__ == "__main__":
+    test_buffer_load_store()
     test_vars()
     test_prim_func()
     test_cast()
diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py
index 3fcdc65..593b845 100644
--- a/tests/python/unittest/test_tir_structural_equal_hash.py
+++ b/tests/python/unittest/test_tir_structural_equal_hash.py
@@ -166,6 +166,22 @@ def test_stmt():
     assert consistent_equal(func2(), func2())
 
 
+def test_buffer_load_store():
+    b = tvm.tir.decl_buffer((10, 10), "float32")
+    x = tvm.tir.BufferLoad(b, [0, 1])
+    y = tvm.tir.BufferLoad(b, [0, 1])
+    z = tvm.tir.BufferLoad(b, [1, 2])
+    assert consistent_equal(y, x)
+    assert not consistent_equal(y, z)
+
+    i = tvm.tir.Var("x", "int32")
+    sx = tvm.tir.BufferStore(b, 0.1, [0, i])
+    sy = tvm.tir.BufferStore(b, 0.1, [0, i])
+    sz = tvm.tir.BufferStore(b, 0.1, [1, i])
+    assert consistent_equal(sy, sx)
+    assert not consistent_equal(sy, sz)
+
+
 if __name__ == "__main__":
     test_exprs()
     test_prim_func()
@@ -173,3 +189,4 @@ if __name__ == "__main__":
     test_array()
     test_env_func()
     test_stmt()
+    test_buffer_load_store()