You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/07 08:24:23 UTC

[tvm] 01/01: Squashed

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch ir-builder
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 3857a97931deb50e9ca6e33721a81476935d18ca
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Mon May 23 17:06:45 2022 -0700

    Squashed
---
 include/tvm/ir/expr.h                              |   20 +-
 include/tvm/ir/ir_builder.h                        |  188 ++++
 include/tvm/support/with.h                         |    2 +
 include/tvm/tir/ir_builder.h                       |  138 +++
 include/tvm/tir/ir_builder_frame.h                 |  454 ++++++++
 include/tvm/tir/op.h                               |   34 +-
 python/tvm/ir/__init__.py                          |   55 +-
 .../_ffi_api.py => ir/_ffi_ir_builder_api.py}      |    4 +-
 python/tvm/ir/ir_builder.py                        |   84 ++
 python/tvm/script/__init__.py                      |   17 +-
 python/tvm/script/{ => parser}/__init__.py         |   14 +-
 python/tvm/script/parser/diagnostics.py            |   60 ++
 python/tvm/script/parser/dispatch.py               |   63 ++
 python/tvm/script/parser/doc.py                    |  341 ++++++
 python/tvm/script/parser/doc_core.py               | 1140 ++++++++++++++++++++
 .../script/{tir/prim_func.py => parser/entry.py}   |   46 +-
 python/tvm/script/parser/evaluator.py              |  282 +++++
 .../script/{_ffi_api.py => parser/ir/__init__.py}  |    7 +-
 .../tvm/script/{__init__.py => parser/ir/entry.py} |   19 +-
 .../script/{__init__.py => parser/ir/parser.py}    |   24 +-
 python/tvm/script/parser/parser.py                 |  182 ++++
 python/tvm/script/parser/source.py                 |   89 ++
 python/tvm/script/{ => parser/tir}/__init__.py     |    9 +-
 python/tvm/script/parser/tir/entry.py              |   97 ++
 python/tvm/script/parser/tir/operation.py          |   85 ++
 python/tvm/script/parser/tir/parser.py             |  262 +++++
 python/tvm/script/parser/utils.py                  |   63 ++
 python/tvm/script/parser/var_table.py              |   71 ++
 python/tvm/script/{ => parser_v1}/__init__.py      |    3 +-
 python/tvm/script/{ => parser_v1}/_ffi_api.py      |    0
 .../script/{ => parser_v1}/context_maintainer.py   |    8 +-
 python/tvm/script/{ => parser_v1}/diagnostics.py   |    6 +-
 python/tvm/script/{ => parser_v1}/highlight.py     |    0
 python/tvm/script/{ => parser_v1}/meta_unparser.py |    0
 python/tvm/script/{ => parser_v1}/parser.py        |   23 +-
 python/tvm/script/{ => parser_v1}/registry.py      |    2 +-
 python/tvm/script/{ => parser_v1}/tir/__init__.py  |    0
 python/tvm/script/{ => parser_v1}/tir/__init__.pyi |    0
 python/tvm/script/{ => parser_v1}/tir/intrin.py    |    5 +-
 python/tvm/script/{ => parser_v1}/tir/node.py      |    7 +-
 python/tvm/script/{ => parser_v1}/tir/prim_func.py |    3 +-
 .../script/{ => parser_v1}/tir/scope_handler.py    |   17 +-
 .../tvm/script/{ => parser_v1}/tir/special_stmt.py |   16 +-
 python/tvm/script/{ => parser_v1}/tir/ty.py        |    1 +
 python/tvm/script/{ => parser_v1}/utils.py         |    7 +-
 python/tvm/tir/__init__.py                         |  212 +++-
 .../_ffi_api.py => tir/_ffi_ir_builder_api.py}     |    4 +-
 python/tvm/tir/analysis/analysis.py                |    4 +-
 python/tvm/tir/buffer.py                           |   36 +-
 python/tvm/tir/expr.py                             |   15 +-
 python/tvm/tir/ir_builder_frame.py                 |  118 ++
 python/tvm/tir/{ir_builder.py => ir_builder_v1.py} |    6 +-
 python/tvm/tir/ir_builder_v2.py                    |  901 ++++++++++++++++
 python/tvm/tir/op.py                               |  551 +++++++++-
 python/tvm/tir/schedule/block_scope.py             |    2 +-
 python/tvm/tir/schedule/schedule.py                |    6 +-
 python/tvm/tir/schedule/state.py                   |    3 +-
 python/tvm/tir/stmt.py                             |    4 +
 python/tvm/tir/usmp/transform/transform.py         |    5 +-
 src/ir/expr.cc                                     |   13 +
 src/ir/ir_builder.cc                               |  134 +++
 src/tir/ir/expr.cc                                 |   19 +-
 src/tir/ir/script/script_complete.cc               |    5 +-
 src/tir/ir/script/script_complete.h                |   35 +
 src/tir/ir/stmt.cc                                 |    2 +
 src/tir/ir_builder/ir_builder.cc                   |  637 +++++++++++
 src/tir/ir_builder/ir_builder_frame.cc             |  207 ++++
 src/tir/ir_builder/utils.h                         |   92 ++
 src/tir/op/op.cc                                   |   24 +
 tests/python/tvmscript/test_builder_basic.py       |  227 ++++
 tests/python/tvmscript/test_parse_basic.py         |  118 ++
 tests/python/tvmscript/test_parser_capture.py      |   43 +
 72 files changed, 7174 insertions(+), 197 deletions(-)

diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 5e358ed50e..dcabd7d3f1 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -764,16 +764,32 @@ struct PackedFuncValueConverter<PrimExpr> {
       return PrimExpr(ObjectPtr<Object>(nullptr));
     }
     if (val.type_code() == kDLInt) {
-      return PrimExpr(val.operator int());
+      return IntImm(runtime::DataType::Int(32), val.operator int());
     }
     if (val.type_code() == kDLFloat) {
-      return PrimExpr(static_cast<float>(val.operator double()));
+      return FloatImm(runtime::DataType::Float(32), val.operator double());
     }
 
     return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
   }
 };
 
+template <>
+struct PackedFuncValueConverter<Array<PrimExpr>> {
+  static Array<PrimExpr> From(const TVMPODValue_& val) {
+    if (val.type_code() == kTVMNullptr) return Array<PrimExpr>(nullptr);
+    Array<ObjectRef> vals = val.AsObjectRef<Array<ObjectRef>>();
+    Array<PrimExpr> exprs;
+    for (const ObjectRef& v : vals) {
+      TVMValue value;
+      value.v_handle = const_cast<void*>(static_cast<const void*>(v.get()));
+      exprs.push_back(
+          PackedFuncValueConverter<PrimExpr>::From(TVMArgValue(value, kTVMObjectHandle)));
+    }
+    return exprs;
+  }
+};
+
 template <>
 struct PackedFuncValueConverter<tvm::Integer> {
   static tvm::Integer From(const TVMPODValue_& val) {
diff --git a/include/tvm/ir/ir_builder.h b/include/tvm/ir/ir_builder.h
new file mode 100644
index 0000000000..5e099e7310
--- /dev/null
+++ b/include/tvm/ir/ir_builder.h
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_IR_IR_BUILDER_H_
+#define TVM_IR_IR_BUILDER_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+
+namespace tvm {
+namespace ir_builder {
+
+////////////////////////////// Core Infra: Frame and IRBuilder //////////////////////////////
+
+class IRBuilderFrameNode : public runtime::Object {
+ public:
+  std::vector<runtime::TypedPackedFunc<void()>> callbacks;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `callbacks` is not visited.
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRBuilderFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
+
+ public:
+  virtual ~IRBuilderFrameNode() = default;
+  virtual void EnterWithScope();
+  virtual void ExitWithScope();
+
+  void AddCallback(runtime::TypedPackedFunc<void()> callback);
+};
+
+class IRBuilderFrame : public runtime::ObjectRef {
+ public:
+  virtual ~IRBuilderFrame() = default;
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
+
+ protected:
+  IRBuilderFrame() = default;
+
+ public:
+  inline void EnterWithScope();
+  inline void ExitWithScope();
+};
+
+class IRBuilderNode : public runtime::Object {
+ public:
+  runtime::Array<IRBuilderFrame> frames;
+  Optional<ObjectRef> result;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("frames", &frames);
+    v->Visit("result", &result);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRBuilder";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
+
+ public:
+  template <typename TFrame>
+  inline Optional<TFrame> FindFrame() const;
+  template <typename TFrame>
+  inline Optional<TFrame> GetLastFrame() const;
+  template <typename TObjectRef>
+  inline TObjectRef Get() const;
+};
+
+class IRBuilder : public runtime::ObjectRef {
+ public:
+  explicit IRBuilder();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
+
+ public:
+  void EnterWithScope();
+  void ExitWithScope();
+  static IRBuilder Current();
+  template <class TObjectRef>
+  inline static TObjectRef Name(String name, TObjectRef obj);
+};
+
+////////////////////////////// Generic IRModule //////////////////////////////
+
+class IRModuleFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<GlobalVar> global_vars;
+  Array<BaseFunc> functions;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("global_vars", &global_vars);
+    v->Visit("functions", &functions);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRModuleFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IRModuleFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame,
+                                                    IRModuleFrameNode);
+};
+
+TVM_DLL IRModuleFrame IRModule();
+
+////////////////////////////// Details //////////////////////////////
+
+namespace details {
+
+class Namer {
+ public:
+  using FType = NodeFunctor<void(const ObjectRef&, String)>;
+  static FType& vtable();
+  static void Name(ObjectRef node, String name);
+};
+
+}  // namespace details
+
+template <class TObjectRef>
+inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
+  details::Namer::Name(obj, name);
+  return Downcast<TObjectRef>(obj);
+}
+
+inline void IRBuilderFrame::EnterWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
+}
+
+inline void IRBuilderFrame::ExitWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
+  data_.reset();
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::FindFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
+    if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
+      return GetRef<TFrame>(p);
+    }
+  }
+  return NullOpt;
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
+    return Downcast<TFrame>(frames.back());
+  }
+  return NullOpt;
+}
+
+template <typename TObjectRef>
+inline TObjectRef IRBuilderNode::Get() const {
+  using TObject = typename TObjectRef::ContainerType;
+  CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
+  const auto* n = result.as<TObject>();
+  CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key;
+  return GetRef<TObjectRef>(n);
+}
+
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_IR_IR_BUILDER_H_
diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h
index 3651e05e74..bbc36419ff 100644
--- a/include/tvm/support/with.h
+++ b/include/tvm/support/with.h
@@ -75,6 +75,8 @@ class With {
   ContextType& operator*() { return *get(); }
   const ContextType* operator*() const { return *get(); }
 
+  ContextType operator()() { return ctx_; }
+
  private:
   /*! \brief internal context type. */
   ContextType ctx_;
diff --git a/include/tvm/tir/ir_builder.h b/include/tvm/tir/ir_builder.h
new file mode 100644
index 0000000000..5d4c61635c
--- /dev/null
+++ b/include/tvm/tir/ir_builder.h
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_H_
+#define TVM_TIR_IR_BUILDER_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/ir_builder_frame.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+using tvm::runtime::NDArray;
+using tvm::tir::Buffer;
+using tvm::tir::IterVar;
+using tvm::tir::Var;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+                  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+                  String storage_scope, int align, int offset_factor, String buffer_type,
+                  Optional<Array<IntImm>> axis_separators);
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
+BlockFrame Block(String name, bool no_realize = false);
+BlockInitFrame Init();
+void Where(PrimExpr predicate);
+void Reads(Array<ObjectRef> buffer_slices);
+void Writes(Array<ObjectRef> buffer_slices);
+void BlockAttrs(Map<String, ObjectRef> attrs);
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
+                   int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+
+namespace axis {
+IterVar Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Array<IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+}  // namespace axis
+
+ForFrame Serial(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Parallel(PrimExpr start, PrimExpr stop,
+                  Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Vectorized(PrimExpr start, PrimExpr stop,
+                    Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Unroll(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Grid(Array<PrimExpr> extents);
+
+PrimFuncFrame PrimFunc();
+Var Arg(String name, Var var);
+Buffer Arg(String name, Buffer buffer);
+void FuncName(String name);
+void FuncAttrs(Map<String, ObjectRef> attrs);
+Type FuncRet(Type ret_type);
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
+                   int align = -1, int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
+                        DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
+                        Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
+                        String storage_scope = "global", int align = -1, int offset_factor = 0,
+                        String buffer_type = "default", Array<IntImm> axis_separators = {});
+
+AssertFrame Assert(PrimExpr condition, String message);
+LetFrame Let(Var var, PrimExpr value);
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+                       Optional<PrimExpr> condition = NullOpt,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+AllocateConstFrame AllocateConst(
+    NDArray data, DataType dtype, Array<PrimExpr> extents,
+    Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+WhileFrame While(PrimExpr condition);
+IfFrame If(PrimExpr condition);
+ThenFrame Then();
+ElseFrame Else();
+LaunchThreadFrame LaunchThread(IterVar iter_var, PrimExpr extent);
+IterVar EnvThread(String thread_tag);
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+void Prefetch(Buffer buffer, Array<Range> bounds);
+void Evaluate(PrimExpr value);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)                             \
+  inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) {                        \
+    DataType dtype = DType;                                                            \
+    return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
+  }
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
+
+#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_H_
diff --git a/include/tvm/tir/ir_builder_frame.h b/include/tvm/tir/ir_builder_frame.h
new file mode 100644
index 0000000000..549847da6c
--- /dev/null
+++ b/include/tvm/tir/ir_builder_frame.h
@@ -0,0 +1,454 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_FRAME_H_
+#define TVM_TIR_IR_BUILDER_FRAME_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+class TIRFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<tvm::tir::Stmt> stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("stmts", &stmts);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.TIRFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode);
+};
+
+class TIRFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode);
+
+ protected:
+  TIRFrame() = default;
+};
+
+class BlockFrameNode : public TIRFrameNode {
+ public:
+  String name;
+  Array<tvm::tir::IterVar> iter_vars;
+  Optional<Array<tvm::tir::BufferRegion>> reads;
+  Optional<Array<tvm::tir::BufferRegion>> writes;
+  Optional<tvm::tir::Stmt> init;
+  Array<tvm::tir::Buffer> alloc_buffers;
+  Array<tvm::tir::MatchBufferRegion> match_buffers;
+  Map<String, ObjectRef> annotations;
+
+  Array<PrimExpr> iter_values;
+  Optional<PrimExpr> predicate;
+  bool no_realize;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("iter_vars", &iter_vars);
+    v->Visit("reads", &reads);
+    v->Visit("writes", &writes);
+    v->Visit("init", &init);
+    v->Visit("alloc_buffers", &alloc_buffers);
+    v->Visit("match_buffers", &match_buffers);
+    v->Visit("annotations", &annotations);
+    v->Visit("iter_values", &iter_values);
+    v->Visit("predicate", &predicate);
+    v->Visit("no_realize", &no_realize);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.BlockFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class BlockFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
+};
+
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+  static constexpr const char* _type_key = "ir_builder.tir.BlockInitFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class BlockInitFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode);
+};
+
+class ForFrameNode : public TIRFrameNode {
+ public:
+  using FMakeForLoop =
+      runtime::TypedPackedFunc<tvm::tir::Stmt(Array<tvm::tir::Var>, Array<Range>, tvm::tir::Stmt)>;
+
+  Array<tvm::tir::Var> vars;
+  Array<Range> doms;
+  FMakeForLoop f_make_for_loop;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("vars", &vars);
+    v->Visit("doms", &doms);
+    // `f_make_for_loop` is not visited.
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.ForFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class ForFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
+};
+
+class PrimFuncFrameNode : public TIRFrameNode {
+ public:
+  Optional<String> name;
+  Array<tvm::tir::Var> args;
+  Optional<Type> ret_type;
+  Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
+  Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
+  Map<String, ObjectRef> attrs;
+  Array<tvm::tir::Buffer> root_alloc_buffers;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("args", &args);
+    v->Visit("ret_type", &ret_type);
+    v->Visit("buffer_map", &buffer_map);
+    v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
+    v->Visit("attrs", &attrs);
+    v->Visit("root_alloc_buffers", &root_alloc_buffers);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.PrimFuncFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class PrimFuncFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
+};
+
+class AssertFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  PrimExpr message;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("message", &message);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AssertFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AssertFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+class LetFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Var var;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.LetFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LetFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+  Array<PrimExpr> extents;
+  DataType dtype;
+  String storage_scope;
+  PrimExpr condition;
+  Map<String, ObjectRef> annotations;
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extents", &extents);
+    v->Visit("dtype", &dtype);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+    v->Visit("annotations", &annotations);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AllocateFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+  DataType dtype;
+  Array<PrimExpr> extents;
+  tvm::runtime::NDArray data;
+  tvm::tir::Buffer buffer;
+  Map<String, ObjectRef> annotations;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("dtype", &dtype);
+    v->Visit("extents", &extents);
+    v->Visit("data", &data);
+    v->Visit("buffer", &buffer);
+    v->Visit("annotations", &annotations);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AllocateConstFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateConstFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+                                                    AllocateConstFrameNode);
+};
+
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr extent;
+  String attr_key;
+  tvm::tir::IterVar iter_var;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extent", &extent);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("iter_var", &iter_var);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.LaunchThreadFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LaunchThreadFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+                                                    LaunchThreadFrameNode);
+};
+
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::BufferRegion buffer_slice;
+  String storage_scope;
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer_slice", &buffer_slice);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.RealizeFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class RealizeFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
+
+class AttrFrameNode : public TIRFrameNode {
+ public:
+  ObjectRef node;
+  String attr_key;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("node", &node);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AttrFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AttrFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+class WhileFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.WhileFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class WhileFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+class IfFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  Optional<Array<tvm::tir::Stmt>> then_stmts;
+  Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("then_stmts", &then_stmts);
+    v->Visit("else_stmts", &else_stmts);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.IfFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IfFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+class ThenFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "ir_builder.tir.ThenFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ThenFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+class ElseFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "ir_builder.tir.ElseFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ElseFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.DeclBufferFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_FRAME_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 7236c6a611..09758cc923 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
  * \return The result expression.
  */
 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
 /*!
  * \brief Check if x is infinite.
  * \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
                       Span span = Span());
 
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
 /*!
  * \brief Calculate floor(x)
  * \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
                                   Span span = Span());
 
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
 // Intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
   inline PrimExpr OpName(PrimExpr x, Span span = Span()) {              \
@@ -701,6 +732,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
 TVM_DECLARE_INTRIN_UNARY(log);
 TVM_DECLARE_INTRIN_UNARY(log2);
 TVM_DECLARE_INTRIN_UNARY(log10);
+TVM_DECLARE_INTRIN_UNARY(log1p);
 TVM_DECLARE_INTRIN_UNARY(popcount);
 TVM_DECLARE_INTRIN_UNARY(tan);
 TVM_DECLARE_INTRIN_UNARY(cos);
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4e847c0310..51dff32ac6 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -16,29 +16,46 @@
 # under the License.
 # pylint: disable=unused-import
 """Common data structures across all IR variants."""
-from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .base import structural_equal, assert_structural_equal, structural_hash
-from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
-from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
-from .tensor_type import TensorType
-from .affine_type import TensorAffineType, TupleAffineType
-from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
-from .op import Op, register_op_attr, register_intrin_lowering
-from .function import CallingConv, BaseFunc
+from . import diagnostics, instrument, ir_builder, transform
 from .adt import Constructor, TypeData
-from .module import IRModule
+from .affine_type import TensorAffineType, TupleAffineType
 from .attrs import Attrs, DictAttrs, make_node
+from .base import (
+    EnvFunc,
+    Node,
+    SourceName,
+    Span,
+    assert_structural_equal,
+    load_json,
+    save_json,
+    structural_equal,
+    structural_hash,
+)
 from .container import Array, Map
+from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
+from .function import BaseFunc, CallingConv
 from .memory_pools import (
-    PoolInfo,
-    WorkspacePoolInfo,
-    ConstantPoolInfo,
-    WorkspaceMemoryPools,
     ConstantMemoryPools,
+    ConstantPoolInfo,
+    PoolInfo,
     PoolInfoProperties,
+    WorkspaceMemoryPools,
+    WorkspacePoolInfo,
 )
-
-from . import transform
-from . import instrument
-from . import diagnostics
+from .module import IRModule
+from .op import Op, register_intrin_lowering, register_op_attr
+from .tensor_type import TensorType
+from .type import (
+    FuncType,
+    GlobalTypeVar,
+    IncompleteType,
+    PointerType,
+    PrimType,
+    RelayRefType,
+    TupleType,
+    Type,
+    TypeConstraint,
+    TypeKind,
+    TypeVar,
+)
+from .type_relation import TypeCall, TypeRelation
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/ir/_ffi_ir_builder_api.py
similarity index 88%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/ir/_ffi_ir_builder_api.py
index 926d17b166..9d08bc9b70 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/ir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
 import tvm._ffi
 
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/ir/ir_builder.py b/python/tvm/ir/ir_builder.py
new file mode 100644
index 0000000000..df05bf8361
--- /dev/null
+++ b/python/tvm/ir/ir_builder.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A generic IRBuilder across the TVM stack"""
+from typing import List, TypeVar
+
+from tvm._ffi import register_object as _register_object
+from tvm.runtime import Object as _Object
+
+from . import _ffi_ir_builder_api as _ffi_api
+
+
+@_register_object("ir_builder.IRBuilderFrame")
+class IRBuilderFrame(_Object):
+    def __enter__(self) -> "IRBuilderFrame":
+        _ffi_api.IRBuilderFrameEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameExit(self)  # pylint: disable=no-member # type: ignore
+
+    def add_callback(self, callback) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameAddCallback(  # pylint: disable=no-member # type: ignore
+            self, callback
+        )
+
+
+@_register_object("ir_builder.IRBuilder")
+class IRBuilder(_Object):
+    def __init__(self) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.IRBuilder  # pylint: disable=no-member # type: ignore
+        )
+
+    def __enter__(self) -> "IRBuilder":
+        _ffi_api.IRBuilderEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderExit(self)  # pylint: disable=no-member # type: ignore
+
+    @staticmethod
+    def current() -> "IRBuilder":
+        return _ffi_api.IRBuilderCurrent()  # pylint: disable=no-member # type: ignore
+
+    def get(self) -> _Object:
+        return _ffi_api.IRBuilderGet(self)  # pylint: disable=no-member # type: ignore
+
+
+DefType = TypeVar("DefType", bound=_Object)
+
+
+def name(s: str, v: DefType) -> DefType:
+    return _ffi_api.IRBuilderName(s, v)  # pylint: disable=no-member # type: ignore
+
+
+def name_many(
+    s: List[str],
+    vs: List[DefType],  # pylint: disable=invalid-name
+) -> List[DefType]:
+    assert len(s) == len(vs)
+    return [name(i, v) for i, v in zip(s, vs)]
+
+
+@_register_object("ir_builder.IRModuleFrame")
+class IRModuleFrame(IRBuilderFrame):
+    ...
+
+
+def ir_module() -> IRModuleFrame:
+    return _ffi_api.IRModule()  # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..d9605b2e70 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,20 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
+from . import parser, parser_v1
 
-from . import tir
+#############
+from .parser import ir as ir_v2
+from .parser import ir_module as ir_module_v2
+from .parser import parse as from_source_v2
+from .parser import tir as tir_v2
 
-from .parser import ir_module, from_source
+#############
+from .parser_v1 import from_source as from_source_v1
+from .parser_v1 import ir_module as ir_module_v1
+from .parser_v1 import tir as tir_v1
+
+ir = ir_v2
+ir_module = ir_module_v2
+tir = tir_v2
+from_source = from_source_v2
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 77%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..d8530e0ab1 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,13 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
+# under the Licens.
+"""The parser"""
+from . import dispatch as _dispatch
+from . import doc as _doc
+from . import ir
+from . import parser as _parser
 from . import tir
-
-from .parser import ir_module, from_source
+from .entry import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/parser/diagnostics.py b/python/tvm/script/parser/diagnostics.py
new file mode 100644
index 0000000000..fd06d20e61
--- /dev/null
+++ b/python/tvm/script/parser/diagnostics.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+from .source import Source
+
+
+class Diagnostics:
+
+    source: Source
+    ctx: diagnostics.DiagnosticContext
+
+    def __init__(self, source: Source):
+        mod = IRModule()
+        mod.source_map.add(source.source_name, source.full_source)
+        self.source = source
+        self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+    def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+        lineno = node.lineno
+        col_offset = node.col_offset
+        end_lineno = node.end_lineno or lineno
+        end_col_offset = node.end_col_offset or col_offset
+        lineno += self.source.start_line - 1
+        end_lineno += self.source.start_line - 1
+        col_offset += self.source.start_column + 1
+        end_col_offset += self.source.start_column + 1
+        self.ctx.emit(
+            diagnostics.Diagnostic(
+                level=level,
+                span=Span(
+                    source_name=SourceName(self.source.source_name),
+                    line=lineno,
+                    end_line=end_lineno,
+                    column=col_offset,
+                    end_column=end_col_offset,
+                ),
+                message=message,
+            )
+        )
+
+    def error(self, node: doc.AST, message: str) -> None:
+        self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+        self.ctx.render()
diff --git a/python/tvm/script/parser/dispatch.py b/python/tvm/script/parser/dispatch.py
new file mode 100644
index 0000000000..6237371a30
--- /dev/null
+++ b/python/tvm/script/parser/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+    """Register a method for a dispatch token and type name"""
+
+    def f(method: ParseMethod):
+        ParseVTable[(token, type_name)] = method
+
+    return f
+
+
+def get(
+    token: str,
+    type_name: str,
+    default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+    return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int):  # pylint: disable=invalid-name
+    def f(method: OpMethod):
+        OpVTable[(ty, op, operand_index)] = method
+
+    return f
+
+
+def get_op(
+    ty: Type,  # pylint: disable=invalid-name
+    op: Type,
+    operand_index: int,
+    default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+    return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/doc.py b/python/tvm/script/parser/doc.py
new file mode 100644
index 0000000000..15ad166c33
--- /dev/null
+++ b/python/tvm/script/parser/doc.py
@@ -0,0 +1,341 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import *  # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+    to_doc: typing.Optional[FnToDoc]
+    from_doc: typing.Optional[FnFromDoc]
+
+    def __init__(self):
+        self.to_doc = None
+        self.from_doc = None
+
+
+class Registry:
+    _inst: typing.Optional["Registry"] = None
+    table: typing.Dict[str, Entry]
+
+    def __init__(self):
+        self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+    def f(to_doc: FnToDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].to_doc = to_doc
+
+    return f
+
+
+def register_from_doc(name: str):
+    def f(to_doc: FnFromDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].from_doc = to_doc
+
+    return f
+
+
+def _is_atomic_type(node):
+    return (
+        node is None
+        or node in [..., True, False]
+        or isinstance(
+            node,
+            (
+                int,
+                float,
+                str,
+                bool,
+                bytes,
+                complex,
+            ),
+        )
+    )
+
+
+def _get_registry_entry(cls_name, attr):
+    cls_name = cls_name.split(".")[-1]
+    reg = Registry._inst  # pylint: disable=protected-access
+    if cls_name in reg.table:
+        entry = reg.table[cls_name]
+        return getattr(entry, attr, None)
+    return None
+
+
+def from_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(from_doc(n) for n in node)
+    if isinstance(node, list):
+        return [from_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "from_doc")
+    if not func:
+        raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def to_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(to_doc(n) for n in node)
+    if isinstance(node, list):
+        return [to_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "to_doc")
+    if not func:
+        raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def parse(
+    source,
+    filename="<unknown>",
+    mode="exec",
+) -> doc.AST:
+    try:
+        program = ast.parse(
+            source=source,
+            filename=filename,
+            mode=mode,
+            feature_version=(3, 8),
+        )
+    except:  # pylint: disable=bare-except
+        program = ast.parse(
+            source=source,
+            filename=filename,
+            mode=mode,
+        )
+    return to_doc(program)
+
+
+class NodeVisitor:
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> None:
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                self.visit(value)
+
+
+class NodeTransformer:
+    def visit(self, node: doc.AST) -> doc.AST:
+        if isinstance(node, list):
+            return [self.visit(item) for item in node]
+        if isinstance(node, tuple):
+            return tuple(self.visit(item) for item in node)
+        if not isinstance(node, doc.AST):
+            return node
+        return getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> doc.AST:
+        kv: typing.Dict[str, typing.Any] = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                value = self.visit(value)
+            kv[field] = value
+        return node.__class__(**kv)
+
+
+def _register_default():
+    class DefaultTranslator:
+        def __init__(self, doc_cls, func, fields):
+            self.doc_cls = doc_cls  # getattr(doc, name)
+            self.func = func
+            self.fields = fields
+
+        def __call__(self, node):
+            kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+            return self.doc_cls(**kv)
+
+    Registry._inst = Registry()  # pylint: disable=protected-access
+    for cls_name in dir(doc):
+        doc_cls = getattr(doc, cls_name)
+        if not hasattr(ast, cls_name):
+            continue
+        if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+            assert "." not in cls_name
+            register_to_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(doc, cls_name),
+                    to_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+            register_from_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(ast, cls_name),
+                    from_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+    return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+    if _py_version() not in [(3, 6), (3, 7)]:
+        return
+
+    def as_constant(f) -> doc.Constant:
+        def to_doc_func(x: ast.AST) -> doc.Constant:
+            return doc.Constant(
+                value=getattr(x, f) if isinstance(f, str) else f(x),
+                kind=None,
+                s=None,
+                n=None,
+                lineno=x.lineno,
+                col_offset=x.col_offset,
+                end_lineno=x.lineno,
+                end_col_offset=x.col_offset,
+            )
+
+        return to_doc_func
+
+    register_to_doc("Str")(as_constant("s"))
+    register_to_doc("NameConstant")(as_constant("value"))
+    register_to_doc("Num")(as_constant("n"))
+    register_to_doc("Bytes")(as_constant("s"))
+    register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+        if isinstance(x.slice, ast.Slice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Slice(
+                    lower=to_doc(x.slice.lower),
+                    upper=to_doc(x.slice.upper),
+                    step=to_doc(x.slice.step),
+                    lineno=getattr(x.slice, "lineno", None),
+                    col_offset=getattr(x.slice, "col_offset", None),
+                    end_lineno=getattr(x.slice, "end_lineno", None),
+                    end_col_offset=getattr(x.slice, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.ExtSlice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Tuple(
+                    elts=[to_doc(i) for i in x.slice.dims],
+                    ctx=doc.Load(
+                        lineno=None,
+                        col_offset=None,
+                        end_lineno=None,
+                        end_col_offset=None,
+                    ),
+                    lineno=getattr(x, "lineno", None),
+                    col_offset=getattr(x, "col_offset", None),
+                    end_lineno=getattr(x, "end_lineno", None),
+                    end_col_offset=getattr(x, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.Index):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=to_doc(x.slice.value),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+    def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+        if isinstance(x.slice, doc.Slice):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=from_doc(x.slice),
+                ctx=from_doc(x.ctx),
+            )
+        elif isinstance(x.slice, doc.Tuple):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.ExtSlice(
+                    dims=[from_doc(i) for i in x.slice.elts],
+                ),
+                ctx=from_doc(x.ctx),
+            )
+        else:
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.Index(value=from_doc(x.slice)),
+                ctx=from_doc(x.ctx),
+            )
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Subscript")(subscript_to_doc)
+    register_from_doc("Subscript")(subscript_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
diff --git a/python/tvm/script/parser/doc_core.py b/python/tvm/script/parser/doc_core.py
new file mode 100644
index 0000000000..b88eef9a0e
--- /dev/null
+++ b/python/tvm/script/parser/doc_core.py
@@ -0,0 +1,1140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-outer-name,missing-docstring,invalid-name
+# pylint: disable=useless-super-delegation,redefined-builtin
+# pylint: disable=too-few-public-methods,too-many-arguments
+class AST:
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__()
+        self.lineno = lineno
+        self.col_offset = col_offset
+        self.end_lineno = end_lineno
+        self.end_col_offset = end_col_offset
+
+
+class mod(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Module(mod):
+    _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+
+
+class Interactive(mod):
+    _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+
+
+class Expression(mod):
+    _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+
+
+class stmt(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class FunctionDef(stmt):
+    _FIELDS = [
+        "name",
+        "args",
+        "body",
+        "decorator_list",
+        "returns",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self,
+        name,
+        args,
+        body,
+        decorator_list,
+        returns,
+        lineno,
+        col_offset,
+        end_lineno,
+        end_col_offset,
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.name = name
+        self.args = args
+        self.body = body
+        self.decorator_list = decorator_list
+        self.returns = returns
+
+
+class ClassDef(stmt):
+    _FIELDS = [
+        "name",
+        "bases",
+        "keywords",
+        "body",
+        "decorator_list",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self,
+        name,
+        bases,
+        keywords,
+        body,
+        decorator_list,
+        lineno,
+        col_offset,
+        end_lineno,
+        end_col_offset,
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.name = name
+        self.bases = bases
+        self.keywords = keywords
+        self.body = body
+        self.decorator_list = decorator_list
+
+
+class Return(stmt):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class Delete(stmt):
+    _FIELDS = ["targets", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, targets, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.targets = targets
+
+
+class Assign(stmt):
+    _FIELDS = ["targets", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, targets, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.targets = targets
+        self.value = value
+
+
+class AugAssign(stmt):
+    _FIELDS = ["target", "op", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, target, op, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.op = op
+        self.value = value
+
+
+class AnnAssign(stmt):
+    _FIELDS = [
+        "target",
+        "annotation",
+        "value",
+        "simple",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self, target, annotation, value, simple, lineno, col_offset, end_lineno, end_col_offset
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.annotation = annotation
+        self.value = value
+        self.simple = simple
+
+
+class For(stmt):
+    _FIELDS = [
+        "target",
+        "iter",
+        "body",
+        "orelse",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(self, target, iter, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.iter = iter
+        self.body = body
+        self.orelse = orelse
+
+
+class While(stmt):
+    _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.test = test
+        self.body = body
+        self.orelse = orelse
+
+
+class If(stmt):
+    _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.test = test
+        self.body = body
+        self.orelse = orelse
+
+
+class With(stmt):
+    _FIELDS = ["items", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, items, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.items = items
+        self.body = body
+
+
+class Raise(stmt):
+    _FIELDS = ["exc", "cause", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, exc, cause, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.exc = exc
+        self.cause = cause
+
+
+class Try(stmt):
+    _FIELDS = [
+        "body",
+        "handlers",
+        "orelse",
+        "finalbody",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self, body, handlers, orelse, finalbody, lineno, col_offset, end_lineno, end_col_offset
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+        self.handlers = handlers
+        self.orelse = orelse
+        self.finalbody = finalbody
+
+
+class Assert(stmt):
+    _FIELDS = ["test", "msg", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, test, msg, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.test = test
+        self.msg = msg
+
+
+class Import(stmt):
+    _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.names = names
+
+
+class ImportFrom(stmt):
+    _FIELDS = ["module", "names", "level", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, module, names, level, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.module = module
+        self.names = names
+        self.level = level
+
+
+class Global(stmt):
+    _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.names = names
+
+
+class Nonlocal(stmt):
+    _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.names = names
+
+
+class Expr(stmt):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class Pass(stmt):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Break(stmt):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Continue(stmt):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class expr(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BoolOp(expr):
+    _FIELDS = ["op", "values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, op, values, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.op = op
+        self.values = values
+
+
+class BinOp(expr):
+    _FIELDS = ["left", "op", "right", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, left, op, right, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.left = left
+        self.op = op
+        self.right = right
+
+
+class UnaryOp(expr):
+    _FIELDS = ["op", "operand", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, op, operand, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.op = op
+        self.operand = operand
+
+
+class Lambda(expr):
+    _FIELDS = ["args", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, args, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.args = args
+        self.body = body
+
+
+class IfExp(expr):
+    _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.test = test
+        self.body = body
+        self.orelse = orelse
+
+
+class Dict(expr):
+    _FIELDS = ["keys", "values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, keys, values, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.keys = keys
+        self.values = values
+
+
+class Set(expr):
+    _FIELDS = ["elts", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elts, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elts = elts
+
+
+class ListComp(expr):
+    _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elt = elt
+        self.generators = generators
+
+
+class SetComp(expr):
+    _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elt = elt
+        self.generators = generators
+
+
+class DictComp(expr):
+    _FIELDS = ["key", "value", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, key, value, generators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.key = key
+        self.value = value
+        self.generators = generators
+
+
+class GeneratorExp(expr):
+    _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elt = elt
+        self.generators = generators
+
+
+class Yield(expr):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class YieldFrom(expr):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class Compare(expr):
+    _FIELDS = ["left", "ops", "comparators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, left, ops, comparators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.left = left
+        self.ops = ops
+        self.comparators = comparators
+
+
+class Call(expr):
+    _FIELDS = ["func", "args", "keywords", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, func, args, keywords, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.func = func
+        self.args = args
+        self.keywords = keywords
+
+
+class FormattedValue(expr):
+    _FIELDS = [
+        "value",
+        "conversion",
+        "format_spec",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self, value, conversion, format_spec, lineno, col_offset, end_lineno, end_col_offset
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.conversion = conversion
+        self.format_spec = format_spec
+
+
+class JoinedStr(expr):
+    _FIELDS = ["values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, values, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.values = values
+
+
+class Constant(expr):
+    _FIELDS = ["value", "kind", "s", "n", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, kind, s, n, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.kind = kind
+        self.s = s
+        self.n = n
+
+
+class NamedExpr(expr):
+    _FIELDS = ["target", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, target, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.value = value
+
+
+class Attribute(expr):
+    _FIELDS = ["value", "attr", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, attr, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.attr = attr
+        self.ctx = ctx
+
+
+class slice(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Slice(slice):
+    _FIELDS = ["lower", "upper", "step", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lower, upper, step, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.lower = lower
+        self.upper = upper
+        self.step = step
+
+
+class ExtSlice(slice):
+    _FIELDS = ["dims", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, dims, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.dims = dims
+
+
+class Index(slice):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class Subscript(expr):
+    _FIELDS = ["value", "slice", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, slice, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.slice = slice
+        self.ctx = ctx
+
+
+class Starred(expr):
+    _FIELDS = ["value", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.ctx = ctx
+
+
+class Name(expr):
+    _FIELDS = ["id", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, id, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.id = id
+        self.ctx = ctx
+
+
+class List(expr):
+    _FIELDS = ["elts", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elts, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elts = elts
+        self.ctx = ctx
+
+
+class Tuple(expr):
+    _FIELDS = ["elts", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elts, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elts = elts
+        self.ctx = ctx
+
+
+class expr_context(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class AugLoad(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class AugStore(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Param(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Suite(mod):
+    _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+
+
+class Del(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Load(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Store(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class boolop(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class And(boolop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Or(boolop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class operator(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Add(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitAnd(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitOr(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitXor(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Div(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class FloorDiv(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class LShift(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Mod(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Mult(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class MatMult(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Pow(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class RShift(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Sub(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class unaryop(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Invert(unaryop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Not(unaryop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class UAdd(unaryop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class USub(unaryop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class cmpop(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Eq(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Gt(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class GtE(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class In(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Is(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class IsNot(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Lt(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class LtE(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class NotEq(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class NotIn(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class comprehension(AST):
+    _FIELDS = [
+        "target",
+        "iter",
+        "ifs",
+        "is_async",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(self, target, iter, ifs, is_async, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.iter = iter
+        self.ifs = ifs
+        self.is_async = is_async
+
+
+class excepthandler(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class ExceptHandler(excepthandler):
+    _FIELDS = ["type", "name", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, type, name, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.type = type
+        self.name = name
+        self.body = body
+
+
+class arguments(AST):
+    _FIELDS = [
+        "args",
+        "vararg",
+        "kwonlyargs",
+        "kw_defaults",
+        "kwarg",
+        "defaults",
+        "posonlyargs",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self,
+        args,
+        vararg,
+        kwonlyargs,
+        kw_defaults,
+        kwarg,
+        defaults,
+        posonlyargs,
+        lineno,
+        col_offset,
+        end_lineno,
+        end_col_offset,
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.args = args
+        self.vararg = vararg
+        self.kwonlyargs = kwonlyargs
+        self.kw_defaults = kw_defaults
+        self.kwarg = kwarg
+        self.defaults = defaults
+        self.posonlyargs = posonlyargs
+
+
+class arg(AST):
+    _FIELDS = ["arg", "annotation", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, arg, annotation, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.arg = arg
+        self.annotation = annotation
+
+
+class keyword(AST):
+    _FIELDS = ["arg", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, arg, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.arg = arg
+        self.value = value
+
+
+class alias(AST):
+    _FIELDS = ["name", "asname", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, name, asname, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.name = name
+        self.asname = asname
+
+
+class withitem(AST):
+    _FIELDS = [
+        "context_expr",
+        "optional_vars",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(self, context_expr, optional_vars, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.context_expr = context_expr
+        self.optional_vars = optional_vars
+
+
+__all__ = [
+    "AST",
+    "mod",
+    "Module",
+    "Interactive",
+    "Expression",
+    "stmt",
+    "FunctionDef",
+    "ClassDef",
+    "Return",
+    "Delete",
+    "Assign",
+    "AugAssign",
+    "AnnAssign",
+    "For",
+    "While",
+    "If",
+    "With",
+    "Raise",
+    "Try",
+    "Assert",
+    "Import",
+    "ImportFrom",
+    "Global",
+    "Nonlocal",
+    "Expr",
+    "Pass",
+    "Break",
+    "Continue",
+    "expr",
+    "BoolOp",
+    "BinOp",
+    "UnaryOp",
+    "Lambda",
+    "IfExp",
+    "Dict",
+    "Set",
+    "ListComp",
+    "SetComp",
+    "DictComp",
+    "GeneratorExp",
+    "Yield",
+    "YieldFrom",
+    "Compare",
+    "Call",
+    "FormattedValue",
+    "JoinedStr",
+    "Constant",
+    "NamedExpr",
+    "Attribute",
+    "slice",
+    "Slice",
+    "ExtSlice",
+    "Index",
+    "Subscript",
+    "Starred",
+    "Name",
+    "List",
+    "Tuple",
+    "expr_context",
+    "AugLoad",
+    "AugStore",
+    "Param",
+    "Suite",
+    "Del",
+    "Load",
+    "Store",
+    "boolop",
+    "And",
+    "Or",
+    "operator",
+    "Add",
+    "BitAnd",
+    "BitOr",
+    "BitXor",
+    "Div",
+    "FloorDiv",
+    "LShift",
+    "Mod",
+    "Mult",
+    "MatMult",
+    "Pow",
+    "RShift",
+    "Sub",
+    "unaryop",
+    "Invert",
+    "Not",
+    "UAdd",
+    "USub",
+    "cmpop",
+    "Eq",
+    "Gt",
+    "GtE",
+    "In",
+    "Is",
+    "IsNot",
+    "Lt",
+    "LtE",
+    "NotEq",
+    "NotIn",
+    "comprehension",
+    "excepthandler",
+    "ExceptHandler",
+    "arguments",
+    "arg",
+    "keyword",
+    "alias",
+    "withitem",
+]
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/entry.py
similarity index 50%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/entry.py
index 923eb97d27..ed736ba282 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/entry.py
@@ -14,32 +14,30 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
 
-import inspect
-from typing import Callable
+from tvm.ir.ir_builder import IRBuilder
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+from . import doc
+from .parser import Parser
+from .source import Source
 
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+    if isinstance(program, str) and extra_vars is None:
+        from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
+        from tvm.script.parser import tir  # pylint: disable=import-outside-toplevel
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
-
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
-
-    raise TypeError("Only function definitions are supported.")
+        extra_vars = {
+            "I": ir,
+            "ir": ir,
+            "T": tir,
+            "tir": tir,
+        }
+    source = Source(program)
+    parser = Parser(source)
+    with IRBuilder() as builder:
+        parser.parse(extra_vars=extra_vars)
+    return builder.get()
diff --git a/python/tvm/script/parser/evaluator.py b/python/tvm/script/parser/evaluator.py
new file mode 100644
index 0000000000..3899531b21
--- /dev/null
+++ b/python/tvm/script/parser/evaluator.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+    doc.Add: lambda a, b: a + b,
+    doc.Sub: lambda a, b: a - b,
+    doc.Mult: lambda a, b: a * b,
+    doc.Div: lambda a, b: a / b,
+    doc.FloorDiv: lambda a, b: a // b,
+    doc.Mod: lambda a, b: a % b,
+    doc.LShift: lambda a, b: a << b,
+    doc.RShift: lambda a, b: a >> b,
+    doc.BitOr: lambda a, b: a | b,
+    doc.BitXor: lambda a, b: a ^ b,
+    doc.BitAnd: lambda a, b: a & b,
+    doc.MatMult: lambda a, b: a @ b,
+    doc.Pow: lambda a, b: a**b,
+    doc.Eq: lambda a, b: a == b,
+    doc.NotEq: lambda a, b: a != b,
+    doc.Lt: lambda a, b: a < b,
+    doc.LtE: lambda a, b: a <= b,
+    doc.Gt: lambda a, b: a > b,
+    doc.GtE: lambda a, b: a >= b,
+    doc.Is: lambda a, b: a is b,
+    doc.IsNot: lambda a, b: a is not b,
+    doc.In: lambda a, b: a in b,
+    doc.NotIn: lambda a, b: a not in b,
+    doc.And: lambda a, b: a and b,
+    doc.Or: lambda a, b: a or b,
+    doc.Invert: lambda a: ~a,
+    doc.Not: lambda a: not a,
+    doc.UAdd: lambda a: +a,
+    doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+    parser: "Parser"
+    value_table: Dict[str, Any]
+    new_value_count: int
+
+    def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+        super().__init__()
+        self.parser = parser
+        self.value_table = value_table
+        self.new_value_count = 0
+
+    @staticmethod
+    def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+        self = ExprEvaluator(parser, value_table)
+        result = self._visit(node)  # pylint: disable=protected-access
+        if isinstance(result, doc.Name):
+            if result.id not in self.value_table:
+                self.parser.report_error(result, f"Undefined variable: {result.id}")
+            return self.value_table[result.id]
+        if isinstance(result, doc.Constant):
+            return result.value
+        raise TypeError(f"Unexpected result type: {type(result)}")
+
+    def _add_intermediate_result(self, value: Any) -> doc.Name:
+        name = f"__tvm_tmp_value_{self.new_value_count}"
+        self.new_value_count += 1
+        self.value_table[name] = value
+        lineno = 0
+        col_offset = 0
+        return doc.Name(
+            id=name,
+            ctx=doc.Load(
+                lineno=lineno,
+                col_offset=col_offset,
+                end_lineno=None,
+                end_col_offset=None,
+            ),
+            lineno=lineno,
+            col_offset=col_offset,
+            end_lineno=None,
+            end_col_offset=None,
+        )
+
+    def _visit(self, node: doc.AST) -> Any:
+        if isinstance(node, list):
+            return [self._visit(n) for n in node]
+        if isinstance(node, tuple):
+            return tuple(self._visit(n) for n in node)
+        assert isinstance(node, doc.AST)
+        if isinstance(node, doc.Name):
+            if node.id not in self.value_table:
+                self.parser.report_error(node, f"Undefined variable: {node.id}")
+            return node
+        if isinstance(
+            node,
+            (
+                doc.Constant,
+                doc.expr_context,
+                doc.operator,
+                doc.boolop,
+                doc.unaryop,
+                doc.cmpop,
+            ),
+        ):
+            return node
+        if not isinstance(node, (doc.expr, doc.slice)):
+            return node
+        if isinstance(node, doc.Lambda):
+            return self._eval_lambda(node)
+        fields = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            attr = getattr(node, field)
+            if isinstance(attr, (doc.AST, tuple, list)):
+                fields[field] = self._visit(attr)
+            else:
+                fields[field] = attr
+        try:
+            if isinstance(node, doc.BoolOp):
+                value = self._eval_bool_op(fields)
+            elif isinstance(node, doc.Compare):
+                value = self._eval_compare(fields)
+            elif isinstance(node, doc.UnaryOp):
+                value = self._eval_unary_op(fields)
+            elif isinstance(node, doc.BinOp):
+                value = self._eval_bin_op(fields)
+            elif isinstance(node, doc.Slice):
+                value = self._eval_slice(fields)
+            else:
+                value = self._eval_expr(node.__class__(**fields))
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_lambda(self, node: doc.Lambda) -> Any:
+        try:
+            value = self._eval_expr(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+        op = fields["op"]
+        if not isinstance(op, (doc.And, doc.Or)):
+            raise TypeError(f"Unexpected operator: {op}")
+        value = self._eval_expr(fields["values"][0])
+        for rhs in fields["values"][1:]:
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["left"])
+        for op, rhs in zip(fields["ops"], fields["comparators"]):
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["operand"])
+        value = _eval_op(fields["op"], values=[value])
+        return value
+
+    def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+        return _eval_op(
+            fields["op"],
+            values=[
+                self._eval_expr(fields["left"]),
+                self._eval_expr(fields["right"]),
+            ],
+        )
+
+    def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+        lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+        lower = self._eval_expr(lower) if lower is not None else None
+        upper = self._eval_expr(upper) if upper is not None else None
+        step = self._eval_expr(step) if step is not None else None
+
+        return slice(lower, upper, step)
+
+    def _eval_expr(self, v: Any) -> Any:
+        return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+    parser: "Parser",
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    value_table = {}
+    if dict_globals is not None:
+        value_table.update(dict_globals)
+    return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+    parser: "Parser",
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    try:
+        return _eval_assign(target, source)
+    except Exception as e:  # pylint: disable=broad-except,invalid-name
+        parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+        raise
+
+
+def _eval_expr(
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    node = doc.from_doc(node)
+    if isinstance(node, ast.expr):
+        node = ast.Expression(body=node)
+    assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+    if dict_globals is None:
+        dict_globals = {}
+    node = ast.fix_missing_locations(node)
+    exe = compile(node, filename="<ast>", mode="eval")
+    return eval(exe, dict_globals)  # pylint: disable=eval-used
+
+
+def _eval_op(
+    op: doc.AST,
+    values: List[Any],
+):
+    op_type = type(op)  # pylint: disable=protected-access
+    for i, v in enumerate(values):
+        v_type = getattr(type(v), "_dispatch_type", None)
+        if v_type is None:
+            continue
+        f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+        if f is not None:
+            return f(*values)
+    return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    target = doc.from_doc(target)
+    assert isinstance(target, ast.expr)
+    RHS_VAR_NAME = "__tvm_rhs_var__"  # pylint: disable=invalid-name
+    rhs_var_name = RHS_VAR_NAME
+    dict_locals = {rhs_var_name: source}
+    mod = ast.fix_missing_locations(
+        ast.Module(
+            body=[
+                ast.Assign(
+                    targets=[target],
+                    value=ast.Name(
+                        id=rhs_var_name,
+                        ctx=ast.Load(),
+                    ),
+                )
+            ],
+            type_ignores=[],
+        )
+    )
+    exe = compile(mod, filename="<ast>", mode="exec")
+    exec(exe, {}, dict_locals)  # pylint: disable=exec-used
+    del dict_locals[rhs_var_name]
+    return dict_locals
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser/ir/__init__.py
similarity index 89%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser/ir/__init__.py
index 926d17b166..bea08cfb1b 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
-import tvm._ffi
-
-tvm._ffi._init_api("script", __name__)
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/entry.py
similarity index 66%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/entry.py
index 555659d0c5..353963f29b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,8 +14,21 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Type
 
-from . import tir
+from tvm.ir import IRModule
 
-from .parser import ir_module, from_source
+from ..entry import parse
+from ..utils import inspect_class_capture
+
+
+def ir_module(f: Type) -> IRModule:
+    if not inspect.isclass(f):
+        raise TypeError(f"Expect a class, but got: {f}")
+
+    return parse(f, inspect_class_capture(f))
+
+
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 54%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/parser.py
index 555659d0c5..aec203c7d9 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,8 +14,26 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from tvm.ir import ir_builder as I
 
-from . import tir
+from .. import dispatch, doc
+from ..parser import Parser
 
-from .parser import ir_module, from_source
+
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+    with self.var_table.with_frame():
+        with I.ir_module():
+            with self.with_dispatch_token("ir"):
+                self.visit_body(node.body)
+
+
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+    pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+    pass
diff --git a/python/tvm/script/parser/parser.py b/python/tvm/script/parser/parser.py
new file mode 100644
index 0000000000..c0eb79f144
--- /dev/null
+++ b/python/tvm/script/parser/parser.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from typing import Any, Callable, Dict, List, Optional, Union
+
+from ...error import DiagnosticError
+from . import dispatch, doc
+from .diagnostics import Diagnostics
+from .evaluator import eval_assign, eval_expr
+from .source import Source
+from .utils import deferred
+from .var_table import VarTable
+
+DEFAULT_VISIT = {
+    "Interactive",
+    "Module",
+    "Expression",
+    "Pass",
+}
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+    def _wrapper(self: "Parser", node: doc.AST) -> None:
+        try:
+            return func(self, node)
+        except DiagnosticError:
+            raise
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+            raise
+
+    return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+    for token in [self.dispatch_tokens[-1], "default"]:
+        func = dispatch.get(token=token, type_name=type_name, default=None)
+        if func is not None:
+            return _dispatch_wrapper(func)
+    return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+    """The TVMScript parser"""
+
+    diag: Diagnostics
+    dispatch_tokens: List[str]
+    var_table: VarTable
+
+    def __init__(self, source: Source) -> None:
+        self.diag = Diagnostics(source)
+        self.dispatch_tokens = ["default"]
+        self.var_table = VarTable()
+
+    def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+        if extra_vars is None:
+            extra_vars = {}
+        with self.var_table.with_frame():
+            for k, v in extra_vars.items():
+                self.var_table.add(k, v)
+            node = self.diag.source.as_ast()
+            self.visit(node)
+
+    def with_dispatch_token(self, token: str):
+        def pop_token():
+            self.dispatch_tokens.pop()
+
+        self.dispatch_tokens.append(token)
+        return deferred(pop_token)
+
+    def eval_expr(
+        self,
+        node: Union[doc.Expression, doc.expr],
+        extra_vars: Optional[Dict[str, Any]] = None,
+    ) -> Any:
+        var_values = self.var_table.get()
+        if extra_vars is not None:
+            for k, v in extra_vars.items():
+                var_values[k] = v
+        return eval_expr(self, node, var_values)
+
+    def eval_assign(
+        self,
+        target: doc.expr,
+        source: Any,
+        bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+    ) -> Dict[str, Any]:
+        var_values = eval_assign(self, target, source)
+        for k, v in var_values.items():
+            var = bind_value(self, target, k, v)
+            self.var_table.add(k, var)
+        return var_values
+
+    def report_error(self, node: doc.AST, msg: str) -> None:  # pylint: disable=no-self-use
+        self.diag.error(node, msg)
+
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        name = node.__class__.__name__.split(".")[-1]
+        if name in DEFAULT_VISIT:
+            func = self.generic_visit
+        else:
+            func = getattr(self, "visit_" + name, None)
+        if func is None:
+            raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+        func(node)
+
+    def visit_body(self, node: List[doc.stmt]) -> Any:
+        for stmt in node:
+            self.visit(stmt)
+
+    def visit_tvm_annotation(self, node: doc.expr) -> Any:
+        return _dispatch(self, "tvm_annotation")(self, node)
+
+    def visit_FunctionDef(self, node: doc.FunctionDef) -> Any:  # pylint: disable=invalid-name
+        if not node.decorator_list:
+            self.report_error(node, "Function must be decorated")
+        # TODO: only the last decorator is parsed
+        decorator = self.eval_expr(node.decorator_list[-1])
+        if not hasattr(decorator, "dispatch_token"):
+            self.report_error(node, "The parser does not understand the decorator")
+        token = decorator.dispatch_token
+        func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        func(self, node)
+
+    def visit_ClassDef(self, node: doc.ClassDef) -> Any:  # pylint: disable=invalid-name
+        func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        func(self, node)
+
+    def visit_arguments(self, node: doc.arguments) -> Any:
+        return _dispatch(self, "arguments")(self, node)
+
+    def visit_For(self, node: doc.For) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "For")(self, node)
+
+    def visit_While(self, node: doc.While) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "While")(self, node)
+
+    def visit_With(self, node: doc.With) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "With")(self, node)
+
+    def visit_Assign(self, node: doc.Assign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assign")(self, node)
+
+    def visit_Expr(self, node: doc.Expr) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Expr")(self, node)
+
+    def visit_If(self, node: doc.If) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "If")(self, node)
+
+    def visit_AnnAssign(self, node: doc.AnnAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AnnAssign")(self, node)
+
+    def visit_AugAssign(self, node: doc.AugAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AugAssign")(self, node)
+
+    def visit_Assert(self, node: doc.Assert) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assert")(self, node)
diff --git a/python/tvm/script/parser/source.py b/python/tvm/script/parser/source.py
new file mode 100644
index 0000000000..9674dc5494
--- /dev/null
+++ b/python/tvm/script/parser/source.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+import sys
+from typing import Union
+
+from . import doc
+
+
+class Source:
+    source_name: str
+    start_line: int
+    start_column: int
+    source: str
+    full_source: str
+
+    def __init__(self, program: Union[str, doc.AST]):
+        if isinstance(program, str):
+            self.source_name = "<str>"
+            self.start_line = 1
+            self.start_column = 0
+            self.source = program
+            self.full_source = program
+            return
+
+        self.source_name = inspect.getsourcefile(program)  # type: ignore
+        lines, self.start_line = inspect.getsourcelines(program)  # type: ignore
+        if lines:
+            self.start_column = len(lines[0]) - len(lines[0].lstrip())
+        else:
+            self.start_column = 0
+        if self.start_column and lines:
+            self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+        else:
+            self.source = "".join(lines)
+        try:
+            # It will cause a problem when running in Jupyter Notebook.
+            # `mod` will be <module '__main__'>, which is a built-in module
+            # and `getsource` will throw a TypeError
+            mod = inspect.getmodule(program)
+            if mod:
+                self.full_source = inspect.getsource(mod)
+            else:
+                self.full_source = self.source
+        except TypeError:
+            # It's a work around for Jupyter problem.
+            # Since `findsource` is an internal API of inspect, we just use it
+            # as a fallback method.
+            src, _ = inspect.findsource(program)  # type: ignore
+            self.full_source = "".join(src)
+
+    def as_ast(self) -> doc.AST:
+        return doc.parse(self.source)
+
+
+_getfile = inspect.getfile
+
+
+def _patched_inspect_getfile(obj):
+    if not inspect.isclass(obj):
+        return _getfile(obj)
+    mod = getattr(obj, "__module__", None)
+    if mod is not None:
+        file = getattr(sys.modules[mod], "__file__", None)
+        if file is not None:
+            return file
+    for _, member in inspect.getmembers(obj):
+        if inspect.isfunction(member):
+            if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+                return inspect.getfile(member)
+    raise TypeError(f"Source for {obj:!r} not found")
+
+
+inspect.getfile = _patched_inspect_getfile
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 78%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..caa51744c8 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from tvm.tir.ir_builder_v2 import *  # pylint: disable=redefined-builtin
 
-from . import tir
-
-from .parser import ir_module, from_source
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..b92341aecd
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+from tvm.tir.ir_builder_v2 import buffer_decl, ptr
+
+from ..entry import parse
+from ..utils import inspect_function_capture
+
+
+def _is_defined_in_class(frames):
+    if len(frames) > 2:
+        maybe_class_frame = frames[2]
+        statement_list = maybe_class_frame[4]
+        first_statement = statement_list[0]
+        line = first_statement.strip()
+        if line.startswith("class "):
+            return True
+        if line.startswith("@") and "ir_module" in line:
+            return True
+    return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+    if not inspect.isfunction(f):
+        raise TypeError(f"Expect a function, but got: {f}")
+    if _is_defined_in_class(inspect.stack()):
+        return f
+    return parse(f, inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+    def __call__(
+        self,
+        shape,
+        dtype="float32",
+        data=None,
+        strides=None,
+        elem_offset=None,
+        scope="global",
+        align=0,
+        offset_factor=0,
+        buffer_type="",
+        axis_separators=None,
+    ) -> Buffer:
+        return buffer_decl(
+            shape,
+            dtype=dtype,
+            data=data,
+            strides=strides,
+            elem_offset=elem_offset,
+            scope=scope,
+            align=align,
+            offset_factor=offset_factor,
+            buffer_type=buffer_type,
+            axis_separators=axis_separators,
+        )
+
+    def __getitem__(self, keys) -> Buffer:
+        return self(*keys)  # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+    def __call__(self, dtype, storage_scope="global"):
+        if callable(dtype):
+            dtype = dtype().dtype
+        return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+    def __getitem__(self, keys):
+        if not isinstance(keys, tuple):
+            return self(keys)
+        return self(*keys)
+
+
+Buffer = BufferProxy()
+Ptr = PtrProxy()
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..11ee92ad29
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .. import doc
+from ..dispatch import OpMethod, register_op
+
+
+def _register_expr_op(ty: Type):  # pylint: disable=invalid-name
+    ty._dispatch_type = ty  # pylint: disable=protected-access
+
+    def _and(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.And(a, b)
+
+    def _or(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.Or(a, b)
+
+    def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
+        register_op(ty, op, i)(m)
+
+    for i in [0, 1]:
+        # Case 1. binop
+        r(doc.Add, i, tir.Add)
+        r(doc.Sub, i, tir.Sub)
+        r(doc.Mult, i, tir.Mul)
+        r(doc.Div, i, tir.Div)
+        r(doc.FloorDiv, i, tir.FloorDiv)
+        r(doc.Mod, i, tir.FloorMod)
+        r(doc.LShift, i, lambda a, b: a << b)
+        r(doc.RShift, i, lambda a, b: a >> b)
+        r(doc.BitOr, i, lambda a, b: a | b)
+        r(doc.BitXor, i, lambda a, b: a ^ b)
+        r(doc.BitAnd, i, lambda a, b: a & b)
+        # doc.MatMult <-- not implemented
+        # doc.Pow <-- not implemented
+        # Case 2. cmpop
+        r(doc.Eq, i, tir.EQ)
+        r(doc.NotEq, i, tir.NE)
+        r(doc.Lt, i, tir.LT)
+        r(doc.LtE, i, tir.LE)
+        r(doc.Gt, i, tir.GT)
+        r(doc.GtE, i, tir.GE)
+        # doc.Is <-- not implemented
+        # doc.IsNot <-- not implemented
+        # doc.In <-- not implemented
+        # doc.NotIn <-- not implemented
+        # Case 3. boolop
+        r(doc.And, i, _and)
+        r(doc.Or, i, _or)
+    for i in [0]:
+        #  Case 4. unaryop
+        r(doc.Invert, i, lambda a: ~a)
+        r(doc.Not, i, tir.Not)
+        r(doc.UAdd, i, lambda a: +a)
+        r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..c66abf2013
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.ir.ir_builder import IRBuilderFrame as Frame
+from tvm.ir.ir_builder import name
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+from tvm.tir import ir_builder_v2 as T
+
+from .. import dispatch, doc
+from ..parser import Parser
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, (Buffer, Var)):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+        raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Var):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+        raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Frame):
+        value.add_callback(partial(value.__exit__, None, None, None))
+        res = value.__enter__()
+        name(var_name, res)
+        return res
+    elif isinstance(value, (Buffer, IterVar)) or (
+        isinstance(value, Var) and not self.var_table.exist(value)
+    ):
+        name(var_name, value)
+        return value
+    elif isinstance(value, PrimExpr):
+        var = T.var(value.dtype)
+        name(var_name, var)
+        frame = T.let(var, value)
+        frame.add_callback(partial(frame.__exit__, None, None, None))
+        frame.__enter__()
+        return var
+    return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+    for_frame = self.eval_expr(node.iter)
+    if not isinstance(for_frame, T.frame.ForFrame):
+        self.report_error(
+            node.iter,
+            "Expect the for loop to be one of the following: "
+            "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+        )
+    with self.var_table.with_frame():
+        with for_frame as iters:
+            self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+    with self.var_table.with_frame():
+        cond = self.eval_expr(node.test)
+        with T.While(cond):
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+    lhs = node.targets[0]
+    rhs = self.eval_expr(node.value)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+    lhs_pos = (
+        node.target.lineno,
+        node.target.col_offset,
+        node.target.end_lineno,
+        node.target.end_col_offset,
+    )
+    rhs_pos = (
+        node.value.lineno,
+        node.value.col_offset,
+        node.value.end_lineno,
+        node.value.end_col_offset,
+    )
+    node.target.ctx = doc.Load(*lhs_pos)
+    with self.var_table.with_frame():
+        lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+        rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+        lhs_expr = self.eval_expr(node.target)
+        rhs_expr = self.eval_expr(node.value)
+        self.var_table.add(lhs_name, lhs_expr)
+        self.var_table.add(rhs_name, rhs_expr)
+        op = doc.BinOp(
+            doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+            node.op,
+            doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+            *lhs_pos,
+        )
+        rhs = self.eval_expr(op)
+    lhs = node.target
+    lhs.ctx = doc.Store(*lhs_pos)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+    lhs = node.target
+    rhs = self.eval_expr(node.value)
+    ann_var = self.visit_tvm_annotation(node.annotation)
+    if not isinstance(ann_var, Var):
+        self.report_error(node.annotation, "Annotation should be Var")
+    self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+    frame = T.let(ann_var, rhs)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+    with contextlib.ExitStack() as stack:
+        stack.enter_context(self.var_table.with_frame())
+        for item in node.items:
+            frame = self.eval_expr(item.context_expr)
+            if not isinstance(frame, Frame):
+                self.report_error(
+                    item.context_expr, "Invalid context expression in the with-statement."
+                )
+            rhs = stack.enter_context(frame)
+            if item.optional_vars is not None:
+                self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+        self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    with self.var_table.with_frame():
+        self.var_table.add("range", T.serial)
+        with T.prim_func():
+            T.func_name(node.name)
+            if node.returns is not None:
+                ret_type = self.eval_expr(node.returns)
+                if callable(ret_type):
+                    ret_type = PrimType(ret_type().dtype)
+                T.func_ret(ret_type)
+            with self.with_dispatch_token("tir"):
+                self.visit(node.args)
+                self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+    # TODO: handle different types of arguments:
+    # - vararg: arg | None
+    # - kwonlyargs: list[arg]
+    # - kw_defaults: list[expr | None]
+    # - kwarg: arg | None
+    # - defaults: list[expr]
+    # - posonlyargs: list[arg]
+    arg: doc.arg
+    for arg in node.args:
+        if arg.annotation is None:
+            self.report_error(arg, "Type annotation is required for function parameters.")
+        param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+        self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+    annotation = self.eval_expr(node)
+    if callable(annotation):
+        annotation = annotation()
+    return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+    res = self.eval_expr(node.value)
+    if isinstance(res, Frame):
+        res.add_callback(partial(res.__exit__, None, None, None))
+        res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+    with self.var_table.with_frame():
+        with T.If(self.eval_expr(node.test)):
+            with T.Then():
+                self.visit_body(node.body)
+            if len(node.orelse):
+                with T.Else():
+                    self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+    cond = self.eval_expr(node.test)
+    msg = self.eval_expr(node.msg)
+    frame = T.Assert(cond, msg)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
diff --git a/python/tvm/script/parser/utils.py b/python/tvm/script/parser/utils.py
new file mode 100644
index 0000000000..9d681236c8
--- /dev/null
+++ b/python/tvm/script/parser/utils.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+from contextlib import contextmanager
+from typing import Any, Callable, Dict
+
+
+def deferred(f: Callable[[], None]):
+    @contextmanager
+    def context():
+        try:
+            yield
+        finally:
+            f()
+
+    return context()
+
+
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+    prefix = "tvm."
+    result = {}
+    captured = {
+        **inspect.getclosurevars(func).nonlocals,
+        **func.__globals__,
+    }
+    for k, v in captured.items():
+        # Case 1: a module like `T` or `tvm.tir.ir_builder`
+        if inspect.ismodule(v) and v.__name__.startswith(prefix):
+            result[k] = v
+            continue
+        # Case 2: a function like `T.match_buffer`
+        if hasattr(v, "__module__") and v.__module__.startswith(prefix):
+            result[k] = v
+            continue
+        # Case 3: atomic types
+        if v is None or isinstance(v, (int, float, str, bool)):
+            result[k] = v
+            continue
+    return result
+
+
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+    result: Dict[str, Any] = {}
+    for _, v in cls.__dict__.items():
+        if inspect.isfunction(v):
+            func_vars = inspect_function_capture(v)
+            result.update(**func_vars)
+    return result
diff --git a/python/tvm/script/parser/var_table.py b/python/tvm/script/parser/var_table.py
new file mode 100644
index 0000000000..32fced625a
--- /dev/null
+++ b/python/tvm/script/parser/var_table.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The symbol table of variable values"""
+
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Set
+
+from .utils import deferred
+
+
+class VarTableFrame:
+    vars: Set[str]
+
+    def __init__(self):
+        self.vars = set()
+
+    def add(self, var: str):
+        if var in self.vars:
+            raise ValueError(f"Variable {var} already defined in current scope")
+        self.vars.add(var)
+
+    def pop_all(self, fn_pop: Callable[[str], None]):
+        for var in self.vars:
+            fn_pop(var)
+        self.vars.clear()
+
+
+class VarTable:
+
+    frames: List[VarTableFrame]
+    name2value: Dict[str, List[Any]]
+
+    def __init__(self):
+        self.frames = []
+        self.name2value = defaultdict(list)
+
+    def with_frame(self):
+        def pop_frame():
+            frame = self.frames.pop()
+            frame.pop_all(lambda name: self.name2value[name].pop())
+
+        self.frames.append(VarTableFrame())
+        return deferred(pop_frame)
+
+    def add(self, var: str, value: Any):
+        self.frames[-1].add(var)
+        self.name2value[var].append(value)
+
+    def get(self) -> Dict[str, Any]:
+        return {key: values[-1] for key, values in self.name2value.items() if values}
+
+    def exist(self, value: Any):
+        for v in self.name2value.values():
+            if v is value:
+                return True
+        return False
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser_v1/__init__.py
similarity index 95%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser_v1/__init__.py
index 555659d0c5..004e947bf6 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser_v1/__init__.py
@@ -17,5 +17,4 @@
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
 
 from . import tir
-
-from .parser import ir_module, from_source
+from .parser import from_source, ir_module
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser_v1/_ffi_api.py
similarity index 100%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser_v1/_ffi_api.py
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py
similarity index 98%
rename from python/tvm/script/context_maintainer.py
rename to python/tvm/script/parser_v1/context_maintainer.py
index f7f16855c7..400baacc4b 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/parser_v1/context_maintainer.py
@@ -16,16 +16,16 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
+from typing import Callable, Dict, List, Mapping, Optional, Union
 
+import synr
 import tvm
 from tvm.ir import Span
 from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
 from tvm.runtime import Object
+from tvm.tir import Buffer, MatchBufferRegion, PrimExpr, Stmt, Var
 from tvm.tir.expr import IterVar
+
 from .tir.node import BufferSlice
 
 
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/parser_v1/diagnostics.py
similarity index 95%
rename from python/tvm/script/diagnostics.py
rename to python/tvm/script/parser_v1/diagnostics.py
index e676461ab3..b15997552f 100644
--- a/python/tvm/script/diagnostics.py
+++ b/python/tvm/script/parser_v1/diagnostics.py
@@ -17,11 +17,11 @@
 """Bridge from synr's (the library used for parsing the python AST)
    DiagnosticContext to TVM's diagnostics
 """
-from synr import DiagnosticContext, ast
-
 import tvm
+from synr import DiagnosticContext, ast
+from tvm.ir.diagnostics import Diagnostic
 from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
+from tvm.ir.diagnostics import DiagnosticLevel, get_renderer
 
 
 class TVMDiagnosticCtx(DiagnosticContext):
diff --git a/python/tvm/script/highlight.py b/python/tvm/script/parser_v1/highlight.py
similarity index 100%
rename from python/tvm/script/highlight.py
rename to python/tvm/script/parser_v1/highlight.py
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/parser_v1/meta_unparser.py
similarity index 100%
rename from python/tvm/script/meta_unparser.py
rename to python/tvm/script/parser_v1/meta_unparser.py
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser_v1/parser.py
similarity index 99%
rename from python/tvm/script/parser.py
rename to python/tvm/script/parser_v1/parser.py
index 908af081c9..b2b7e388f1 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser_v1/parser.py
@@ -20,35 +20,34 @@ We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
 different python versions. Synr also provides an error handling context that we
 use for error reporting.
 """
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
+import inspect
 import json
 import operator
-import inspect
+
+# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
+import types
 from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
 
 import tvm
+from synr import Transformer, ast, to_ast
 from tvm import IRModule
 from tvm._ffi.base import TVMError
 from tvm.ir import GlobalVar
 from tvm.ir.function import BaseFunc
 from tvm.tir import buffer
 from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
 
+from . import _ffi_api, tir
 from .context_maintainer import ContextMaintainer
+from .diagnostics import TVMDiagnosticCtx
 from .meta_unparser import MetaUnparser
 from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
+from .tir import ty
 from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
+from .tir.node import BufferSlice, Slice
+from .tir.scope_handler import ForScopeHandler, ScopeHandler, WithScopeHandler
 from .tir.special_stmt import SpecialStmt
-from .tir import ty
+from .utils import call_with_error_reporting, synr_span_from_tvm, tvm_span_from_synr
 
 
 class CallArgumentReader(object):
diff --git a/python/tvm/script/registry.py b/python/tvm/script/parser_v1/registry.py
similarity index 97%
rename from python/tvm/script/registry.py
rename to python/tvm/script/parser_v1/registry.py
index e7d90dd515..e816b90f5d 100644
--- a/python/tvm/script/registry.py
+++ b/python/tvm/script/parser_v1/registry.py
@@ -17,7 +17,7 @@
 """TVM Script Parser Function Registry """
 # pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
 import types
-from typing import Union, Callable, Dict, Optional, Any
+from typing import Any, Callable, Dict, Optional, Union
 
 
 class Registry(object):
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser_v1/tir/__init__.py
similarity index 100%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser_v1/tir/__init__.py
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi
similarity index 100%
rename from python/tvm/script/tir/__init__.pyi
rename to python/tvm/script/parser_v1/tir/__init__.pyi
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py
similarity index 99%
rename from python/tvm/script/tir/intrin.py
rename to python/tvm/script/parser_v1/tir/intrin.py
index 627b89086a..008d7495a0 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/parser_v1/tir/intrin.py
@@ -17,11 +17,12 @@
 """TVM Script Parser Intrinsic Classes"""
 # pylint: disable=redefined-builtin, relative-beyond-top-level
 import builtins
-from typing import List, Any
+from typing import Any, List
 
 import tvm.tir
+from tvm.target import codegen
+
 from ..registry import register
-from ...target import codegen
 from ..utils import get_param_list, tvm_span_from_synr
 
 
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/parser_v1/tir/node.py
similarity index 98%
rename from python/tvm/script/tir/node.py
rename to python/tvm/script/parser_v1/tir/node.py
index 29e79607fb..cfaf9df476 100644
--- a/python/tvm/script/tir/node.py
+++ b/python/tvm/script/parser_v1/tir/node.py
@@ -17,12 +17,13 @@
 # pylint: disable=redefined-builtin
 """TVM Script nodes."""
 
-from typing import Optional, Union, List, Callable
+from typing import Callable, List, Optional, Union
+
 import synr
 from tvm.arith import Analyzer
+from tvm.ir import Range, Span
 from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
+from tvm.tir import Buffer, BufferLoad, BufferRegion, IntImm, PrimExpr, Ramp
 
 
 class Slice:
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser_v1/tir/prim_func.py
similarity index 97%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser_v1/tir/prim_func.py
index 923eb97d27..a5fdcc15c5 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser_v1/tir/prim_func.py
@@ -19,7 +19,8 @@
 import inspect
 from typing import Callable
 
-from tvm.tir.function import PrimFunc
+from tvm.tir import PrimFunc
+
 from ..parser import from_source
 
 
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/parser_v1/tir/scope_handler.py
similarity index 98%
rename from python/tvm/script/tir/scope_handler.py
rename to python/tvm/script/parser_v1/tir/scope_handler.py
index da7545c9a9..cd2167f4ab 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/parser_v1/tir/scope_handler.py
@@ -16,24 +16,19 @@
 # under the License.
 """TVM Script Parser Scope Handler Classes"""
 # pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
 
-import synr
 import numpy as np
+import synr
 import tvm.tir
+from tvm.ir import Range, Span
 from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
-
-from .node import BufferSlice
+from tvm.tir import Buffer, BufferRegion, ForKind, IterVar, PrimExpr, Stmt, Var
 
 from ..context_maintainer import ContextMaintainer
 from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
 
 
 class ScopeHandler:
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py
similarity index 99%
rename from python/tvm/script/tir/special_stmt.py
rename to python/tvm/script/parser_v1/tir/special_stmt.py
index 15502055b7..42a90f647f 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/parser_v1/tir/special_stmt.py
@@ -17,27 +17,21 @@
 """TVM Script Parser Special Stmt Classes"""
 # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
 # pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
 
 import synr
+import tvm.tir
 from synr import ast
+from tvm.ir import Span
 from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
 from tvm.runtime import Object, String
 from tvm.target import Target
-from tvm.ir import Span
 from tvm.tir import IntImm, IterVar, Var
 
-from .node import BufferSlice
-
 from ..context_maintainer import BlockInfo, ContextMaintainer
 from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
 
 
 def convert_to_int(
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/parser_v1/tir/ty.py
similarity index 99%
rename from python/tvm/script/tir/ty.py
rename to python/tvm/script/parser_v1/tir/ty.py
index a64485b215..8a048c07b5 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/parser_v1/tir/ty.py
@@ -23,6 +23,7 @@ a wrapper for uniform Type system in IR
 from numbers import Integral
 
 import tvm
+
 from .special_stmt import SpecialStmt, convert_to_int
 
 
diff --git a/python/tvm/script/utils.py b/python/tvm/script/parser_v1/utils.py
similarity index 97%
rename from python/tvm/script/utils.py
rename to python/tvm/script/parser_v1/utils.py
index c655a62237..f358a90081 100644
--- a/python/tvm/script/utils.py
+++ b/python/tvm/script/parser_v1/utils.py
@@ -16,13 +16,12 @@
 # under the License.
 """Helper functions in TVM Script Parser"""
 
-from typing import Callable, List, Any, Optional, Tuple
-
 import inspect
-import synr
+from typing import Any, Callable, List, Optional, Tuple
 
-from tvm.ir import Span, SourceName
+import synr
 from tvm.error import DiagnosticError
+from tvm.ir import SourceName, Span
 
 
 def get_param_list(
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index c64b7dfe71..0d949234d8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -19,50 +19,180 @@
 from tvm.ir import PrimExpr
 from tvm.runtime import const
 
-from .buffer import Buffer, decl_buffer, DataProducer
-from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
-from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
-from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
-from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
-from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
-
-from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
+from . import analysis
+from . import ir_builder_v1 as ir_builder
+from . import schedule, stmt_functor, transform, usmp
+from .buffer import Buffer, DataProducer, decl_buffer
+from .data_layout import BijectiveLayout, Layout, bijective_layout, layout
+from .expr import (
+    EQ,
+    GE,
+    GT,
+    LE,
+    LT,
+    NE,
+    Add,
+    And,
+    Any,
+    Broadcast,
+    BufferLoad,
+    Call,
+    CallEffectKind,
+    Cast,
+    CommReducer,
+    Div,
+    FloatImm,
+    FloorDiv,
+    FloorMod,
+    IntImm,
+    IterVar,
+    Let,
+    Load,
+    Max,
+    Min,
+    Mod,
+    Mul,
+    Not,
+    Or,
+    ProducerLoad,
+    Ramp,
+    Reduce,
+    Select,
+    Shuffle,
+    SizeVar,
+    StringImm,
+    Sub,
+    Var,
+)
+from .function import IndexMap, PrimFunc, TensorIntrin
+from .op import (
+    TVMBackendAllocWorkspace,
+    TVMBackendFreeWorkspace,
+    abs,
+    acos,
+    acosh,
+    address_of,
+    all,
+    any,
+    asin,
+    asinh,
+    atan,
+    atan2,
+    atanh,
+    call_cpacked,
+    call_cpacked_lowered,
+    call_extern,
+    call_intrin,
+    call_llvm_intrin,
+    call_llvm_pure_intrin,
+    call_packed,
+    call_packed_lowered,
+    call_pure_extern,
+    ceil,
+    ceildiv,
+    clz,
+    comm_reducer,
+    copysign,
+    cos,
+    cosh,
+    div,
+    erf,
+    exp,
+    exp2,
+    exp10,
+    floor,
+    floordiv,
+    floormod,
+    fmod,
+    hypot,
+    if_then_else,
+    indexdiv,
+    indexmod,
+    isfinite,
+    isinf,
+    isnan,
+    isnullptr,
+    ldexp,
+    likely,
+    log,
+    log1p,
+    log2,
+    log10,
+    lookup_param,
+    max,
+    max_value,
+    min,
+    min_value,
+    mma_fill,
+    mma_store,
+    nearbyint,
+    nextafter,
+    popcount,
+    power,
+    ptx_commit_group,
+    ptx_cp_async,
+    ptx_ldmatrix,
+    ptx_mma,
+    ptx_mma_sp,
+    ptx_wait_group,
+    q_multiply_shift,
+    ret,
+    round,
+    rsqrt,
+    shift_left,
+    shift_right,
+    sigmoid,
+    sin,
+    sinh,
+    sqrt,
+    sum,
+    tan,
+    tanh,
+    trace,
+    trunc,
+    truncdiv,
+    truncmod,
+    tvm_access_ptr,
+    tvm_bmma_sync,
+    tvm_fill_fragment,
+    tvm_load_matrix_sync,
+    tvm_mma_sync,
+    tvm_stack_alloca,
+    tvm_stack_make_array,
+    tvm_stack_make_shape,
+    tvm_store_matrix_sync,
+    tvm_struct_get,
+    tvm_struct_set,
+    tvm_thread_allreduce,
+    tvm_throw_last_error,
+    tvm_tuple,
+)
+from .schedule import BlockScope, Schedule, ScheduleError, ScheduleState, StmtSRef
 from .stmt import (
-    BufferStore,
-    BufferRealize,
-    Store,
-    ProducerStore,
     Allocate,
     AllocateConst,
+    AssertStmt,
     AttrStmt,
+    Block,
+    BlockRealize,
+    BufferRealize,
+    BufferRegion,
+    BufferStore,
     DeclBuffer,
+    Evaluate,
+    For,
+    ForKind,
+    IfThenElse,
+    LetStmt,
+    MatchBufferRegion,
+    Prefetch,
+    ProducerRealize,
+    ProducerStore,
+    SeqStmt,
+    Stmt,
+    Store,
+    While,
+    stmt_list,
+    stmt_seq,
+    type_annotation,
 )
-
-from .stmt import ProducerRealize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
-from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-
-from .function import PrimFunc, TensorIntrin, IndexMap
-
-from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
-from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
-from .op import sin, sinh, asin, asinh
-from .op import cos, cosh, acos, acosh
-from .op import tan, tanh, atan, atan2, atanh
-from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
-from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
-from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
-from .op import comm_reducer, min, max, sum
-from .op import q_multiply_shift
-
-from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
-
-from . import schedule
-from . import ir_builder
-from . import transform
-from . import analysis
-from . import stmt_functor
-from . import usmp
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/tir/_ffi_ir_builder_api.py
similarity index 88%
rename from python/tvm/script/_ffi_api.py
rename to python/tvm/tir/_ffi_ir_builder_api.py
index 926d17b166..61b288d498 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/tir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
 import tvm._ffi
 
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder.tir", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 13674daa24..ea220fea22 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -21,9 +21,9 @@ from typing import Dict, List, Union
 from tvm import Object
 from tvm.ir import IRModule
 from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
+from tvm.tir.stmt import Block, BufferRegion, PrimExpr, Stmt
 
-from .. import Buffer, Stmt
+from ..buffer import Buffer
 from ..function import PrimFunc
 from . import _ffi_api
 
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index d9b0aec76a..6752082fa1 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -20,7 +20,7 @@ import tvm._ffi
 
 from tvm._ffi.base import string_types
 from tvm.runtime import Object, convert
-from tvm.ir import PrimExpr, PointerType, PrimType
+from tvm.ir import PrimExpr, PointerType, PrimType, Range
 from . import _ffi_api
 
 
@@ -176,6 +176,40 @@ class Buffer(Object):
         """
         return _ffi_api.BufferOffsetOf(self, indices)  # type: ignore
 
+    def __getitem__(self, indices):
+        from .expr import BufferLoad, Ramp
+        from .stmt import BufferRegion
+        from ..arith import Analyzer
+
+        if not isinstance(indices, (tuple, list)):
+            indices = [indices]
+        if any(isinstance(index, slice) and index.step is None for index in indices):
+            region = []
+            for index in indices:
+                if isinstance(index, slice):
+                    region.append(
+                        Range.from_min_extent(
+                            index.start, Analyzer().simplify(index.stop - index.start)
+                        )
+                    )
+                else:
+                    region.append(Range.from_min_extent(index, 1))
+            return BufferRegion(self, region)
+        else:
+            expr_indices = []
+            for index in indices:
+                if isinstance(index, slice):
+                    lanes = Analyzer().simplify(
+                        (index.stop - index.start + index.step - 1) // index.step
+                    )
+                    if lanes == 1:
+                        expr_indices.append(index.start)
+                    else:
+                        expr_indices.append(Ramp(index.start, index.step, int(lanes)))
+                else:
+                    expr_indices.append(index)
+            return BufferLoad(self, expr_indices)
+
 
 def decl_buffer(
     shape,
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
   assert(y.a == x)
 """
 from typing import Optional, Union
-from tvm import ir
+
 import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
 from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
 
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
 from . import _ffi_api
+from . import generic as _generic
 
 
 def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
 class ExprOp(object):
     """Operator overloading for Expr like expressions."""
 
-    # TODO(tkonolige): use inspect to add source information to these objects
-
     def __add__(self, other):
         return _generic.add(self, other)
 
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
     """
 
     def __init__(self, condition, true_value, false_value, span=None):
+        if isinstance(condition, bool):
+            condition = IntImm("bool", condition)
         self.__init_handle_by_constructor__(
             _ffi_api.Select, condition, true_value, false_value, span  # type: ignore
         )
diff --git a/python/tvm/tir/ir_builder_frame.py b/python/tvm/tir/ir_builder_frame.py
new file mode 100644
index 0000000000..a1f457aad2
--- /dev/null
+++ b/python/tvm/tir/ir_builder_frame.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IRBuilder for TIR"""
+
+from typing import List
+
+from tvm._ffi import register_object as _register_object
+from tvm.ir.ir_builder import IRBuilderFrame
+
+from .buffer import Buffer
+from .expr import Var
+
+
+@_register_object("ir_builder.tir.TIRFrame")
+class TIRFrame(IRBuilderFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.BlockFrame")
+class BlockFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ForFrame")
+class ForFrame(TIRFrame):
+    def __enter__(self) -> List[Var]:
+        super().__enter__()
+        return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("ir_builder.tir.PrimFuncFrame")
+class PrimFuncFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder_v1.py
similarity index 99%
rename from python/tvm/tir/ir_builder.py
rename to python/tvm/tir/ir_builder_v1.py
index ce8cd1b403..8a68faaac6 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder_v1.py
@@ -17,13 +17,13 @@
 """Developer API of IR node builder make function."""
 import tvm
 from tvm._ffi.base import string_types
-from tvm.runtime import ObjectGeneric, convert, const
 from tvm.ir import container as _container
+from tvm.runtime import ObjectGeneric, const, convert
 
-from . import stmt as _stmt
-from . import expr as _expr
 from . import buffer as _buffer
+from . import expr as _expr
 from . import op
+from . import stmt as _stmt
 
 
 class WithScope(object):
diff --git a/python/tvm/tir/ir_builder_v2.py b/python/tvm/tir/ir_builder_v2.py
new file mode 100644
index 0000000000..f297ec69ff
--- /dev/null
+++ b/python/tvm/tir/ir_builder_v2.py
@@ -0,0 +1,901 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""IRBuilder for TIR"""
+import functools
+import inspect
+from numbers import Integral
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
+from tvm.target import Target as target
+
+from . import _ffi_ir_builder_api as _ffi_api
+from . import ir_builder_frame as frame
+from . import op as _tir_op
+from .buffer import Buffer
+from .expr import Broadcast as broadcast
+from .expr import BufferLoad, CommReducer, IntImm, IterVar, Let, PrimExpr
+from .expr import Ramp as ramp
+from .expr import Select, Shuffle, StringImm, Var
+from .generic import cast
+from .stmt import BufferRegion, type_annotation
+
+
+def buffer_decl(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.BufferDecl(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def ptr(dtype, storage_scope="global"):
+    return _ffi_api.Ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+
+buffer_var = ptr
+
+
+def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
+    return _ffi_api.Block(name, no_realize)  # pylint: disable=no-member # type: ignore
+
+
+def init() -> frame.BlockInitFrame:
+    return _ffi_api.Init()  # pylint: disable=no-member # type: ignore
+
+
+def where(predicate) -> None:
+    if isinstance(predicate, bool):
+        predicate = IntImm("bool", predicate)
+    _ffi_api.Where(predicate)  # pylint: disable=no-member # type: ignore
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]  # type: ignore
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Reads(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Writes(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.BlockAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def alloc_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.AllocBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def _as_range(dom) -> Range:
+    if isinstance(dom, Range):
+        return dom
+    if isinstance(dom, (list, tuple)):
+        return Range(dom[0], dom[1])
+    return Range(0, dom)
+
+
+class axis:  # pylint: disable=invalid-name
+    @staticmethod
+    def spatial(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisSpatial(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def reduce(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisReduce(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def scan(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisScan(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def opaque(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisOpaque(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def remap(kinds, bindings, dtype="int32") -> Union[List[IterVar], IterVar]:
+        iter_vars = _ffi_api.AxisRemap(  # pylint: disable=no-member # type: ignore
+            kinds, bindings, dtype
+        )
+        return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+    S = spatial  # pylint: disable=invalid-name
+    R = reduce  # pylint: disable=invalid-name
+
+
+def serial(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Serial(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def parallel(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Parallel(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def vectorized(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Vectorized(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def unroll(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Unroll(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def thread_binding(
+    start,
+    stop=None,
+    thread=None,
+    *,
+    annotations=None,
+) -> frame.ForFrame:
+    if thread is None:
+        if not isinstance(stop, str):
+            raise ValueError("Thread cannot be None for thread_binding")
+        thread = stop
+        stop = start
+        start = 0
+    elif stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.ThreadBinding(  # pylint: disable=no-member # type: ignore
+        start, stop, thread, annotations
+    )
+
+
+def grid(*extents) -> frame.ForFrame:
+    return _ffi_api.Grid(extents)  # pylint: disable=no-member # type: ignore
+
+
+def prim_func() -> frame.PrimFuncFrame:
+    return _ffi_api.PrimFunc()  # pylint: disable=no-member # type: ignore
+
+
+def arg(name, obj):
+    return _ffi_api.Arg(name, obj)  # pylint: disable=no-member # type: ignore
+
+
+def func_name(name: str) -> str:
+    return _ffi_api.FuncName(name)  # pylint: disable=no-member # type: ignore
+
+
+def func_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.FuncAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def func_ret(ret_type) -> Type:
+    return _ffi_api.FuncRet(ret_type)  # pylint: disable=no-member # type: ignore
+
+
+def match_buffer(
+    param,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.MatchBuffer(  # pylint: disable=no-member # type: ignore
+        param,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def preflattened_buffer(
+    postflattened,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> None:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    _ffi_api.PreflattenedBuffer(  # pylint: disable=no-member # type: ignore
+        postflattened,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Assert(condition, message)  # pylint: disable=no-member # type: ignore
+
+
+def let(
+    v: Var,
+    value: PrimExpr,
+    body: PrimExpr = None,
+) -> frame.LetFrame:
+    if body is None:
+        return _ffi_api.Let(v, value)  # pylint: disable=no-member # type: ignore
+    return Let(v, value, body)
+
+
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str = "",
+    condition: PrimExpr = None,
+    annotations=None,
+) -> frame.AllocateFrame:
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.Allocate(  # pylint: disable=no-member # type: ignore
+        extents, dtype, scope, condition, annotations
+    )
+
+
+def allocate_const(
+    data: List[PrimExpr],
+    dtype: str,
+    extents: List[PrimExpr],
+    annotations=None,
+) -> frame.AllocateConstFrame:
+
+    return _ffi_api.AllocateConst(  # pylint: disable=no-member # type: ignore
+        ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+    )
+
+
+def realize(
+    buffer_slice: BufferRegion,
+    storage_scope: str,
+    condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+    return _ffi_api.Realize(  # pylint: disable=no-member # type: ignore
+        buffer_slice, storage_scope, condition
+    )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+    node = convert(node)
+    value = convert(value)
+    return _ffi_api.Attr(node, attr_key, value)  # pylint: disable=no-member # type: ignore
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.While(condition)  # pylint: disable=no-member # type: ignore
+
+
+def If(condition: PrimExpr) -> frame.IfFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.If(condition)  # pylint: disable=no-member # type: ignore
+
+
+def Then() -> frame.ThenFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Then()  # pylint: disable=no-member # type: ignore
+
+
+def Else() -> frame.ElseFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Else()  # pylint: disable=no-member # type: ignore
+
+
+def decl_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> frame.DeclBufferFrame:
+
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.DeclBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def launch_thread(
+    iter_var: IterVar,  # pylint: disable=redefined-outer-name
+    extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+    return _ffi_api.LaunchThread(iter_var, extent)  # pylint: disable=no-member # type: ignore
+
+
+def env_thread(thread_tag: str) -> IterVar:
+    return _ffi_api.EnvThread(thread_tag)  # pylint: disable=no-member # type: ignore
+
+
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+    from tvm.arith import Analyzer  # pylint: disable=import-outside-toplevel
+
+    expr_indices = []
+    for index in indices:
+        if isinstance(index, slice):
+            step = 1 if index.step is None else index.step
+            lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+            if lanes == 1:
+                expr_indices.append(index.start)
+            else:
+                expr_indices.append(ramp(index.start, step, int(lanes)))
+        else:
+            expr_indices.append(index)
+    if isinstance(value, bool) and buffer.dtype == "bool":
+        value = IntImm("bool", value)
+    return _ffi_api.BufferStore(  # pylint: disable=no-member # type: ignore
+        buffer, value, expr_indices
+    )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+    return _ffi_api.Prefetch(buffer, indices)  # pylint: disable=no-member # type: ignore
+
+
+def evaluate(value: PrimExpr) -> None:
+    if isinstance(value, str):
+        value = StringImm(value)
+    return _ffi_api.Evaluate(value)  # pylint: disable=no-member # type: ignore
+
+
+def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Boolean(expr)  # pylint: disable=no-member # type: ignore
+
+
+def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Handle(expr)  # pylint: disable=no-member # type: ignore
+
+
+def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Void(expr)  # pylint: disable=no-member # type: ignore
+
+
+def min(a, b):  # pylint: disable=redefined-builtin
+    """Compute the minimum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.min(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def max(a, b):  # pylint: disable=redefined-builtin
+    """Compute the maximum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.max(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def var(dtype, name="") -> Var:
+    return Var(name, dtype)  # pylint: disable=no-member # type: ignore
+
+
+def iter_var(v, dom, iter_type, thread_tag):
+    iter_type = getattr(IterVar, iter_type)
+    return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner, identity):
+    """Create a CommReducer from lambda inputs/outputs and the identities"""
+    params = inspect.signature(combiner).parameters
+    num_args = len(params)
+    args = []
+    for name, i in zip(params.keys(), identity + identity):
+        if isinstance(i, int):
+            args.append(Var(name, "int32"))
+        else:
+            args.append(Var(name, i.dtype))
+    res = combiner(*args)
+    if not isinstance(res, tuple):
+        res = (res,)
+    return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def llvm_lookup_intrinsic_id(name):
+    # pylint: disable=import-outside-toplevel
+    from tvm.target.codegen import llvm_lookup_intrinsic_id as f
+
+    # pylint: enable=import-outside-toplevel
+    return f(name)
+
+
+def _op_wrapper(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            kwargs.pop("dtype")
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+def _dtype_forward(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            args = (kwargs.pop("dtype"),) + args
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+abs = _op_wrapper(_tir_op.abs)  # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round)  # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+
+__all__ = [
+    "Assert",
+    "Else",
+    "If",
+    "Let",
+    "Select",
+    "Shuffle",
+    "TVMBackendAllocWorkspace",
+    "TVMBackendFreeWorkspace",
+    "Then",
+    "While",
+    "abs",
+    "acos",
+    "acosh",
+    "address_of",
+    "alloc_buffer",
+    "allocate",
+    "allocate_const",
+    "arg",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "attr",
+    "axis",
+    "block",
+    "block_attr",
+    "boolean",
+    "broadcast",
+    "buffer_decl",
+    "buffer_store",
+    "buffer_var",
+    "call_cpacked",
+    "call_cpacked_lowered",
+    "call_extern",
+    "call_intrin",
+    "call_llvm_intrin",
+    "call_llvm_pure_intrin",
+    "call_packed",
+    "call_packed_lowered",
+    "call_pure_extern",
+    "cast",
+    "ceil",
+    "ceildiv",
+    "clz",
+    "comm_reducer",
+    "copysign",
+    "cos",
+    "cosh",
+    "env_thread",
+    "erf",
+    "evaluate",
+    "exp",
+    "exp10",
+    "exp2",
+    "decl_buffer",
+    "fabs",
+    "float16",
+    "float32",
+    "float64",
+    "float8",
+    "floor",
+    "floordiv",
+    "floormod",
+    "fmod",
+    "func_attr",
+    "func_name",
+    "func_ret",
+    "grid",
+    "handle",
+    "hypot",
+    "if_then_else",
+    "infinity",
+    "init",
+    "int16",
+    "int32",
+    "int64",
+    "int8",
+    "isfinite",
+    "isinf",
+    "isnan",
+    "isnullptr",
+    "iter_var",
+    "launch_thread",
+    "ldexp",
+    "let",
+    "likely",
+    "llvm_lookup_intrinsic_id",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "lookup_param",
+    "match_buffer",
+    "max",
+    "max_value",
+    "min",
+    "min_value",
+    "mma_fill",
+    "mma_store",
+    "nearbyint",
+    "nextafter",
+    "parallel",
+    "popcount",
+    "power",
+    "prefetch",
+    "preflattened_buffer",
+    "prim_func",
+    "ptr",
+    "ptx_commit_group",
+    "ptx_cp_async",
+    "ptx_ldmatrix",
+    "ptx_mma",
+    "ptx_mma_sp",
+    "ptx_wait_group",
+    "q_multiply_shift",
+    "ramp",
+    "reads",
+    "realize",
+    "reinterpret",
+    "ret",
+    "round",
+    "rsqrt",
+    "serial",
+    "shift_left",
+    "shift_right",
+    "sigmoid",
+    "sin",
+    "sinh",
+    "sqrt",
+    "tan",
+    "tanh",
+    "target",
+    "thread_binding",
+    "trunc",
+    "truncdiv",
+    "truncmod",
+    "tvm_access_ptr",
+    "tvm_bmma_sync",
+    "tvm_call_cpacked",
+    "tvm_call_cpacked_lowered",
+    "tvm_call_packed",
+    "tvm_call_packed_lowered",
+    "tvm_fill_fragment",
+    "tvm_load_matrix_sync",
+    "tvm_mma_sync",
+    "tvm_stack_alloca",
+    "tvm_stack_make_array",
+    "tvm_stack_make_shape",
+    "tvm_store_matrix_sync",
+    "tvm_struct_get",
+    "tvm_struct_set",
+    "tvm_thread_allreduce",
+    "tvm_throw_last_error",
+    "tvm_tuple",
+    "type_annotation",
+    "uint16",
+    "uint32",
+    "uint64",
+    "uint8",
+    "unroll",
+    "var",
+    "vectorized",
+    "void",
+    "where",
+    "writes",
+]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 17005b04a4..7b5f55fd5f 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -100,6 +100,64 @@ def call_cpacked(*args, span=None):
     return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span)
 
 
+def call_packed_lowered(*args, span=None):
+    """Lowered version of call packed.
+
+    The argument to packed function can be Expr or Buffer.
+    The argument is the corresponding POD type when Expr is presented.
+
+    When the argument is Buffer, the corresponding PackedFunc
+    will recieve an TVMArrayHandle whose content is valid during the callback period.
+    If the PackedFunc is a python callback, then the corresponding argument is NDArray.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span)
+
+
+def call_cpacked_lowered(*args, span=None):
+    """Lowered version of call c-packed.
+
+    Same as call_packed, except that the first argument is the function name
+    (as in call_extern), and the last argument is the resource handle.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span)
+
+
 def call_intrin(dtype, func_name, *args, span=None):
     """Build expression by calling an intrinsic function.
 
@@ -151,7 +209,10 @@ def call_pure_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span
+        dtype,
+        Op.get("tir.call_pure_extern"),
+        convert((StringImm(func_name),) + args),
+        span,
     )
 
 
@@ -178,7 +239,10 @@ def call_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span
+        dtype,
+        Op.get("tir.call_extern"),
+        convert((StringImm(func_name),) + args),
+        span=span,
     )
 
 
@@ -206,11 +270,21 @@ def call_llvm_intrin(dtype, name, *args, span=None):
     """
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
-
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
     assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
     return call_intrin(
-        dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+        dtype,
+        Op.get("tir.call_llvm_intrin"),
+        tvm.tir.const(llvm_id, "uint32"),
+        *args,
+        span=span,
     )
 
 
@@ -238,8 +312,14 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     """
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
-
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
     assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
     return call_intrin(
         dtype,
@@ -250,6 +330,310 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     )
 
 
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+    return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+    return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
+def tvm_stack_alloca(dtype_str, num):
+    return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num)
+
+
+def tvm_stack_make_shape(*args):
+    return call_intrin("handle", "tir.tvm_stack_make_shape", *args)
+
+
+def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
+    return call_intrin(
+        "handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset
+    )
+
+
+def address_of(buffer_load, span=None):
+    """Returns the address of an element in the buffer
+
+    Parameters
+    ----------
+    buffer_load: BufferLoad
+        The buffer load.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.address_of", buffer_load)
+
+
+def lookup_param(param_name, span=None):
+    """Returns the param by name
+
+    Parameters
+    ----------
+    param_name : str
+        The name of param.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.lookup_param", param_name)
+
+
+def tvm_access_ptr(dtype, data, offset, extent, rw_mask):
+    return call_intrin("handle", "tir.tvm_access_ptr", dtype, data, offset, extent, rw_mask)
+
+
+def tvm_tuple(*value):
+    return call_intrin("handle", "tir.tvm_tuple", *value)
+
+
+def tvm_struct_get(arr, index, field_id, dtype):
+    return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field_id)
+
+
+def tvm_struct_set(arr, index, field_id, value):
+    return call_intrin("handle", "tir.tvm_struct_set", arr, index, field_id, value)
+
+
+def tvm_thread_allreduce(*freduce_args):
+    return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_load_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def tvm_mma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_mma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_bmma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_bmma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_fill_fragment(fragment, m, n, k, index, value):
+    return call_intrin(
+        "handle",
+        "tir.tvm_fill_fragment",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        value,
+    )
+
+
+def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_store_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def ptx_mma(
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    saturate,
+    operator=None,
+):
+    if operator is None:
+        return call_intrin(
+            dtype,
+            "tir.ptx_mma",
+            shape,
+            A_layout,
+            B_layout,
+            A_dtype,
+            B_dtype,
+            C_dtype,
+            multiplicand_a,
+            a_index,
+            multiplicand_b,
+            b_index,
+            accumulator,
+            c_index,
+            saturate,
+        )
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        saturate,
+        operator,
+    )
+
+
+def ptx_mma_sp(
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    metadata,
+    meta_index,
+    sparse_selector,
+    saturate,
+):
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma_sp",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        metadata,
+        meta_index,
+        sparse_selector,
+        saturate,
+    )
+
+
+def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
+    return call_intrin(
+        dtype,
+        "tir.ptx_ldmatrix",
+        trans,
+        num,
+        type,
+        local_ptr,
+        local_offset,
+        smem_ptr,
+        smem_offset,
+    )
+
+
+def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
+    return call_intrin(
+        dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
+    )
+
+
+def ptx_commit_group():
+    return call_intrin("", "tir.ptx_commit_group")
+
+
+def ptx_wait_group(num):
+    return call_intrin("", "tir.ptx_wait_group", num)
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+    return call_intrin(
+        dtype,
+        "tir.mma_store",
+        m,
+        n,
+        dst_ptr,
+        src_ptr,
+        src_offset,
+        dst_stride,
+    )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+    return call_intrin(
+        dtype,
+        "tir.mma_fill",
+        local_size,
+        local_ptr,
+        offset,
+    )
+
+
 def ret(val):
     """Create a tir return expression
 
@@ -394,6 +778,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
     return _ffi_api.max_value(dtype, span)  # type: ignore
 
 
+def infinity(dtype: str, span: Optional[Span] = None) -> Any:
+    """infinity value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The infinity value of dtype.
+    """
+    return _ffi_api.infinity(dtype, span)  # type: ignore
+
+
+def reinterpret(dtype, value, span=None) -> Any:
+    """infinity value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    value : PrimExpr
+        The input value.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The reinterpret cast value of dtype.
+    """
+    return _ffi_api.reinterpret(dtype, value, span)  # type: ignore
+
+
 def exp(x):
     """Take exponential of input x.
 
@@ -998,6 +1423,25 @@ def ldexp(x1, x2):
     return call_intrin(x1.dtype, "tir.ldexp", x1, x2)  # type: ignore
 
 
+def likely(cond, span=None):
+    """Mark condition as likely.
+
+    Parameters
+    ----------
+    cond : PrimExpr
+        Input argument.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    y : PrimExpr
+        The marked expression.
+    """
+    return _ffi_api.likely(cond, span)  # type: ignore
+
+
 def isnan(x, span=None):
     """Check if input value is Nan.
 
@@ -1017,6 +1461,25 @@ def isnan(x, span=None):
     return _ffi_api.isnan(x, span)  # type: ignore
 
 
+def isnullptr(x, span=None):
+    """Check if input value is nullptr.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_intrin("bool", "tir.isnullptr", x)  # type: ignore
+
+
 def isfinite(x, span=None):
     """Check if input value is finite.
 
@@ -1122,6 +1585,42 @@ def q_multiply_shift(x, y, q, s):
     return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)
 
 
+def shift_left(x, y, span=None):
+    """Return the result of x left shifted by y bits.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+    y : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    z : PrimExpr
+        The result.
+    """
+    return _ffi_api.left_shift(x, y, span)
+
+
+def shift_right(x, y, span=None):
+    """Return the result of x right shifted by y bits.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+    y : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    z : PrimExpr
+        The result.
+    """
+    return _ffi_api.right_shift(x, y, span)
+
+
 def fmod(x, y):
     """Return the remainder of x divided by y with the same sign as x.
 
@@ -1306,6 +1805,28 @@ def truncmod(a, b, span=None):
     return _ffi_api._OpTruncMod(a, b, span)  # type: ignore
 
 
+def ceildiv(a, b, span=None):
+    """Compute the ceildiv of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    span : Optional[Span]
+        The location of this operator in the source.
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api._OpCeilDiv(a, b, span)  # type: ignore
+
+
 def floordiv(a, b, span=None):
     """Compute the floordiv of two expressions.
 
@@ -1523,6 +2044,22 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
     return reducer
 
 
+def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
+    return call_intrin(
+        "handle",
+        "tir.TVMBackendAllocWorkspace",
+        device_type,
+        device_id,
+        nbytes,
+        dtype_code_hint,
+        dtype_bits_hint,
+    )
+
+
+def TVMBackendFreeWorkspace(device_type, device_id, ptr):
+    call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)
+
+
 # pylint: disable=unnecessary-lambda
 sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
 min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min")  # type: ignore
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
index 30e047b4f7..0ebaf212d1 100644
--- a/python/tvm/tir/schedule/block_scope.py
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -20,8 +20,8 @@ from typing import List, Optional, Union
 
 from tvm._ffi import register_object
 from tvm.runtime import Object
-from tvm.tir import Block, For
 
+from ..stmt import Block, For
 from . import _ffi_api
 
 
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index cf031c014c..f26f954d51 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -21,9 +21,11 @@ from tvm._ffi import register_object as _register_object
 from tvm.error import TVMError, register_error
 from tvm.ir import IRModule, PrimExpr
 from tvm.runtime import Object, String
-from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
 
-from ..function import IndexMap
+from ..buffer import Buffer
+from ..expr import FloatImm, IntImm
+from ..function import IndexMap, PrimFunc
+from ..stmt import Block, For
 from . import _ffi_api
 from ._type_checker import type_checked
 from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..3aed52fb50 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -22,8 +22,9 @@ from typing import Dict, Optional, Union
 from tvm._ffi import register_object
 from tvm.ir import IRModule
 from tvm.runtime import Object
-from tvm.tir import Block, BlockRealize, For, PrimFunc
 
+from ..function import PrimFunc
+from ..stmt import Block, BlockRealize, For
 from . import _ffi_api
 from .block_scope import BlockScope, StmtSRef
 
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4847e377de..3c2228e6d9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -754,3 +754,7 @@ def stmt_list(stmt):
             res += stmt_list(x)
         return res
     return [stmt]
+
+
+def type_annotation(dtype, span=None):
+    return _ffi_api.TypeAnnotation(dtype, span)
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
index f472172cf3..86d8bef356 100644
--- a/python/tvm/tir/usmp/transform/transform.py
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -20,8 +20,9 @@
 from typing import Dict
 
 import tvm
-from tvm.tir import Stmt
-from tvm.tir.usmp.utils import PoolAllocation
+
+from ...stmt import Stmt
+from ..utils import PoolAllocation
 from . import _ffi_api
 
 
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index a3318bf94f..a7d95848da 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -21,6 +21,7 @@
  * \file src/ir/expr.cc
  * \brief The expression AST nodes for the common IR infra.
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/expr.h>
 #include <tvm/ir/function.h>
 #include <tvm/runtime/registry.h>
@@ -49,6 +50,18 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
   if (auto* ptr = ref.as<runtime::StringObj>()) {
     return tir::StringImm(GetRef<runtime::String>(ptr));
   }
+  if (auto* ptr = ref.as<tir::BufferRegionNode>()) {
+    tir::BufferRegion buffer_region = GetRef<tir::BufferRegion>(ptr);
+    Array<PrimExpr> indices;
+    for (Range r : buffer_region->region) {
+      if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+        indices.push_back(r->min);
+      } else {
+        indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+      }
+    }
+    return tir::BufferLoad(buffer_region->buffer, indices);
+  }
   Optional<String> actual_type = ObjectTypeChecker<PrimExpr>::CheckAndGetMismatch(ref.get());
   ICHECK(!actual_type.defined()) << "Expected type " << ObjectTypeChecker<PrimExpr>::TypeName()
                                  << " but got " << actual_type.value();
diff --git a/src/ir/ir_builder.cc b/src/ir/ir_builder.cc
new file mode 100644
index 0000000000..9f42cdb168
--- /dev/null
+++ b/src/ir/ir_builder.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ir/ir_builder.h>
+#include <tvm/ir/module.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace ir_builder {
+
+void IRBuilderFrameNode::EnterWithScope() {
+  IRBuilder::Current()->frames.push_back(GetRef<IRBuilderFrame>(this));
+}
+
+void IRBuilderFrameNode::ExitWithScope() {
+  for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
+    (*it)();
+  }
+  this->callbacks.clear();
+  IRBuilder::Current()->frames.pop_back();
+}
+
+void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc<void()> callback) {
+  if (IRBuilder::Current()->frames.empty()) {
+    LOG(FATAL) << "ValueError: No frames in Builder to add callback";
+  }
+  IRBuilder::Current()->frames.back()->callbacks.push_back(callback);
+}
+
+IRBuilder::IRBuilder() {
+  ObjectPtr<IRBuilderNode> n = make_object<IRBuilderNode>();
+  n->frames.clear();
+  n->result = NullOpt;
+  data_ = n;
+}
+
+std::vector<IRBuilder>* ThreadLocalBuilderStack() {
+  thread_local std::vector<IRBuilder> stack;
+  return &stack;
+}
+
+void IRBuilder::EnterWithScope() {
+  IRBuilderNode* n = this->get();
+  CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: "
+                           << n->frames.size()
+                           << ". Please use a fresh new builder every time building IRs";
+  n->result = NullOpt;
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  stack->push_back(*this);
+}
+
+void IRBuilder::ExitWithScope() {
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  ICHECK(!stack->empty());
+  stack->pop_back();
+}
+
+IRBuilder IRBuilder::Current() {
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  CHECK(!stack->empty()) << "ValueError: No builder in current scope";
+  return stack->back();
+}
+
+IRModuleFrame IRModule() {
+  ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
+  n->global_vars.clear();
+  n->functions.clear();
+  return IRModuleFrame(n);
+}
+
+void IRModuleFrameNode::ExitWithScope() {
+  ICHECK_EQ(functions.size(), global_vars.size());
+  int n = functions.size();
+  Map<GlobalVar, BaseFunc> func_map;
+  for (int i = 0; i < n; ++i) {
+    func_map.Set(global_vars[i], functions[i]);
+  }
+  IRBuilder builder = IRBuilder::Current();
+  ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+  builder->result = tvm::IRModule(func_map);
+}
+
+namespace details {
+
+Namer::FType& Namer::vtable() {
+  static FType inst;
+  return inst;
+}
+
+void Namer::Name(ObjectRef node, String name) {
+  static const FType& f = vtable();
+  CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name;
+  CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \""
+                              << node->GetTypeKey();
+  f(node, name);
+}
+
+}  // namespace details
+
+TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode);
+TVM_REGISTER_NODE_TYPE(IRBuilderNode);
+TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameEnter")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameExit")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameAddCallback")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::AddCallback);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); });
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderGet")
+    .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRModule").set_body_typed(IRModule);
+
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 7979c9f47a..9bbd9068e2 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -20,6 +20,7 @@
 /*!
  * \file expr.cc
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
@@ -828,12 +829,26 @@ TVM_REGISTER_GLOBAL("tir.Call")
     .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) {
       Array<PrimExpr> prim_expr_args;
       for (const auto& it : args) {
-        ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>())
+        ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() ||
+               it->IsInstance<IterVarNode>() || it->IsInstance<BufferRegionNode>())
             << "Argument " << it << " is not a string or primexpr";
         if (const auto* str = it.as<runtime::StringObj>()) {
           prim_expr_args.push_back(StringImm(str->data));
+        } else if (const auto* expr = it.as<PrimExprNode>()) {
+          prim_expr_args.push_back(GetRef<PrimExpr>(expr));
+        } else if (const auto* br = it.as<BufferRegionNode>()) {
+          BufferRegion buffer_region = GetRef<BufferRegion>(br);
+          Array<PrimExpr> indices;
+          for (Range r : buffer_region->region) {
+            if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+              indices.push_back(r->min);
+            } else {
+              indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+            }
+          }
+          prim_expr_args.push_back(BufferLoad(buffer_region->buffer, indices));
         } else {
-          prim_expr_args.push_back(Downcast<PrimExpr>(it));
+          prim_expr_args.push_back(Downcast<IterVar>(it).operator PrimExpr());
         }
       }
       return Call(type, op, prim_expr_args, span);
diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index 7e3d3d1075..b11ca6650a 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -22,11 +22,10 @@
  * \brief Used by TVM Script parser to expand incomplete TIR input
  */
 
+#include "./script_complete.h"
+
 #include <tvm/arith/int_set.h>
-#include <tvm/runtime/registry.h>
 #include <tvm/tir/analysis.h>
-#include <tvm/tir/stmt.h>
-#include <tvm/tir/stmt_functor.h>
 
 #include <utility>
 
diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h
new file mode 100644
index 0000000000..5392ea309d
--- /dev/null
+++ b/src/tir/ir/script/script_complete.h
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/ir/script/script_complete.h
+ * \brief Used by TVM Script parser to expand incomplete TIR input
+ */
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 524204f3d3..6dd5bf5886 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -1091,5 +1091,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
 TVM_REGISTER_OP("tir.type_annotation")
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
+TVM_REGISTER_GLOBAL("tir.TypeAnnotation").set_body_typed(TypeAnnotation);
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder.cc b/src/tir/ir_builder/ir_builder.cc
new file mode 100644
index 0000000000..d312f014e5
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder.cc
@@ -0,0 +1,637 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+                  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+                  String storage_scope, int align, int offset_factor, String buffer_type,
+                  Optional<Array<IntImm>> axis_separators) {
+  Var buffer_data;
+  if (!data.defined()) {
+    DataType storage_dtype = dtype;
+    if (storage_dtype == DataType::Bool()) {
+      storage_dtype = DataType::Int(8);
+    }
+    buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope));
+  } else {
+    buffer_data = data.value();
+  }
+  if (!elem_offset.defined() && offset_factor) {
+    DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype;
+    elem_offset = tvm::tir::Var("elem_offset", shape_dtype);
+  }
+  return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
+                elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor,
+                (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault,
+                axis_separators.value_or(Array<IntImm>()));
+}
+
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+                           Optional<Var> data, Optional<Array<PrimExpr>> strides,
+                           Optional<PrimExpr> elem_offset, String storage_scope, int align,
+                           int offset_factor, String buffer_type,
+                           Optional<Array<IntImm>> axis_separators) {
+  ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
+  n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
+                         align, offset_factor, buffer_type, axis_separators);
+  return DeclBufferFrame(n);
+}
+
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
+  return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
+}
+
+BlockFrame Block(String name, bool no_realize) {
+  ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
+  n->name = name;
+  n->iter_vars.clear();
+  n->reads = NullOpt;
+  n->writes = NullOpt;
+  n->init = NullOpt;
+  n->alloc_buffers.clear();
+  n->match_buffers.clear();
+  n->annotations.clear();
+  n->iter_values.clear();
+  n->predicate = NullOpt;
+  n->no_realize = no_realize;
+  return BlockFrame(n);
+}
+
+BlockInitFrame Init() { return BlockInitFrame(make_object<BlockInitFrameNode>()); }
+
+void Where(PrimExpr predicate) {
+  BlockFrame frame = FindBlockFrame("T.where");
+  if (frame->predicate.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is "
+               << frame->predicate.value();
+  }
+  frame->predicate = predicate;
+}
+
+void Reads(Array<ObjectRef> buffer_slices) {
+  using namespace tvm::tir;
+  BlockFrame frame = FindBlockFrame("T.reads");
+  if (frame->reads.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads;
+  }
+  Array<BufferRegion> reads;
+  for (const ObjectRef& obj : buffer_slices) {
+    if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+      reads.push_back(GetRef<BufferRegion>(buffer_region));
+    } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+      reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+    } else {
+      LOG(FATAL) << "Invalid type for buffer reads.";
+    }
+  }
+  frame->reads = reads;
+}
+
+void Writes(Array<ObjectRef> buffer_slices) {
+  using namespace tvm::tir;
+  BlockFrame frame = FindBlockFrame("T.writes");
+  if (frame->writes.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is "
+               << frame->writes;
+  }
+  Array<BufferRegion> writes;
+  for (const ObjectRef& obj : buffer_slices) {
+    if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+      writes.push_back(GetRef<BufferRegion>(buffer_region));
+    } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+      writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+    } else {
+      LOG(FATAL) << "Invalid type for buffer writes.";
+    }
+  }
+  frame->writes = writes;
+}
+
+void BlockAttrs(Map<String, ObjectRef> attrs) {
+  BlockFrame frame = FindBlockFrame("T.block_attr");
+  if (!frame->annotations.empty()) {
+    LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
+  }
+  frame->annotations = attrs;
+}
+
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+                   Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+                   int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+  Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+                             offset_factor, buffer_type_str, axis_separators);
+  IRBuilder builder = IRBuilder::Current();
+  if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) {
+    frame.value()->alloc_buffers.push_back(buffer);
+  } else if (Optional<PrimFuncFrame> frame = builder->GetLastFrame<PrimFuncFrame>()) {
+    frame.value()->root_alloc_buffers.push_back(buffer);
+  } else {
+    LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure "
+                  "'T.alloc_buffer' is called under T.block() or T.prim_func()";
+  }
+  return buffer;
+};
+
+namespace axis {
+
+IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
+  if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+    BlockFrame frame = opt_frame.value();
+    frame->iter_vars.push_back(iter_var);
+    frame->iter_values.push_back(binding);
+  } else {
+    LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
+  }
+  return iter_var;
+}
+
+#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name)                                           \
+  IterVar Method(Range dom, PrimExpr binding, DataType dtype) {                               \
+    ICHECK(dom.defined()) << Name << " axis must have a domain";                              \
+    int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
+    return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)),          \
+                                /*iter_type=*/Kind, /*thread_tag=*/""),                       \
+                        binding);                                                             \
+  }
+TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
+TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
+TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
+TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
+#undef TVM_TIR_IR_BUILDER_AXIS
+
+Array<IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
+  using namespace tvm::tir;
+  Array<IterVar> results;
+  ICHECK_EQ(kinds.size(), bindings.size());
+  int n = bindings.size();
+  results.reserve(n);
+  for (int i = 0; i < n; ++i) {
+    char c = kinds.c_str()[i];
+    PrimExpr e = bindings[i];
+    const VarNode* v = e.as<VarNode>();
+    ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
+    Range dom{nullptr};
+    for (const auto& frame : IRBuilder::Current()->frames) {
+      if (const auto* for_frame = frame.as<ForFrameNode>()) {
+        ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
+        int n = for_frame->doms.size();
+        for (int i = 0; i < n; ++i) {
+          if (for_frame->vars[i].get() == v) {
+            dom = for_frame->doms[i];
+            break;
+          }
+        }
+        if (dom.defined()) {
+          break;
+        }
+      }
+    }
+    ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
+    DataType dtype = v->dtype;
+    if (c == 'S') {
+      results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+                                             /*var=*/Var("", dtype),
+                                             /*iter_type=*/IterVarType::kDataPar,
+                                             /*thread_tag=*/""),
+                                     e));
+    } else if (c == 'R') {
+      results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+                                             /*var=*/Var("", dtype),
+                                             /*iter_type=*/IterVarType::kCommReduce,
+                                             /*thread_tag=*/""),
+                                     e));
+    } else {
+      LOG(FATAL) << "Unknown axis kind: " << c;
+    }
+  }
+  return results;
+}
+
+}  // namespace axis
+
+#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind)                                                \
+  ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) {  \
+    PrimExpr min = start;                                                                         \
+    PrimExpr extent = arith::Analyzer().Simplify(stop - start);                                   \
+    ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();                                      \
+    int bits = std::max(min.dtype().bits(), extent.dtype().bits());                               \
+    n->vars = {Var("v", DataType::Int(bits))};                                                    \
+    n->doms = {Range::FromMinExtent(min, extent)};                                                \
+    n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, tvm::tir::Stmt body) { \
+      ICHECK_EQ(vars.size(), 1);                                                                  \
+      ICHECK_EQ(doms.size(), 1);                                                                  \
+      return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt,           \
+                           annotations.value_or(Map<String, ObjectRef>()));                       \
+    };                                                                                            \
+    return ForFrame(n);                                                                           \
+  }
+
+TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled);
+
+#undef TVM_TIR_IR_BUILDER_FOR_FRAME
+
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+                       Optional<Map<String, ObjectRef>> annotations) {
+  using namespace tvm::tir;
+  PrimExpr min = start;
+  PrimExpr extent = arith::Analyzer().Simplify(stop - start);
+  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+  int bits = std::max(min.dtype().bits(), extent.dtype().bits());
+  n->vars = {Var("v", DataType::Int(bits))};
+  n->doms = {Range::FromMinExtent(min, extent)};
+  n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
+    ICHECK_EQ(vars.size(), 1);
+    ICHECK_EQ(doms.size(), 1);
+    IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
+                     thread);
+    return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
+               annotations.value_or(Map<String, ObjectRef>()));
+  };
+  return ForFrame(n);
+}
+
+ForFrame Grid(Array<PrimExpr> extents) {
+  using namespace tvm::tir;
+  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+  n->vars.reserve(extents.size());
+  n->doms.reserve(extents.size());
+  for (const auto& extent : extents) {
+    DataType dtype = extent.dtype();
+    n->vars.push_back(Var("v", extent.dtype()));
+    n->doms.push_back(Range(make_const(dtype, 0), extent));
+  }
+  n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, Stmt body) -> Stmt {
+    ICHECK_EQ(vars.size(), doms.size());
+    int n = vars.size();
+    for (int i = n - 1; i >= 0; --i) {
+      Range dom = doms[i];
+      Var var = vars[i];
+      body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
+                 /*thread_binding=*/NullOpt, /*annotations=*/{});
+    }
+    return body;
+  };
+  return ForFrame(n);
+}
+
+PrimFuncFrame PrimFunc() {
+  ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>();
+  n->name = NullOpt;
+  n->args.clear();
+  n->ret_type = NullOpt;
+  n->buffer_map.clear();
+  n->preflattened_buffer_map.clear();
+  n->attrs.clear();
+  n->root_alloc_buffers.clear();
+  return PrimFuncFrame(n);
+}
+
+Var Arg(String name, Var var) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+  details::Namer::Name(var, name);
+  frame->args.push_back(var);
+  return var;
+}
+
+Buffer Arg(String name, Buffer buffer) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+  details::Namer::Name(buffer, name);
+  Var handle(buffer->name + "_handle", DataType::Handle());
+  frame->args.push_back(handle);
+  frame->buffer_map.Set(handle, buffer);
+  return buffer;
+}
+
+void FuncName(String name) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
+  if (frame->name.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value();
+  }
+  frame->name = name;
+}
+
+void FuncAttrs(Map<String, ObjectRef> attrs) {
+  using namespace tvm::tir;
+  PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr");
+  if (!frame->attrs.empty()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
+  }
+  frame->attrs = attrs;
+}
+
+tvm::Type FuncRet(tvm::Type ret_type) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type");
+  if (frame->ret_type.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is "
+               << frame->ret_type.value();
+  }
+  frame->ret_type = ret_type;
+  return ret_type;
+}
+
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+                   Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+                   int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+  Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+                             offset_factor, buffer_type_str, axis_separators);
+  if (const auto* var = param.as<tvm::tir::VarNode>()) {
+    PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer");
+    Var v = GetRef<Var>(var);
+    for (auto const& arg : frame->args) {
+      if (arg.same_as(v)) {
+        frame->buffer_map.Set(v, buffer);
+        return buffer;
+      }
+    }
+    LOG(FATAL) << "ValueError: Can not bind non-input param to buffer.";
+  } else if (const auto* buffer_load = param.as<tvm::tir::BufferLoadNode>()) {
+    BlockFrame frame = FindBlockFrame("T.match_buffer");
+    frame->match_buffers.push_back(tvm::tir::MatchBufferRegion(
+        buffer, BufferRegionFromLoad(GetRef<tvm::tir::BufferLoad>(buffer_load))));
+  } else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) {
+    BlockFrame frame = FindBlockFrame("T.match_buffer");
+    frame->match_buffers.push_back(
+        tvm::tir::MatchBufferRegion(buffer, GetRef<tvm::tir::BufferRegion>(buffer_region)));
+  } else {
+    LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
+  }
+  return buffer;
+};
+
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape, DataType dtype,
+                        Optional<Var> data, Array<PrimExpr> strides, PrimExpr elem_offset,
+                        String storage_scope, int align, int offset_factor, String buffer_type_str,
+                        Array<IntImm> axis_separators) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer");
+  for (auto const& p : frame->buffer_map) {
+    if (p.second.same_as(postflattened_buffer)) {
+      String buffer_name(postflattened_buffer->name + "_preflatten");
+      Buffer buffer =
+          BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset,
+                     storage_scope, align, offset_factor, buffer_type_str, axis_separators);
+      details::Namer::Name(buffer, buffer_name);
+      frame->preflattened_buffer_map.Set(p.first, buffer);
+      return;
+    }
+  }
+  LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name
+             << " does not exist.";
+};
+
+AssertFrame Assert(PrimExpr condition, String message) {
+  ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>();
+  n->condition = condition;
+  n->message = tvm::tir::StringImm(message);
+  return AssertFrame(n);
+}
+
+LetFrame Let(Var var, PrimExpr value) {
+  ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
+  n->var = var;
+  n->value = value;
+  return LetFrame(n);
+}
+
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope,
+                       Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) {
+  ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>();
+  n->extents = extents;
+  n->dtype = dtype;
+  n->storage_scope = storage_scope;
+  n->condition = condition.value_or(tvm::Bool(true));
+  n->annotations = annotations.value_or(Map<String, ObjectRef>());
+  n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
+                         "default", NullOpt);
+  return AllocateFrame(n);
+}
+
+AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
+                                 Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+  ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
+  n->dtype = dtype;
+  n->extents = extents;
+  n->data = data;
+  n->annotations = annotations;
+  n->buffer =
+      BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+  return AllocateConstFrame(n);
+}
+
+LaunchThreadFrame LaunchThread(IterVar iter_var, PrimExpr extent) {
+  ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
+  if (!iter_var->dom.defined()) {
+    const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+  } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
+    LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
+               << iter_var->dom->extent << " vs " << extent;
+  }
+  n->iter_var = iter_var;
+  n->extent = extent;
+  n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent";
+  return LaunchThreadFrame(n);
+}
+
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
+                     PrimExpr condition) {
+  ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
+  n->buffer_slice = buffer_slice;
+  n->storage_scope = storage_scope;
+  n->condition = condition;
+  return RealizeFrame(n);
+}
+
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) {
+  ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
+  n->node = node;
+  n->attr_key = attr_key;
+  n->value = value;
+  return AttrFrame(n);
+}
+
+WhileFrame While(PrimExpr condition) {
+  ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>();
+  n->condition = condition;
+  return WhileFrame(n);
+}
+
+IfFrame If(PrimExpr condition) {
+  ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>();
+  n->condition = condition;
+  n->then_stmts = NullOpt;
+  n->else_stmts = NullOpt;
+  return IfFrame(n);
+}
+
+ThenFrame Then() {
+  ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>();
+  return ThenFrame(n);
+}
+
+ElseFrame Else() {
+  ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>();
+  return ElseFrame(n);
+}
+
+IterVar EnvThread(String thread_tag) {
+  return IterVar(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
+                 thread_tag);
+}
+
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+  AddToParent(tvm::tir::BufferStore(buffer, value, indices));
+}
+
+void Prefetch(Buffer buffer, Array<Range> bounds) {
+  AddToParent(tvm::tir::Prefetch(buffer, bounds));
+}
+
+void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
+
+using tvm::ir_builder::details::Namer;
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::BufferNode>([](const ObjectRef& node, String name) -> void {
+      tvm::tir::BufferNode* buffer =
+          const_cast<tvm::tir::BufferNode*>(node.as<tvm::tir::BufferNode>());
+      buffer->name = name;
+      Namer::Name(buffer->data, name + "_data");
+      int n = buffer->strides.size();
+      for (int i = 0; i < n; ++i) {
+        PrimExpr e = buffer->strides[i];
+        if (const tvm::tir::VarNode* v = e.as<tvm::tir::VarNode>()) {
+          Namer::Name(GetRef<tvm::tir::Var>(v), name + "_s" + std::to_string(i));
+        }
+      }
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::SizeVarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      SizeVarNode* var = const_cast<SizeVarNode*>(node.as<SizeVarNode>());
+      var->name_hint = name;
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::VarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      VarNode* var = const_cast<VarNode*>(node.as<VarNode>());
+      var->name_hint = name;
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::IterVarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      IterVarNode* var = const_cast<IterVarNode*>(node.as<IterVarNode>());
+      Namer::Name(var->var, name);
+    });
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Ptr").set_body_typed(Ptr);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Block").set_body_typed(Block);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Init").set_body_typed(Init);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Where").set_body_typed(Where);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Reads").set_body_typed(Reads);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Writes").set_body_typed(Writes);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Serial").set_body_typed(Serial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Parallel").set_body_typed(Parallel);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Vectorized").set_body_typed(Vectorized);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Unroll").set_body_typed(Unroll);
+TVM_REGISTER_GLOBAL("ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Grid").set_body_typed(Grid);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Arg")
+    .set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
+      using namespace tvm::tir;
+      if (const auto* var = obj.as<VarNode>()) {
+        return Arg(name, GetRef<tvm::tir::Var>(var));
+      }
+      if (const auto* buffer = obj.as<BufferNode>()) {
+        return Arg(name, GetRef<Buffer>(buffer));
+      }
+      LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey();
+      throw;
+    });
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncName").set_body_typed(FuncName);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncRet").set_body_typed(FuncRet);
+TVM_REGISTER_GLOBAL("ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Assert").set_body_typed(Assert);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Allocate").set_body_typed(Allocate);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Attr").set_body_typed(Attr);
+TVM_REGISTER_GLOBAL("ir_builder.tir.While").set_body_typed(While);
+TVM_REGISTER_GLOBAL("ir_builder.tir.If").set_body_typed(If);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Then").set_body_typed(Then);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Else").set_body_typed(Else);
+TVM_REGISTER_GLOBAL("ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.EnvThread").set_body_typed(EnvThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferStore").set_body_typed(BufferStore);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Prefetch").set_body_typed(Prefetch);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Evaluate").set_body_typed(Evaluate);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int8").set_body_typed(Int8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int16").set_body_typed(Int16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32").set_body_typed(Int32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int64").set_body_typed(Int64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt8").set_body_typed(UInt8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt16").set_body_typed(UInt16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt32").set_body_typed(UInt32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt64").set_body_typed(UInt64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float8").set_body_typed(Float8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float16").set_body_typed(Float16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float32").set_body_typed(Float32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float64").set_body_typed(Float64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Boolean").set_body_typed(Boolean);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Handle").set_body_typed(Handle);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Void").set_body_typed(Void);
+TVM_REGISTER_GLOBAL("ir_builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+  return tvm::min(a, b);
+});
+TVM_REGISTER_GLOBAL("ir_builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+  return tvm::max(a, b);
+});
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder_frame.cc b/src/tir/ir_builder/ir_builder_frame.cc
new file mode 100644
index 0000000000..cd0cd46b50
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder_frame.cc
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/function.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "../../tir/ir/script/script_complete.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+void BlockFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  Array<tvm::tir::Buffer> tir_alloc_buffers;
+  for (const tvm::tir::Buffer& buffer : alloc_buffers) {
+    tir_alloc_buffers.push_back(buffer);
+  }
+  if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) {
+    annotations.Set("tir.script_parsing_detect_access",
+                    tvm::IntImm(DataType::Int(64), detect_access));
+  }
+  tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()),
+                        writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init,
+                        tir_alloc_buffers, match_buffers, annotations);
+  if (no_realize) {
+    CHECK(iter_values.empty())
+        << "ValueError: Block bindings are not allowed when `no_realize=True`";
+    CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
+    AddToParent(block);
+  } else {
+    AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
+  }
+}
+
+void BlockInitFrameNode::EnterWithScope() {
+  BlockFrame frame = FindBlockFrame("T.init");
+  if (frame->init.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block init declaration";
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void BlockInitFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  BlockFrame frame = FindBlockFrame("T.init");
+  frame->init = AsStmt(stmts);
+}
+
+void ForFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
+}
+
+void PrimFuncFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  tvm::tir::PrimFunc func(/*params=*/args,
+                          /*body=*/AsStmt(stmts),
+                          /*ret_type=*/ret_type.value_or(TupleType::Empty()),
+                          /*buffer_map=*/buffer_map,
+                          /*preflattened_buffer_map=*/preflattened_buffer_map,
+                          /*attrs=*/attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs));
+  func = tvm::tir::ScriptComplete(func, root_alloc_buffers);
+  IRBuilder builder = IRBuilder::Current();
+  if (builder->frames.empty()) {
+    ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+    builder->result = func;
+  } else if (Optional<IRModuleFrame> opt_frame = builder->FindFrame<IRModuleFrame>()) {
+    IRModuleFrame frame = opt_frame.value();
+    frame->global_vars.push_back(GlobalVar(name.value_or("")));
+    frame->functions.push_back(func);
+  } else {
+    LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
+  }
+}
+
+void AssertFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+}
+
+void LetFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
+}
+
+void AllocateFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
+                                 AsStmt(stmts), annotations));
+}
+
+void AllocateConstFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(
+      tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+}
+
+void LaunchThreadFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
+}
+
+void RealizeFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
+                                 tvm::tir::StringImm(storage_scope),
+                                 tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region,
+                                                         condition, AsStmt(stmts))));
+}
+
+void AttrFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts)));
+}
+
+void WhileFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::While(condition, AsStmt(stmts)));
+}
+
+void IfFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  if (!stmts.empty()) {
+    LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame";
+  }
+  if (!then_stmts.defined()) {
+    LOG(FATAL) << "IfThenElse frame should have at least one then branch";
+  }
+  AddToParent(tvm::tir::IfThenElse(
+      condition, AsStmt(then_stmts.value()),
+      else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr)));
+}
+
+void ThenFrameNode::EnterWithScope() {
+  IfFrame frame = FindIfFrame("T.then_");
+  if (frame->then_stmts.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is "
+               << frame->then_stmts.value();
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void ThenFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  FindIfFrame("T.then_")->then_stmts = stmts;
+}
+
+void ElseFrameNode::EnterWithScope() {
+  IfFrame frame = FindIfFrame("T.else_");
+  if (!frame->then_stmts.defined()) {
+    LOG(FATAL) << "The else branch should follow then branch";
+  }
+  if (frame->else_stmts.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is "
+               << frame->else_stmts.value();
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void ElseFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  FindIfFrame("T.else_")->else_stmts = stmts;
+}
+
+void DeclBufferFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+}
+
+TVM_REGISTER_NODE_TYPE(TIRFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
+TVM_REGISTER_NODE_TYPE(ForFrameNode);
+TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
+TVM_REGISTER_NODE_TYPE(AssertFrameNode);
+TVM_REGISTER_NODE_TYPE(LetFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode);
+TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
+TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
+TVM_REGISTER_NODE_TYPE(AttrFrameNode);
+TVM_REGISTER_NODE_TYPE(WhileFrameNode);
+TVM_REGISTER_NODE_TYPE(IfFrameNode);
+TVM_REGISTER_NODE_TYPE(ThenFrameNode);
+TVM_REGISTER_NODE_TYPE(ElseFrameNode);
+TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode);
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir_builder/utils.h b/src/tir/ir_builder/utils.h
new file mode 100644
index 0000000000..6b6271ab0e
--- /dev/null
+++ b/src/tir/ir_builder/utils.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_BUILDER_TIR_BASE_H_
+#define TVM_SCRIPT_BUILDER_TIR_BASE_H_
+
+#include <tvm/tir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+inline void AddToParent(tvm::tir::Stmt stmt) {
+  IRBuilder builder = IRBuilder::Current();
+  if (builder->frames.empty()) {
+    ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+    builder->result = stmt;
+  } else if (const auto* tir_frame = builder->frames.back().as<TIRFrameNode>()) {
+    GetRef<TIRFrame>(tir_frame)->stmts.push_back(stmt);
+  } else {
+    LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back();
+  }
+}
+
+inline tvm::tir::Stmt AsStmt(const Array<tvm::tir::Stmt>& stmt) {
+  using namespace tvm::tir;
+  if (stmt.empty()) {
+    return tvm::tir::Evaluate(0);
+  } else if (stmt.size() == 1) {
+    return stmt[0];
+  } else {
+    return SeqStmt(stmt);
+  }
+}
+
+inline BlockFrame FindBlockFrame(const String& method) {
+  if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+    return frame.value();
+  }
+  LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
+             << "' is called under T.block()";
+  throw;
+}
+
+inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
+  if (Optional<PrimFuncFrame> frame = IRBuilder::Current()->GetLastFrame<PrimFuncFrame>()) {
+    return frame.value();
+  }
+  LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method
+             << "' is called under T.prim_func()";
+  throw;
+}
+
+inline IfFrame FindIfFrame(const String& method) {
+  if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) {
+    return frame.value();
+  } else {
+    LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method
+               << "' is called under T.if_()";
+  }
+  throw;
+}
+
+inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) {
+  Array<Range> ranges;
+  for (const PrimExpr& index : buffer_load->indices) {
+    ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1)));
+  }
+  return tvm::tir::BufferRegion(buffer_load->buffer, ranges);
+}
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_BUILDER_TIR_BASE_H_
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 114571218b..dd3c7a0b0f 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
                    {x, y, q, s}, span);
 }
 
+// address_of
+PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
+  return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
+}
+
+// lookup_param
+PrimExpr lookup_param(String param_name, Span span) {
+  return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
+                   span);
+}
+
 // The public function with a quick checking path.
 void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) {  // NOLINT(*)
   CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
@@ -702,6 +713,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
   }
 }
 
+// isnullptr
+PrimExpr isnullptr(PrimExpr x, Span span) {
+  return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
+}
+
 // isinf
 PrimExpr isinf(PrimExpr x, Span span) {
   DataType t = DataType::Bool(x.dtype().lanes());
@@ -931,6 +947,8 @@ TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
 
 TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
 
+TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);
+
 TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);
 
 TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
@@ -949,6 +967,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
 
 TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
 
+TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity);
+
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
 // operator overloading, smarter than make
 #define REGISTER_MAKE_BINARY_OP(Node, Func)                                                \
   TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
@@ -997,6 +1019,8 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
 REGISTER_MAKE_BIT_OP(left_shift, left_shift);  // NOLINT(*)
 REGISTER_MAKE_BIT_OP(right_shift, right_shift);
 
+TVM_REGISTER_GLOBAL("tir._OpNot").set_body_typed(logical_not);
+
 TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
     .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
       return if_then_else(cond, true_value, false_value, span);
diff --git a/tests/python/tvmscript/test_builder_basic.py b/tests/python/tvmscript/test_builder_basic.py
new file mode 100644
index 0000000000..7224d1b307
--- /dev/null
+++ b/tests/python/tvmscript/test_builder_basic.py
@@ -0,0 +1,227 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.script.builder import Builder, name, name_many
+from tvm.script.builder import tir as T
+from tvm.ir import Range
+
+
+def test_builder_root_block():
+    print("test_builder_root_block")
+    # impilict root block
+    with Builder() as b0:
+        with T.prim_func():
+            T.func_name("main")
+            T.func_attr({"key": "value"})
+            with T.block(name="block"):
+                pass
+    print(b0.get().script())
+    with Builder() as b1:
+        with T.prim_func():
+            T.func_name("main")
+            T.func_attr({"key": "value"})
+            A = name("A", T.alloc_buffer((128,)))
+            with T.block(name="block"):
+                pass
+    print(b1.get().script())
+    with Builder() as b2:
+        with T.prim_func():
+            T.func_name("main")
+            T.func_attr({"key": "value"})
+            A = name("A", T.alloc_buffer((128,)))
+            with T.block(name="block0"):
+                pass
+            with T.block(name="block1"):
+                pass
+    print(b2.get().script())
+    # expilict root block
+    with Builder() as b0_r:
+        with T.prim_func():
+            T.func_name("main")
+            T.func_attr({"key": "value"})
+            with T.block(name="root"):
+                with T.block(name="block"):
+                    pass
+    print(b0_r.get().script())
+    with Builder() as b1_r:
+        with T.prim_func():
+            T.func_name("main")
+            T.func_attr({"key": "value"})
+            with T.block(name="root"):
+                A = name("A", T.alloc_buffer((128,)))
+                with T.block(name="block"):
+                    pass
+    print(b1_r.get().script())
+    with Builder() as b2_r:
+        with T.prim_func():
+            T.func_name("main")
+            T.func_attr({"key": "value"})
+            with T.block(name="root"):
+                A = name("A", T.alloc_buffer((128,)))
+                with T.block(name="block0"):
+                    pass
+                with T.block(name="block1"):
+                    pass
+    print(b2_r.get().script())
+
+
+def test_builder_axis():
+    print("test_builder_axis")
+    with Builder() as b:
+        with T.prim_func():
+            T.func_name("main")
+            with T.grid(128, 128, 128, 128, 128) as (i, j, k, m, n):
+                name_many(["i", "j", "k", "m", "n"], [i, j, k, m, n])
+                with T.block(name="block"):
+                    vi = name("vi", T.axis.spatial(128, i))
+                    vj = name("vj", T.axis.spatial(128, j))
+                    vk = name("vk", T.axis.reduce(128, k))
+                    vm = name("vm", T.axis.scan(128, m))
+                    vn = name("vn", T.axis.opaque(128, n))
+                    x, y, z = name_many(["x", "y", "z"], T.axis.remap("SSR", [i, j, k]))
+    print(b.get().script())
+
+
+def test_builder_prim_func():
+    print("test_builder_prim_func")
+    with Builder() as b:
+        with T.prim_func():
+            T.func_name("main")
+            T.func_attr({"global_symbol": "main"})
+            arg_a = T.arg("a", T.handle())
+            arg_b = T.arg("b", T.handle())
+            buffer_c = T.Buffer((128,), "float32")
+            buffer_d = T.Buffer((128,), "float32")
+            arg_c = T.arg("c", buffer_c)
+            arg_d = T.arg("d", buffer_d)
+            T.func_ret(tvm.ir.PrimType("int8"))
+            A = name("A", T.match_buffer(arg_a, (128, 128, 128), "int32"))
+            B = name("B", T.match_buffer(arg_b, (128, 128, 128), "int32"))
+            T.preflattened_buffer(buffer_c, (128,), data=buffer_c.data)
+            T.preflattened_buffer(buffer_d, (128,), data=buffer_d.data)
+    print(b.get().script())
+
+
+def test_builder_block():
+    print("test_builder_block")
+    with Builder() as b:
+        with T.prim_func():
+            arg_a = T.arg("a", T.handle())
+            arg_b = T.arg("b", T.handle())
+            A = name("A", T.match_buffer(arg_a, (128, 128, 128), "int32"))
+            B = name("B", T.match_buffer(arg_b, (128, 128, 128), "int32"))
+            with T.grid(128, 128, 128) as (i, j, k):
+                name_many(["i", "j", "k"], [i, j, k])
+                with T.block(name="block"):
+                    T.block_attr({"axis": 1})
+                    T.where(i > 1)
+                    with T.init():
+                        pass
+                    vi, vj, vk = name_many(["vi", "vj", "vk"], T.axis.remap("SSR", [i, j, k]))
+                    T.reads(A[vi, vj, vk : vk + B[1, 2, A[3, 4, 5]]])
+                    T.writes(A[100, A[50, 51, 52], 102])
+                    E = name("E", T.alloc_buffer((128, 128)))
+                    F = name("F", T.alloc_buffer((128, 128)))
+    print(b.get().script())
+
+
+def test_builder_for():
+    print("test_builder_for")
+    with Builder() as b:
+        with T.prim_func():
+            with T.grid(128, 128, 128) as (i, j, k):
+                name_many(["i", "j", "k"], [i, j, k])
+            with T.serial(0, 128) as w:
+                w = name("w", w)
+            with T.parallel(0, 128) as x:
+                x = name("x", x)
+            with T.vectorized(0, 128) as y:
+                y = name("y", y)
+            with T.unroll(0, 128) as z:
+                z = name("z", z)
+            with T.thread_binding(0, 32, thread="blockIdx.x") as bx:
+                bx = name("bx", bx)
+                with T.thread_binding(0, 2, thread="vthread.y") as vy:
+                    vy = name("vy", vy)
+                    with T.thread_binding(0, 8, thread="threadIdx.z") as tz:
+                        tz = name("tz", tz)
+    print(b.get().script())
+
+
+def test_builder_stmt():
+    print("test_builder_stmt")
+    with Builder() as b:
+        with T.prim_func():
+            thread_x = name("thread_x", T.env_thread("threadIdx.x"))
+            thread_y = name("thread_y", T.env_thread("threadIdx.y"))
+            buffer_x = name("buffer_x", T.Buffer([128, 128]))
+            buffer_y = name("buffer_y", T.Buffer([128, 128]))
+            var_x = name("var_x", tvm.tir.Var("", dtype="int32"))
+            var_y = name("var_y", tvm.tir.Var("", dtype="int32"))
+            with T.Assert(var_x < var_y, ""):
+                with T.Assert(1, "true"):
+                    pass
+            with T.let(var_x, var_y):
+                pass
+            with T.allocate([128], "uint8", "global") as alloc_x:
+                with T.allocate([128], "uint8", "global") as alloc_y:
+                    alloc_x, alloc_y = name_many(["alloc_x", "alloc_y"], [alloc_x, alloc_y])
+            with T.allocate_const([1, 1, 1, 1, 1], "int32", [5]) as alloc_const_x:
+                with T.allocate_const([10, 10, 10], "float32", [3]) as alloc_const_y:
+                    alloc_const_x, alloc_const_y = name_many(
+                        ["alloc_const_x", "alloc_const_y"], [alloc_const_x, alloc_const_y]
+                    )
+            with T.realize(buffer_x[0:var_x, 0:var_y], ""):
+                with T.realize(buffer_x[var_x:128, var_y:128], ""):
+                    pass
+            with T.attr(buffer_x, "key_x", "value_x"):
+                with T.attr(buffer_y, "key_y", "value_y"):
+                    pass
+            with T.launch_thread(thread_x, 4):
+                with T.launch_thread(thread_y, 4):
+                    pass
+            with T.while_(var_x < var_y):
+                with T.while_(var_x > 0):
+                    pass
+            with T.if_(var_x < var_y):
+                with T.then_():
+                    T.evaluate(0)
+                    T.evaluate(1)
+                with T.else_():
+                    T.evaluate(0)
+                    T.evaluate(1)
+            with T.if_(1):
+                with T.then_():
+                    T.evaluate(1)
+            T.prefetch(buffer_x, [Range(0, 64), Range(64, 128)])
+            T.prefetch(buffer_y, [Range(0, var_x), Range(var_y, 128)])
+            T.buffer_store(buffer_x, 1, [0, 0])
+            T.buffer_store(buffer_x, var_x + var_y, [var_x, var_y])
+            T.evaluate(var_x + var_y)
+            T.evaluate(1)
+
+    print(b.get().script())
+
+
+if __name__ == "__main__":
+    test_builder_root_block()
+    test_builder_axis()
+    test_builder_prim_func()
+    test_builder_block()
+    test_builder_for()
+    test_builder_stmt()
diff --git a/tests/python/tvmscript/test_parse_basic.py b/tests/python/tvmscript/test_parse_basic.py
new file mode 100644
index 0000000000..ca3e908ce1
--- /dev/null
+++ b/tests/python/tvmscript/test_parse_basic.py
@@ -0,0 +1,118 @@
+import inspect
+
+import pytest
+import tvm
+from tvm.ir import structural_equal
+from tvm.script.builder import ir as I
+from tvm.script.builder import tir as T
+
+
+def test_parse_elementwise():
+    # pylint: disable=unused-argument,unused-variable,invalid-name
+    @T.prim_func
+    def elementwise(
+        A: T.Buffer[(128, 128), "float32"],
+        B: T.Buffer(shape=(128, 128, 128), dtype="float32"),  # type: ignore
+    ) -> None:
+        for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128):
+            with T.block("inner_block"):
+                # vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                vi = T.axis.S(128, i + 1)
+                vj = T.axis.S(128, j + 20)
+                vk = T.axis.R(128, k - i)
+                A[vi + 1, vj] = A[vi, vk] * B[vvv[0], vvv[1], vvv[2]] + 2
+                B[vi, vj, vk] = A[vvv[0], vvv[-1]]
+
+    # pylint: enable=unused-argument,unused-variable,invalid-name
+
+    result = elementwise
+    print(result.script())
+
+
+def test_parse_skip():
+    class Skip:
+        @T.prim_func
+        def f():  # type: ignore
+            ...
+
+    assert inspect.isfunction(Skip.f)
+
+
+def test_parse_class():
+    # pylint: disable=unused-argument,unused-variable,invalid-name
+    @I.ir_module
+    class C:
+        @T.prim_func
+        def elementwise(
+            A: T.Buffer(shape=(128, 128, 128), dtype="float32"),  # type: ignore
+            B: T.Buffer(shape=(128, 128, 128), dtype="float32"),  # type: ignore
+        ) -> None:
+            for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128):
+                with T.block("inner_block"):
+                    # vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    vi = T.axis.S(128, i + 1)
+                    vj = T.axis.S(128, j + 20)
+                    vk = T.axis.R(128, k - i)
+
+    # pylint: enable=unused-argument,unused-variable,invalid-name
+
+    print(C.script())
+
+
+def test_parse_atomic():
+    @T.prim_func
+    def f(A: T.int32, B: T.int64, C: T.handle) -> None:
+        pass
+
+    assert f.params[0].name == "A"
+    assert f.params[0].dtype == "int32"
+    assert f.params[1].name == "B"
+    assert f.params[1].dtype == "int64"
+    assert f.params[2].name == "C"
+    assert f.params[2].dtype == "handle"
+
+
+def test_parse_report_error():
+    with pytest.raises(tvm.error.DiagnosticError):
+
+        @T.prim_func
+        def elementwise() -> None:
+            for (*vvv,) in T.grid(128, 128, 128, 128, 128, 128, 128):
+                with T.block("inner_block"):
+                    vj = T.axis.S(128, vvv[10] + 20)
+
+
+def test_parse_concise_scope():
+    # pylint: disable=unused-argument,unused-variable,invalid-name
+    @T.prim_func
+    def concise_scope(
+        A: T.handle,
+    ) -> None:
+        A_local = T.allocate([64], "float32", "local")
+        B_local = T.allocate([64], "float32", "local")
+        C_local = T.allocate([64], "float32", "local")
+        T.evaluate(1)
+        T.evaluate(2)
+        T.evaluate(3)
+
+    @T.prim_func
+    def normal_scope(
+        A: T.handle,
+    ) -> None:
+        with T.allocate([64], "float32", "local") as A_local:
+            with T.allocate([64], "float32", "local") as B_local:
+                with T.allocate([64], "float32", "local") as C_local:
+                    T.evaluate(1)
+                    T.evaluate(2)
+                    T.evaluate(3)
+
+    assert structural_equal(normal_scope, concise_scope)
+
+
+if __name__ == "__main__":
+    test_parse_elementwise()
+    test_parse_skip()
+    test_parse_class()
+    test_parse_atomic()
+    test_parse_report_error()
+    test_parse_concise_scope()
diff --git a/tests/python/tvmscript/test_parser_capture.py b/tests/python/tvmscript/test_parser_capture.py
new file mode 100644
index 0000000000..b61d35137f
--- /dev/null
+++ b/tests/python/tvmscript/test_parser_capture.py
@@ -0,0 +1,43 @@
+from tvm.script.builder import ir as I
+from tvm.script.builder import tir as T
+
+
+def test_capture_func():
+    from tvm.script.builder.tir import axis as ax
+    from tvm.script.builder.tir import block, match_buffer
+
+    @T.prim_func
+    def scalar_func(a: T.handle, b: T.handle, c: T.Buffer((128,))):
+        A = match_buffer(a, (128, 128))
+        B = match_buffer(b, (128, 128))
+        with block():
+            for i, j in T.grid(128, 128):
+                with block("inner_block"):
+                    vi, vj = ax.remap("SR", [i, j])
+                    A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]
+
+    print(scalar_func.script())
+
+
+def test_capture_class():
+    from tvm.script.builder.tir import axis as ax
+    from tvm.script.builder.tir import block, match_buffer
+
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def scalar_func(a: T.handle, b: T.handle, c: T.Buffer((128,))):
+            A = match_buffer(a, (128, 128))
+            B = match_buffer(b, (128, 128))
+            with block():
+                for i, j in T.grid(128, 128):
+                    with block("inner_block"):
+                        vi, vj = ax.remap("SR", [i, j])
+                        A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]
+
+    print(Module.script())
+
+
+if __name__ == "__main__":
+    test_capture_func()
+    test_capture_class()