You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/09 00:36:30 UTC

[tvm] branch ir-builder updated (d0ef7c26e9 -> 202cd15c9e)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a change to branch ir-builder
in repository https://gitbox.apache.org/repos/asf/tvm.git


 discard d0ef7c26e9 Squashed
     new 202cd15c9e Squashed

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (d0ef7c26e9)
            \
             N -- N -- N   refs/heads/ir-builder (202cd15c9e)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 tests/python/tvmscript/test_builder_basic.py  | 227 --------------------------
 tests/python/tvmscript/test_parse_basic.py    | 118 -------------
 tests/python/tvmscript/test_parser_capture.py |  43 -----
 3 files changed, 388 deletions(-)
 delete mode 100644 tests/python/tvmscript/test_builder_basic.py
 delete mode 100644 tests/python/tvmscript/test_parse_basic.py
 delete mode 100644 tests/python/tvmscript/test_parser_capture.py


[tvm] 01/01: Squashed

Posted by ju...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch ir-builder
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 202cd15c9e72ec8ea187fa38eff06db4f90b74cc
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Mon May 23 17:06:45 2022 -0700

    Squashed
---
 include/tvm/ir/expr.h                              |   20 +-
 include/tvm/ir/ir_builder.h                        |  188 ++++
 include/tvm/support/with.h                         |    2 +
 include/tvm/tir/ir_builder.h                       |  138 +++
 include/tvm/tir/ir_builder_frame.h                 |  454 ++++++++
 include/tvm/tir/op.h                               |   34 +-
 python/tvm/ir/__init__.py                          |   55 +-
 .../_ffi_api.py => ir/_ffi_ir_builder_api.py}      |    4 +-
 python/tvm/ir/ir_builder.py                        |   85 ++
 python/tvm/script/__init__.py                      |   19 +-
 python/tvm/script/{ => parser}/__init__.py         |   14 +-
 python/tvm/script/parser/diagnostics.py            |   60 ++
 python/tvm/script/parser/dispatch.py               |   63 ++
 python/tvm/script/parser/doc.py                    |  341 ++++++
 python/tvm/script/parser/doc_core.py               | 1140 ++++++++++++++++++++
 .../script/{tir/prim_func.py => parser/entry.py}   |   46 +-
 python/tvm/script/parser/evaluator.py              |  282 +++++
 .../script/{_ffi_api.py => parser/ir/__init__.py}  |    7 +-
 .../tvm/script/{__init__.py => parser/ir/entry.py} |   19 +-
 .../script/{__init__.py => parser/ir/parser.py}    |   24 +-
 python/tvm/script/parser/parser.py                 |  182 ++++
 python/tvm/script/parser/source.py                 |   89 ++
 python/tvm/script/{ => parser/tir}/__init__.py     |    9 +-
 python/tvm/script/parser/tir/entry.py              |   97 ++
 python/tvm/script/parser/tir/operation.py          |   85 ++
 python/tvm/script/parser/tir/parser.py             |  262 +++++
 python/tvm/script/parser/utils.py                  |   63 ++
 python/tvm/script/parser/var_table.py              |   71 ++
 python/tvm/script/{ => parser_v1}/__init__.py      |    3 +-
 python/tvm/script/{ => parser_v1}/_ffi_api.py      |    0
 .../script/{ => parser_v1}/context_maintainer.py   |    8 +-
 python/tvm/script/{ => parser_v1}/diagnostics.py   |    6 +-
 python/tvm/script/{ => parser_v1}/highlight.py     |    0
 python/tvm/script/{ => parser_v1}/meta_unparser.py |    0
 python/tvm/script/{ => parser_v1}/parser.py        |   23 +-
 python/tvm/script/{ => parser_v1}/registry.py      |    2 +-
 python/tvm/script/{ => parser_v1}/tir/__init__.py  |    0
 python/tvm/script/{ => parser_v1}/tir/__init__.pyi |    0
 python/tvm/script/{ => parser_v1}/tir/intrin.py    |    5 +-
 python/tvm/script/{ => parser_v1}/tir/node.py      |    7 +-
 python/tvm/script/{ => parser_v1}/tir/prim_func.py |    3 +-
 .../script/{ => parser_v1}/tir/scope_handler.py    |   17 +-
 .../tvm/script/{ => parser_v1}/tir/special_stmt.py |   16 +-
 python/tvm/script/{ => parser_v1}/tir/ty.py        |    1 +
 python/tvm/script/{ => parser_v1}/utils.py         |    7 +-
 python/tvm/tir/__init__.py                         |  212 +++-
 .../_ffi_api.py => tir/_ffi_ir_builder_api.py}     |    4 +-
 python/tvm/tir/analysis/analysis.py                |    4 +-
 python/tvm/tir/buffer.py                           |   39 +-
 python/tvm/tir/expr.py                             |   15 +-
 python/tvm/tir/ir_builder_frame.py                 |  118 ++
 python/tvm/tir/{ir_builder.py => ir_builder_v1.py} |    6 +-
 python/tvm/tir/ir_builder_v2.py                    |  903 ++++++++++++++++
 python/tvm/tir/op.py                               |  578 +++++++++-
 python/tvm/tir/schedule/block_scope.py             |    2 +-
 python/tvm/tir/schedule/schedule.py                |    6 +-
 python/tvm/tir/schedule/state.py                   |    3 +-
 python/tvm/tir/stmt.py                             |    4 +
 python/tvm/tir/usmp/transform/transform.py         |    5 +-
 src/ir/expr.cc                                     |   13 +
 src/ir/ir_builder.cc                               |  134 +++
 src/tir/ir/expr.cc                                 |   19 +-
 src/tir/ir/script/script_complete.cc               |    5 +-
 src/tir/ir/script/script_complete.h                |   35 +
 src/tir/ir/stmt.cc                                 |    2 +
 src/tir/ir_builder/ir_builder.cc                   |  637 +++++++++++
 src/tir/ir_builder/ir_builder_frame.cc             |  207 ++++
 src/tir/ir_builder/utils.h                         |   92 ++
 src/tir/op/op.cc                                   |   24 +
 69 files changed, 6796 insertions(+), 222 deletions(-)

diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 5e358ed50e..dcabd7d3f1 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -764,16 +764,32 @@ struct PackedFuncValueConverter<PrimExpr> {
       return PrimExpr(ObjectPtr<Object>(nullptr));
     }
     if (val.type_code() == kDLInt) {
-      return PrimExpr(val.operator int());
+      return IntImm(runtime::DataType::Int(32), val.operator int());
     }
     if (val.type_code() == kDLFloat) {
-      return PrimExpr(static_cast<float>(val.operator double()));
+      return FloatImm(runtime::DataType::Float(32), val.operator double());
     }
 
     return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
   }
 };
 
+template <>
+struct PackedFuncValueConverter<Array<PrimExpr>> {
+  static Array<PrimExpr> From(const TVMPODValue_& val) {
+    if (val.type_code() == kTVMNullptr) return Array<PrimExpr>(nullptr);
+    Array<ObjectRef> vals = val.AsObjectRef<Array<ObjectRef>>();
+    Array<PrimExpr> exprs;
+    for (const ObjectRef& v : vals) {
+      TVMValue value;
+      value.v_handle = const_cast<void*>(static_cast<const void*>(v.get()));
+      exprs.push_back(
+          PackedFuncValueConverter<PrimExpr>::From(TVMArgValue(value, kTVMObjectHandle)));
+    }
+    return exprs;
+  }
+};
+
 template <>
 struct PackedFuncValueConverter<tvm::Integer> {
   static tvm::Integer From(const TVMPODValue_& val) {
diff --git a/include/tvm/ir/ir_builder.h b/include/tvm/ir/ir_builder.h
new file mode 100644
index 0000000000..5e099e7310
--- /dev/null
+++ b/include/tvm/ir/ir_builder.h
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_IR_IR_BUILDER_H_
+#define TVM_IR_IR_BUILDER_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+
+namespace tvm {
+namespace ir_builder {
+
+////////////////////////////// Core Infra: Frame and IRBuilder //////////////////////////////
+
+class IRBuilderFrameNode : public runtime::Object {
+ public:
+  std::vector<runtime::TypedPackedFunc<void()>> callbacks;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `callbacks` is not visited.
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRBuilderFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
+
+ public:
+  virtual ~IRBuilderFrameNode() = default;
+  virtual void EnterWithScope();
+  virtual void ExitWithScope();
+
+  void AddCallback(runtime::TypedPackedFunc<void()> callback);
+};
+
+class IRBuilderFrame : public runtime::ObjectRef {
+ public:
+  virtual ~IRBuilderFrame() = default;
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
+
+ protected:
+  IRBuilderFrame() = default;
+
+ public:
+  inline void EnterWithScope();
+  inline void ExitWithScope();
+};
+
+class IRBuilderNode : public runtime::Object {
+ public:
+  runtime::Array<IRBuilderFrame> frames;
+  Optional<ObjectRef> result;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("frames", &frames);
+    v->Visit("result", &result);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRBuilder";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
+
+ public:
+  template <typename TFrame>
+  inline Optional<TFrame> FindFrame() const;
+  template <typename TFrame>
+  inline Optional<TFrame> GetLastFrame() const;
+  template <typename TObjectRef>
+  inline TObjectRef Get() const;
+};
+
+class IRBuilder : public runtime::ObjectRef {
+ public:
+  explicit IRBuilder();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
+
+ public:
+  void EnterWithScope();
+  void ExitWithScope();
+  static IRBuilder Current();
+  template <class TObjectRef>
+  inline static TObjectRef Name(String name, TObjectRef obj);
+};
+
+////////////////////////////// Generic IRModule //////////////////////////////
+
+class IRModuleFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<GlobalVar> global_vars;
+  Array<BaseFunc> functions;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("global_vars", &global_vars);
+    v->Visit("functions", &functions);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRModuleFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IRModuleFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame,
+                                                    IRModuleFrameNode);
+};
+
+TVM_DLL IRModuleFrame IRModule();
+
+////////////////////////////// Details //////////////////////////////
+
+namespace details {
+
+class Namer {
+ public:
+  using FType = NodeFunctor<void(const ObjectRef&, String)>;
+  static FType& vtable();
+  static void Name(ObjectRef node, String name);
+};
+
+}  // namespace details
+
+template <class TObjectRef>
+inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
+  details::Namer::Name(obj, name);
+  return Downcast<TObjectRef>(obj);
+}
+
+inline void IRBuilderFrame::EnterWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
+}
+
+inline void IRBuilderFrame::ExitWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
+  data_.reset();
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::FindFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
+    if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
+      return GetRef<TFrame>(p);
+    }
+  }
+  return NullOpt;
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
+    return Downcast<TFrame>(frames.back());
+  }
+  return NullOpt;
+}
+
+template <typename TObjectRef>
+inline TObjectRef IRBuilderNode::Get() const {
+  using TObject = typename TObjectRef::ContainerType;
+  CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
+  const auto* n = result.as<TObject>();
+  CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key;
+  return GetRef<TObjectRef>(n);
+}
+
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_IR_IR_BUILDER_H_
diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h
index 3651e05e74..bbc36419ff 100644
--- a/include/tvm/support/with.h
+++ b/include/tvm/support/with.h
@@ -75,6 +75,8 @@ class With {
   ContextType& operator*() { return *get(); }
   const ContextType* operator*() const { return *get(); }
 
+  ContextType operator()() { return ctx_; }
+
  private:
   /*! \brief internal context type. */
   ContextType ctx_;
diff --git a/include/tvm/tir/ir_builder.h b/include/tvm/tir/ir_builder.h
new file mode 100644
index 0000000000..5d4c61635c
--- /dev/null
+++ b/include/tvm/tir/ir_builder.h
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_H_
+#define TVM_TIR_IR_BUILDER_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/ir_builder_frame.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+using tvm::runtime::NDArray;
+using tvm::tir::Buffer;
+using tvm::tir::IterVar;
+using tvm::tir::Var;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+                  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+                  String storage_scope, int align, int offset_factor, String buffer_type,
+                  Optional<Array<IntImm>> axis_separators);
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
+BlockFrame Block(String name, bool no_realize = false);
+BlockInitFrame Init();
+void Where(PrimExpr predicate);
+void Reads(Array<ObjectRef> buffer_slices);
+void Writes(Array<ObjectRef> buffer_slices);
+void BlockAttrs(Map<String, ObjectRef> attrs);
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
+                   int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+
+namespace axis {
+IterVar Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Array<IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+}  // namespace axis
+
+ForFrame Serial(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Parallel(PrimExpr start, PrimExpr stop,
+                  Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Vectorized(PrimExpr start, PrimExpr stop,
+                    Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Unroll(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Grid(Array<PrimExpr> extents);
+
+PrimFuncFrame PrimFunc();
+Var Arg(String name, Var var);
+Buffer Arg(String name, Buffer buffer);
+void FuncName(String name);
+void FuncAttrs(Map<String, ObjectRef> attrs);
+Type FuncRet(Type ret_type);
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
+                   int align = -1, int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
+                        DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
+                        Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
+                        String storage_scope = "global", int align = -1, int offset_factor = 0,
+                        String buffer_type = "default", Array<IntImm> axis_separators = {});
+
+AssertFrame Assert(PrimExpr condition, String message);
+LetFrame Let(Var var, PrimExpr value);
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+                       Optional<PrimExpr> condition = NullOpt,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+AllocateConstFrame AllocateConst(
+    NDArray data, DataType dtype, Array<PrimExpr> extents,
+    Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+WhileFrame While(PrimExpr condition);
+IfFrame If(PrimExpr condition);
+ThenFrame Then();
+ElseFrame Else();
+LaunchThreadFrame LaunchThread(IterVar iter_var, PrimExpr extent);
+IterVar EnvThread(String thread_tag);
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+void Prefetch(Buffer buffer, Array<Range> bounds);
+void Evaluate(PrimExpr value);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)                             \
+  inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) {                        \
+    DataType dtype = DType;                                                            \
+    return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
+  }
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
+
+#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_H_
diff --git a/include/tvm/tir/ir_builder_frame.h b/include/tvm/tir/ir_builder_frame.h
new file mode 100644
index 0000000000..549847da6c
--- /dev/null
+++ b/include/tvm/tir/ir_builder_frame.h
@@ -0,0 +1,454 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_FRAME_H_
+#define TVM_TIR_IR_BUILDER_FRAME_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+class TIRFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<tvm::tir::Stmt> stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("stmts", &stmts);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.TIRFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode);
+};
+
+class TIRFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode);
+
+ protected:
+  TIRFrame() = default;
+};
+
+class BlockFrameNode : public TIRFrameNode {
+ public:
+  String name;
+  Array<tvm::tir::IterVar> iter_vars;
+  Optional<Array<tvm::tir::BufferRegion>> reads;
+  Optional<Array<tvm::tir::BufferRegion>> writes;
+  Optional<tvm::tir::Stmt> init;
+  Array<tvm::tir::Buffer> alloc_buffers;
+  Array<tvm::tir::MatchBufferRegion> match_buffers;
+  Map<String, ObjectRef> annotations;
+
+  Array<PrimExpr> iter_values;
+  Optional<PrimExpr> predicate;
+  bool no_realize;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("iter_vars", &iter_vars);
+    v->Visit("reads", &reads);
+    v->Visit("writes", &writes);
+    v->Visit("init", &init);
+    v->Visit("alloc_buffers", &alloc_buffers);
+    v->Visit("match_buffers", &match_buffers);
+    v->Visit("annotations", &annotations);
+    v->Visit("iter_values", &iter_values);
+    v->Visit("predicate", &predicate);
+    v->Visit("no_realize", &no_realize);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.BlockFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class BlockFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
+};
+
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+  static constexpr const char* _type_key = "ir_builder.tir.BlockInitFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class BlockInitFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode);
+};
+
+class ForFrameNode : public TIRFrameNode {
+ public:
+  using FMakeForLoop =
+      runtime::TypedPackedFunc<tvm::tir::Stmt(Array<tvm::tir::Var>, Array<Range>, tvm::tir::Stmt)>;
+
+  Array<tvm::tir::Var> vars;
+  Array<Range> doms;
+  FMakeForLoop f_make_for_loop;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("vars", &vars);
+    v->Visit("doms", &doms);
+    // `f_make_for_loop` is not visited.
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.ForFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class ForFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
+};
+
+class PrimFuncFrameNode : public TIRFrameNode {
+ public:
+  Optional<String> name;
+  Array<tvm::tir::Var> args;
+  Optional<Type> ret_type;
+  Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
+  Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
+  Map<String, ObjectRef> attrs;
+  Array<tvm::tir::Buffer> root_alloc_buffers;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("args", &args);
+    v->Visit("ret_type", &ret_type);
+    v->Visit("buffer_map", &buffer_map);
+    v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
+    v->Visit("attrs", &attrs);
+    v->Visit("root_alloc_buffers", &root_alloc_buffers);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.PrimFuncFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class PrimFuncFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
+};
+
+class AssertFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  PrimExpr message;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("message", &message);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AssertFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AssertFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+class LetFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Var var;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.LetFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LetFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+  Array<PrimExpr> extents;
+  DataType dtype;
+  String storage_scope;
+  PrimExpr condition;
+  Map<String, ObjectRef> annotations;
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extents", &extents);
+    v->Visit("dtype", &dtype);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+    v->Visit("annotations", &annotations);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AllocateFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+  DataType dtype;
+  Array<PrimExpr> extents;
+  tvm::runtime::NDArray data;
+  tvm::tir::Buffer buffer;
+  Map<String, ObjectRef> annotations;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("dtype", &dtype);
+    v->Visit("extents", &extents);
+    v->Visit("data", &data);
+    v->Visit("buffer", &buffer);
+    v->Visit("annotations", &annotations);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AllocateConstFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateConstFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+                                                    AllocateConstFrameNode);
+};
+
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr extent;
+  String attr_key;
+  tvm::tir::IterVar iter_var;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extent", &extent);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("iter_var", &iter_var);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.LaunchThreadFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LaunchThreadFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+                                                    LaunchThreadFrameNode);
+};
+
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::BufferRegion buffer_slice;
+  String storage_scope;
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer_slice", &buffer_slice);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.RealizeFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class RealizeFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
+
+class AttrFrameNode : public TIRFrameNode {
+ public:
+  ObjectRef node;
+  String attr_key;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("node", &node);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AttrFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AttrFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+class WhileFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.WhileFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class WhileFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+class IfFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  Optional<Array<tvm::tir::Stmt>> then_stmts;
+  Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("then_stmts", &then_stmts);
+    v->Visit("else_stmts", &else_stmts);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.IfFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IfFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+class ThenFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "ir_builder.tir.ThenFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ThenFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+class ElseFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "ir_builder.tir.ElseFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ElseFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.DeclBufferFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_FRAME_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 7236c6a611..09758cc923 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
  * \return The result expression.
  */
 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
 /*!
  * \brief Check if x is infinite.
  * \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
                       Span span = Span());
 
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
 /*!
  * \brief Calculate floor(x)
  * \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
                                   Span span = Span());
 
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
 // Intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
   inline PrimExpr OpName(PrimExpr x, Span span = Span()) {              \
@@ -701,6 +732,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
 TVM_DECLARE_INTRIN_UNARY(log);
 TVM_DECLARE_INTRIN_UNARY(log2);
 TVM_DECLARE_INTRIN_UNARY(log10);
+TVM_DECLARE_INTRIN_UNARY(log1p);
 TVM_DECLARE_INTRIN_UNARY(popcount);
 TVM_DECLARE_INTRIN_UNARY(tan);
 TVM_DECLARE_INTRIN_UNARY(cos);
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4e847c0310..51dff32ac6 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -16,29 +16,46 @@
 # under the License.
 # pylint: disable=unused-import
 """Common data structures across all IR variants."""
-from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .base import structural_equal, assert_structural_equal, structural_hash
-from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
-from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
-from .tensor_type import TensorType
-from .affine_type import TensorAffineType, TupleAffineType
-from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
-from .op import Op, register_op_attr, register_intrin_lowering
-from .function import CallingConv, BaseFunc
+from . import diagnostics, instrument, ir_builder, transform
 from .adt import Constructor, TypeData
-from .module import IRModule
+from .affine_type import TensorAffineType, TupleAffineType
 from .attrs import Attrs, DictAttrs, make_node
+from .base import (
+    EnvFunc,
+    Node,
+    SourceName,
+    Span,
+    assert_structural_equal,
+    load_json,
+    save_json,
+    structural_equal,
+    structural_hash,
+)
 from .container import Array, Map
+from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
+from .function import BaseFunc, CallingConv
 from .memory_pools import (
-    PoolInfo,
-    WorkspacePoolInfo,
-    ConstantPoolInfo,
-    WorkspaceMemoryPools,
     ConstantMemoryPools,
+    ConstantPoolInfo,
+    PoolInfo,
     PoolInfoProperties,
+    WorkspaceMemoryPools,
+    WorkspacePoolInfo,
 )
-
-from . import transform
-from . import instrument
-from . import diagnostics
+from .module import IRModule
+from .op import Op, register_intrin_lowering, register_op_attr
+from .tensor_type import TensorType
+from .type import (
+    FuncType,
+    GlobalTypeVar,
+    IncompleteType,
+    PointerType,
+    PrimType,
+    RelayRefType,
+    TupleType,
+    Type,
+    TypeConstraint,
+    TypeKind,
+    TypeVar,
+)
+from .type_relation import TypeCall, TypeRelation
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/ir/_ffi_ir_builder_api.py
similarity index 88%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/ir/_ffi_ir_builder_api.py
index 926d17b166..9d08bc9b70 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/ir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
 import tvm._ffi
 
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/ir/ir_builder.py b/python/tvm/ir/ir_builder.py
new file mode 100644
index 0000000000..db0b1ead6d
--- /dev/null
+++ b/python/tvm/ir/ir_builder.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A generic IRBuilder across the TVM stack"""
+from typing import List, TypeVar
+
+from tvm._ffi import register_object as _register_object
+from tvm.runtime import Object as _Object
+
+from . import _ffi_ir_builder_api as _ffi_api
+
+
+@_register_object("ir_builder.IRBuilderFrame")
+class IRBuilderFrame(_Object):
+    def __enter__(self) -> "IRBuilderFrame":
+        _ffi_api.IRBuilderFrameEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameExit(self)  # pylint: disable=no-member # type: ignore
+
+    def add_callback(self, callback) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameAddCallback(  # pylint: disable=no-member # type: ignore
+            self, callback
+        )
+
+
+@_register_object("ir_builder.IRBuilder")
+class IRBuilder(_Object):
+    def __init__(self) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.IRBuilder  # pylint: disable=no-member # type: ignore
+        )
+
+    def __enter__(self) -> "IRBuilder":
+        _ffi_api.IRBuilderEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderExit(self)  # pylint: disable=no-member # type: ignore
+
+    @staticmethod
+    def current() -> "IRBuilder":
+        return _ffi_api.IRBuilderCurrent()  # pylint: disable=no-member # type: ignore
+
+    def get(self) -> _Object:
+        return _ffi_api.IRBuilderGet(self)  # pylint: disable=no-member # type: ignore
+
+
+DefType = TypeVar("DefType", bound=_Object)
+
+
+def name(s: str, v: DefType) -> DefType:
+    return _ffi_api.IRBuilderName(s, v)  # pylint: disable=no-member # type: ignore
+
+
+def name_many(  # pylint: disable=invalid-name
+    s: List[str],
+    vs: List[DefType],
+) -> List[DefType]:
+    assert len(s) == len(vs)
+    return [name(i, v) for i, v in zip(s, vs)]
+
+
+@_register_object("ir_builder.IRModuleFrame")
+class IRModuleFrame(IRBuilderFrame):
+    ...
+
+
+def ir_module() -> IRModuleFrame:
+    return _ffi_api.IRModule()  # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..8b132dcdf0 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,22 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
+from . import parser, parser_v1
 
-from . import tir
+#############
+from .parser import ir as ir_v2
+from .parser import ir_module as ir_module_v2
+from .parser import parse as from_source_v2
+from .parser import tir as tir_v2
 
-from .parser import ir_module, from_source
+#############
+from .parser_v1 import from_source as from_source_v1
+from .parser_v1 import ir_module as ir_module_v1
+from .parser_v1 import tir as tir_v1
+
+# pylint: disable=invalid-name
+
+ir = ir_v2
+ir_module = ir_module_v2
+tir = tir_v2
+from_source = from_source_v2
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 77%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..d8530e0ab1 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,13 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
+# under the Licens.
+"""The parser"""
+from . import dispatch as _dispatch
+from . import doc as _doc
+from . import ir
+from . import parser as _parser
 from . import tir
-
-from .parser import ir_module, from_source
+from .entry import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/parser/diagnostics.py b/python/tvm/script/parser/diagnostics.py
new file mode 100644
index 0000000000..fd06d20e61
--- /dev/null
+++ b/python/tvm/script/parser/diagnostics.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+from .source import Source
+
+
+class Diagnostics:
+
+    source: Source
+    ctx: diagnostics.DiagnosticContext
+
+    def __init__(self, source: Source):
+        mod = IRModule()
+        mod.source_map.add(source.source_name, source.full_source)
+        self.source = source
+        self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+    def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+        lineno = node.lineno
+        col_offset = node.col_offset
+        end_lineno = node.end_lineno or lineno
+        end_col_offset = node.end_col_offset or col_offset
+        lineno += self.source.start_line - 1
+        end_lineno += self.source.start_line - 1
+        col_offset += self.source.start_column + 1
+        end_col_offset += self.source.start_column + 1
+        self.ctx.emit(
+            diagnostics.Diagnostic(
+                level=level,
+                span=Span(
+                    source_name=SourceName(self.source.source_name),
+                    line=lineno,
+                    end_line=end_lineno,
+                    column=col_offset,
+                    end_column=end_col_offset,
+                ),
+                message=message,
+            )
+        )
+
+    def error(self, node: doc.AST, message: str) -> None:
+        self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+        self.ctx.render()
diff --git a/python/tvm/script/parser/dispatch.py b/python/tvm/script/parser/dispatch.py
new file mode 100644
index 0000000000..f10b90961a
--- /dev/null
+++ b/python/tvm/script/parser/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+    """Register a method for a dispatch token and type name"""
+
+    def f(method: ParseMethod):
+        ParseVTable[(token, type_name)] = method
+
+    return f
+
+
+def get(
+    token: str,
+    type_name: str,
+    default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+    return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int):  # pylint: disable=invalid-name
+    def f(method: OpMethod):
+        OpVTable[(ty, op, operand_index)] = method
+
+    return f
+
+
+def get_op(  # pylint: disable=invalid-name
+    ty: Type,
+    op: Type,
+    operand_index: int,
+    default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+    return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/doc.py b/python/tvm/script/parser/doc.py
new file mode 100644
index 0000000000..929b574b6a
--- /dev/null
+++ b/python/tvm/script/parser/doc.py
@@ -0,0 +1,341 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import *  # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+    to_doc: typing.Optional[FnToDoc]
+    from_doc: typing.Optional[FnFromDoc]
+
+    def __init__(self):
+        self.to_doc = None
+        self.from_doc = None
+
+
+class Registry:
+    _inst: typing.Optional["Registry"] = None
+    table: typing.Dict[str, Entry]
+
+    def __init__(self):
+        self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+    def f(to_doc: FnToDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].to_doc = to_doc
+
+    return f
+
+
+def register_from_doc(name: str):
+    def f(to_doc: FnFromDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].from_doc = to_doc
+
+    return f
+
+
+def _is_atomic_type(node):
+    return (
+        node is None
+        or node in [..., True, False]
+        or isinstance(
+            node,
+            (
+                int,
+                float,
+                str,
+                bool,
+                bytes,
+                complex,
+            ),
+        )
+    )
+
+
+def _get_registry_entry(cls_name, attr):
+    cls_name = cls_name.split(".")[-1]
+    reg = Registry._inst  # pylint: disable=protected-access
+    if cls_name in reg.table:
+        entry = reg.table[cls_name]
+        return getattr(entry, attr, None)
+    return None
+
+
+def from_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(from_doc(n) for n in node)
+    if isinstance(node, list):
+        return [from_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "from_doc")
+    if not func:
+        raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def to_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(to_doc(n) for n in node)
+    if isinstance(node, list):
+        return [to_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "to_doc")
+    if not func:
+        raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def parse(
+    source,
+    filename="<unknown>",
+    mode="exec",
+) -> doc.AST:
+    try:
+        program = ast.parse(  # pylint: disable=unexpected-keyword-arg
+            source=source,
+            filename=filename,
+            mode=mode,
+            feature_version=(3, 8),
+        )
+    except:  # pylint: disable=bare-except
+        program = ast.parse(
+            source=source,
+            filename=filename,
+            mode=mode,
+        )
+    return to_doc(program)
+
+
+class NodeVisitor:
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> None:
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                self.visit(value)
+
+
+class NodeTransformer:
+    def visit(self, node: doc.AST) -> doc.AST:
+        if isinstance(node, list):
+            return [self.visit(item) for item in node]
+        if isinstance(node, tuple):
+            return tuple(self.visit(item) for item in node)
+        if not isinstance(node, doc.AST):
+            return node
+        return getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> doc.AST:
+        kv: typing.Dict[str, typing.Any] = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                value = self.visit(value)
+            kv[field] = value
+        return node.__class__(**kv)
+
+
+def _register_default():
+    class DefaultTranslator:
+        def __init__(self, doc_cls, func, fields):
+            self.doc_cls = doc_cls  # getattr(doc, name)
+            self.func = func
+            self.fields = fields
+
+        def __call__(self, node):
+            kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+            return self.doc_cls(**kv)
+
+    Registry._inst = Registry()  # pylint: disable=protected-access
+    for cls_name in dir(doc):
+        doc_cls = getattr(doc, cls_name)
+        if not hasattr(ast, cls_name):
+            continue
+        if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+            assert "." not in cls_name
+            register_to_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(doc, cls_name),
+                    to_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+            register_from_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(ast, cls_name),
+                    from_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+    return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+    if _py_version() not in [(3, 6), (3, 7)]:
+        return
+
+    def as_constant(f) -> doc.Constant:
+        def to_doc_func(x: ast.AST) -> doc.Constant:
+            return doc.Constant(
+                value=getattr(x, f) if isinstance(f, str) else f(x),
+                kind=None,
+                s=None,
+                n=None,
+                lineno=x.lineno,
+                col_offset=x.col_offset,
+                end_lineno=x.lineno,
+                end_col_offset=x.col_offset,
+            )
+
+        return to_doc_func
+
+    register_to_doc("Str")(as_constant("s"))
+    register_to_doc("NameConstant")(as_constant("value"))
+    register_to_doc("Num")(as_constant("n"))
+    register_to_doc("Bytes")(as_constant("s"))
+    register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+        if isinstance(x.slice, ast.Slice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Slice(
+                    lower=to_doc(x.slice.lower),
+                    upper=to_doc(x.slice.upper),
+                    step=to_doc(x.slice.step),
+                    lineno=getattr(x.slice, "lineno", None),
+                    col_offset=getattr(x.slice, "col_offset", None),
+                    end_lineno=getattr(x.slice, "end_lineno", None),
+                    end_col_offset=getattr(x.slice, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.ExtSlice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Tuple(
+                    elts=[to_doc(i) for i in x.slice.dims],
+                    ctx=doc.Load(
+                        lineno=None,
+                        col_offset=None,
+                        end_lineno=None,
+                        end_col_offset=None,
+                    ),
+                    lineno=getattr(x, "lineno", None),
+                    col_offset=getattr(x, "col_offset", None),
+                    end_lineno=getattr(x, "end_lineno", None),
+                    end_col_offset=getattr(x, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.Index):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=to_doc(x.slice.value),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+    def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+        if isinstance(x.slice, doc.Slice):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=from_doc(x.slice),
+                ctx=from_doc(x.ctx),
+            )
+        elif isinstance(x.slice, doc.Tuple):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.ExtSlice(
+                    dims=[from_doc(i) for i in x.slice.elts],
+                ),
+                ctx=from_doc(x.ctx),
+            )
+        else:
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.Index(value=from_doc(x.slice)),
+                ctx=from_doc(x.ctx),
+            )
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Subscript")(subscript_to_doc)
+    register_from_doc("Subscript")(subscript_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
diff --git a/python/tvm/script/parser/doc_core.py b/python/tvm/script/parser/doc_core.py
new file mode 100644
index 0000000000..b88eef9a0e
--- /dev/null
+++ b/python/tvm/script/parser/doc_core.py
@@ -0,0 +1,1140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-outer-name,missing-docstring,invalid-name
+# pylint: disable=useless-super-delegation,redefined-builtin
+# pylint: disable=too-few-public-methods,too-many-arguments
+class AST:
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__()
+        self.lineno = lineno
+        self.col_offset = col_offset
+        self.end_lineno = end_lineno
+        self.end_col_offset = end_col_offset
+
+
+class mod(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Module(mod):
+    _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+
+
+class Interactive(mod):
+    _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+
+
+class Expression(mod):
+    _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+
+
+class stmt(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class FunctionDef(stmt):
+    _FIELDS = [
+        "name",
+        "args",
+        "body",
+        "decorator_list",
+        "returns",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self,
+        name,
+        args,
+        body,
+        decorator_list,
+        returns,
+        lineno,
+        col_offset,
+        end_lineno,
+        end_col_offset,
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.name = name
+        self.args = args
+        self.body = body
+        self.decorator_list = decorator_list
+        self.returns = returns
+
+
+class ClassDef(stmt):
+    _FIELDS = [
+        "name",
+        "bases",
+        "keywords",
+        "body",
+        "decorator_list",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self,
+        name,
+        bases,
+        keywords,
+        body,
+        decorator_list,
+        lineno,
+        col_offset,
+        end_lineno,
+        end_col_offset,
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.name = name
+        self.bases = bases
+        self.keywords = keywords
+        self.body = body
+        self.decorator_list = decorator_list
+
+
+class Return(stmt):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class Delete(stmt):
+    _FIELDS = ["targets", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, targets, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.targets = targets
+
+
+class Assign(stmt):
+    _FIELDS = ["targets", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, targets, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.targets = targets
+        self.value = value
+
+
+class AugAssign(stmt):
+    _FIELDS = ["target", "op", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, target, op, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.op = op
+        self.value = value
+
+
+class AnnAssign(stmt):
+    _FIELDS = [
+        "target",
+        "annotation",
+        "value",
+        "simple",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self, target, annotation, value, simple, lineno, col_offset, end_lineno, end_col_offset
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.annotation = annotation
+        self.value = value
+        self.simple = simple
+
+
+class For(stmt):
+    _FIELDS = [
+        "target",
+        "iter",
+        "body",
+        "orelse",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(self, target, iter, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.iter = iter
+        self.body = body
+        self.orelse = orelse
+
+
+class While(stmt):
+    _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.test = test
+        self.body = body
+        self.orelse = orelse
+
+
+class If(stmt):
+    _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.test = test
+        self.body = body
+        self.orelse = orelse
+
+
+class With(stmt):
+    _FIELDS = ["items", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, items, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.items = items
+        self.body = body
+
+
+class Raise(stmt):
+    _FIELDS = ["exc", "cause", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, exc, cause, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.exc = exc
+        self.cause = cause
+
+
+class Try(stmt):
+    _FIELDS = [
+        "body",
+        "handlers",
+        "orelse",
+        "finalbody",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self, body, handlers, orelse, finalbody, lineno, col_offset, end_lineno, end_col_offset
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+        self.handlers = handlers
+        self.orelse = orelse
+        self.finalbody = finalbody
+
+
+class Assert(stmt):
+    _FIELDS = ["test", "msg", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, test, msg, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.test = test
+        self.msg = msg
+
+
+class Import(stmt):
+    _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.names = names
+
+
+class ImportFrom(stmt):
+    _FIELDS = ["module", "names", "level", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, module, names, level, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.module = module
+        self.names = names
+        self.level = level
+
+
+class Global(stmt):
+    _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.names = names
+
+
+class Nonlocal(stmt):
+    _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.names = names
+
+
+class Expr(stmt):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class Pass(stmt):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Break(stmt):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Continue(stmt):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class expr(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BoolOp(expr):
+    _FIELDS = ["op", "values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, op, values, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.op = op
+        self.values = values
+
+
+class BinOp(expr):
+    _FIELDS = ["left", "op", "right", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, left, op, right, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.left = left
+        self.op = op
+        self.right = right
+
+
+class UnaryOp(expr):
+    _FIELDS = ["op", "operand", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, op, operand, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.op = op
+        self.operand = operand
+
+
+class Lambda(expr):
+    _FIELDS = ["args", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, args, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.args = args
+        self.body = body
+
+
+class IfExp(expr):
+    _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.test = test
+        self.body = body
+        self.orelse = orelse
+
+
+class Dict(expr):
+    _FIELDS = ["keys", "values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, keys, values, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.keys = keys
+        self.values = values
+
+
+class Set(expr):
+    _FIELDS = ["elts", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elts, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elts = elts
+
+
+class ListComp(expr):
+    _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elt = elt
+        self.generators = generators
+
+
+class SetComp(expr):
+    _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elt = elt
+        self.generators = generators
+
+
+class DictComp(expr):
+    _FIELDS = ["key", "value", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, key, value, generators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.key = key
+        self.value = value
+        self.generators = generators
+
+
+class GeneratorExp(expr):
+    _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elt = elt
+        self.generators = generators
+
+
+class Yield(expr):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class YieldFrom(expr):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class Compare(expr):
+    _FIELDS = ["left", "ops", "comparators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, left, ops, comparators, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.left = left
+        self.ops = ops
+        self.comparators = comparators
+
+
+class Call(expr):
+    _FIELDS = ["func", "args", "keywords", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, func, args, keywords, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.func = func
+        self.args = args
+        self.keywords = keywords
+
+
+class FormattedValue(expr):
+    _FIELDS = [
+        "value",
+        "conversion",
+        "format_spec",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self, value, conversion, format_spec, lineno, col_offset, end_lineno, end_col_offset
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.conversion = conversion
+        self.format_spec = format_spec
+
+
+class JoinedStr(expr):
+    _FIELDS = ["values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, values, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.values = values
+
+
+class Constant(expr):
+    _FIELDS = ["value", "kind", "s", "n", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, kind, s, n, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.kind = kind
+        self.s = s
+        self.n = n
+
+
+class NamedExpr(expr):
+    _FIELDS = ["target", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, target, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.value = value
+
+
+class Attribute(expr):
+    _FIELDS = ["value", "attr", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, attr, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.attr = attr
+        self.ctx = ctx
+
+
+class slice(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Slice(slice):
+    _FIELDS = ["lower", "upper", "step", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lower, upper, step, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.lower = lower
+        self.upper = upper
+        self.step = step
+
+
+class ExtSlice(slice):
+    _FIELDS = ["dims", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, dims, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.dims = dims
+
+
+class Index(slice):
+    _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+
+
+class Subscript(expr):
+    _FIELDS = ["value", "slice", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, slice, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.slice = slice
+        self.ctx = ctx
+
+
+class Starred(expr):
+    _FIELDS = ["value", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, value, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.value = value
+        self.ctx = ctx
+
+
+class Name(expr):
+    _FIELDS = ["id", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, id, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.id = id
+        self.ctx = ctx
+
+
+class List(expr):
+    _FIELDS = ["elts", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elts, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elts = elts
+        self.ctx = ctx
+
+
+class Tuple(expr):
+    _FIELDS = ["elts", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, elts, ctx, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.elts = elts
+        self.ctx = ctx
+
+
+class expr_context(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class AugLoad(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class AugStore(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Param(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Suite(mod):
+    _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.body = body
+
+
+class Del(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Load(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Store(expr_context):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class boolop(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class And(boolop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Or(boolop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class operator(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Add(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitAnd(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitOr(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitXor(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Div(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class FloorDiv(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class LShift(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Mod(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Mult(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class MatMult(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Pow(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class RShift(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Sub(operator):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class unaryop(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Invert(unaryop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Not(unaryop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class UAdd(unaryop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class USub(unaryop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class cmpop(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Eq(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Gt(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class GtE(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class In(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Is(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class IsNot(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Lt(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class LtE(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class NotEq(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class NotIn(cmpop):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class comprehension(AST):
+    _FIELDS = [
+        "target",
+        "iter",
+        "ifs",
+        "is_async",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(self, target, iter, ifs, is_async, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.target = target
+        self.iter = iter
+        self.ifs = ifs
+        self.is_async = is_async
+
+
+class excepthandler(AST):
+    _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class ExceptHandler(excepthandler):
+    _FIELDS = ["type", "name", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, type, name, body, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.type = type
+        self.name = name
+        self.body = body
+
+
+class arguments(AST):
+    _FIELDS = [
+        "args",
+        "vararg",
+        "kwonlyargs",
+        "kw_defaults",
+        "kwarg",
+        "defaults",
+        "posonlyargs",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(
+        self,
+        args,
+        vararg,
+        kwonlyargs,
+        kw_defaults,
+        kwarg,
+        defaults,
+        posonlyargs,
+        lineno,
+        col_offset,
+        end_lineno,
+        end_col_offset,
+    ):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.args = args
+        self.vararg = vararg
+        self.kwonlyargs = kwonlyargs
+        self.kw_defaults = kw_defaults
+        self.kwarg = kwarg
+        self.defaults = defaults
+        self.posonlyargs = posonlyargs
+
+
+class arg(AST):
+    _FIELDS = ["arg", "annotation", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, arg, annotation, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.arg = arg
+        self.annotation = annotation
+
+
+class keyword(AST):
+    _FIELDS = ["arg", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, arg, value, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.arg = arg
+        self.value = value
+
+
+class alias(AST):
+    _FIELDS = ["name", "asname", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+    def __init__(self, name, asname, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.name = name
+        self.asname = asname
+
+
+class withitem(AST):
+    _FIELDS = [
+        "context_expr",
+        "optional_vars",
+        "lineno",
+        "col_offset",
+        "end_lineno",
+        "end_col_offset",
+    ]
+
+    def __init__(self, context_expr, optional_vars, lineno, col_offset, end_lineno, end_col_offset):
+        super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+        self.context_expr = context_expr
+        self.optional_vars = optional_vars
+
+
+__all__ = [
+    "AST",
+    "mod",
+    "Module",
+    "Interactive",
+    "Expression",
+    "stmt",
+    "FunctionDef",
+    "ClassDef",
+    "Return",
+    "Delete",
+    "Assign",
+    "AugAssign",
+    "AnnAssign",
+    "For",
+    "While",
+    "If",
+    "With",
+    "Raise",
+    "Try",
+    "Assert",
+    "Import",
+    "ImportFrom",
+    "Global",
+    "Nonlocal",
+    "Expr",
+    "Pass",
+    "Break",
+    "Continue",
+    "expr",
+    "BoolOp",
+    "BinOp",
+    "UnaryOp",
+    "Lambda",
+    "IfExp",
+    "Dict",
+    "Set",
+    "ListComp",
+    "SetComp",
+    "DictComp",
+    "GeneratorExp",
+    "Yield",
+    "YieldFrom",
+    "Compare",
+    "Call",
+    "FormattedValue",
+    "JoinedStr",
+    "Constant",
+    "NamedExpr",
+    "Attribute",
+    "slice",
+    "Slice",
+    "ExtSlice",
+    "Index",
+    "Subscript",
+    "Starred",
+    "Name",
+    "List",
+    "Tuple",
+    "expr_context",
+    "AugLoad",
+    "AugStore",
+    "Param",
+    "Suite",
+    "Del",
+    "Load",
+    "Store",
+    "boolop",
+    "And",
+    "Or",
+    "operator",
+    "Add",
+    "BitAnd",
+    "BitOr",
+    "BitXor",
+    "Div",
+    "FloorDiv",
+    "LShift",
+    "Mod",
+    "Mult",
+    "MatMult",
+    "Pow",
+    "RShift",
+    "Sub",
+    "unaryop",
+    "Invert",
+    "Not",
+    "UAdd",
+    "USub",
+    "cmpop",
+    "Eq",
+    "Gt",
+    "GtE",
+    "In",
+    "Is",
+    "IsNot",
+    "Lt",
+    "LtE",
+    "NotEq",
+    "NotIn",
+    "comprehension",
+    "excepthandler",
+    "ExceptHandler",
+    "arguments",
+    "arg",
+    "keyword",
+    "alias",
+    "withitem",
+]
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/entry.py
similarity index 50%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/entry.py
index 923eb97d27..7b052b6b32 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/entry.py
@@ -14,32 +14,30 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
 
-import inspect
-from typing import Callable
+from tvm.ir.ir_builder import IRBuilder
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+from . import doc
+from .parser import Parser
+from .source import Source
 
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+    if isinstance(program, str) and extra_vars is None:
+        from tvm.script.parser import ir  # pylint: disable=import-outside-toplevel
+        from tvm.script.parser import tir  # pylint: disable=import-outside-toplevel
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
-
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
-
-    raise TypeError("Only function definitions are supported.")
+        extra_vars = {
+            "I": ir,
+            "ir": ir,
+            "T": tir,
+            "tir": tir,
+        }
+    source = Source(program)
+    parser = Parser(source)
+    with IRBuilder() as builder:
+        parser.parse(extra_vars=extra_vars)
+    return builder.get()
diff --git a/python/tvm/script/parser/evaluator.py b/python/tvm/script/parser/evaluator.py
new file mode 100644
index 0000000000..3899531b21
--- /dev/null
+++ b/python/tvm/script/parser/evaluator.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+    doc.Add: lambda a, b: a + b,
+    doc.Sub: lambda a, b: a - b,
+    doc.Mult: lambda a, b: a * b,
+    doc.Div: lambda a, b: a / b,
+    doc.FloorDiv: lambda a, b: a // b,
+    doc.Mod: lambda a, b: a % b,
+    doc.LShift: lambda a, b: a << b,
+    doc.RShift: lambda a, b: a >> b,
+    doc.BitOr: lambda a, b: a | b,
+    doc.BitXor: lambda a, b: a ^ b,
+    doc.BitAnd: lambda a, b: a & b,
+    doc.MatMult: lambda a, b: a @ b,
+    doc.Pow: lambda a, b: a**b,
+    doc.Eq: lambda a, b: a == b,
+    doc.NotEq: lambda a, b: a != b,
+    doc.Lt: lambda a, b: a < b,
+    doc.LtE: lambda a, b: a <= b,
+    doc.Gt: lambda a, b: a > b,
+    doc.GtE: lambda a, b: a >= b,
+    doc.Is: lambda a, b: a is b,
+    doc.IsNot: lambda a, b: a is not b,
+    doc.In: lambda a, b: a in b,
+    doc.NotIn: lambda a, b: a not in b,
+    doc.And: lambda a, b: a and b,
+    doc.Or: lambda a, b: a or b,
+    doc.Invert: lambda a: ~a,
+    doc.Not: lambda a: not a,
+    doc.UAdd: lambda a: +a,
+    doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+    parser: "Parser"
+    value_table: Dict[str, Any]
+    new_value_count: int
+
+    def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+        super().__init__()
+        self.parser = parser
+        self.value_table = value_table
+        self.new_value_count = 0
+
+    @staticmethod
+    def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+        self = ExprEvaluator(parser, value_table)
+        result = self._visit(node)  # pylint: disable=protected-access
+        if isinstance(result, doc.Name):
+            if result.id not in self.value_table:
+                self.parser.report_error(result, f"Undefined variable: {result.id}")
+            return self.value_table[result.id]
+        if isinstance(result, doc.Constant):
+            return result.value
+        raise TypeError(f"Unexpected result type: {type(result)}")
+
+    def _add_intermediate_result(self, value: Any) -> doc.Name:
+        name = f"__tvm_tmp_value_{self.new_value_count}"
+        self.new_value_count += 1
+        self.value_table[name] = value
+        lineno = 0
+        col_offset = 0
+        return doc.Name(
+            id=name,
+            ctx=doc.Load(
+                lineno=lineno,
+                col_offset=col_offset,
+                end_lineno=None,
+                end_col_offset=None,
+            ),
+            lineno=lineno,
+            col_offset=col_offset,
+            end_lineno=None,
+            end_col_offset=None,
+        )
+
+    def _visit(self, node: doc.AST) -> Any:
+        if isinstance(node, list):
+            return [self._visit(n) for n in node]
+        if isinstance(node, tuple):
+            return tuple(self._visit(n) for n in node)
+        assert isinstance(node, doc.AST)
+        if isinstance(node, doc.Name):
+            if node.id not in self.value_table:
+                self.parser.report_error(node, f"Undefined variable: {node.id}")
+            return node
+        if isinstance(
+            node,
+            (
+                doc.Constant,
+                doc.expr_context,
+                doc.operator,
+                doc.boolop,
+                doc.unaryop,
+                doc.cmpop,
+            ),
+        ):
+            return node
+        if not isinstance(node, (doc.expr, doc.slice)):
+            return node
+        if isinstance(node, doc.Lambda):
+            return self._eval_lambda(node)
+        fields = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            attr = getattr(node, field)
+            if isinstance(attr, (doc.AST, tuple, list)):
+                fields[field] = self._visit(attr)
+            else:
+                fields[field] = attr
+        try:
+            if isinstance(node, doc.BoolOp):
+                value = self._eval_bool_op(fields)
+            elif isinstance(node, doc.Compare):
+                value = self._eval_compare(fields)
+            elif isinstance(node, doc.UnaryOp):
+                value = self._eval_unary_op(fields)
+            elif isinstance(node, doc.BinOp):
+                value = self._eval_bin_op(fields)
+            elif isinstance(node, doc.Slice):
+                value = self._eval_slice(fields)
+            else:
+                value = self._eval_expr(node.__class__(**fields))
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_lambda(self, node: doc.Lambda) -> Any:
+        try:
+            value = self._eval_expr(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+        op = fields["op"]
+        if not isinstance(op, (doc.And, doc.Or)):
+            raise TypeError(f"Unexpected operator: {op}")
+        value = self._eval_expr(fields["values"][0])
+        for rhs in fields["values"][1:]:
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["left"])
+        for op, rhs in zip(fields["ops"], fields["comparators"]):
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["operand"])
+        value = _eval_op(fields["op"], values=[value])
+        return value
+
+    def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+        return _eval_op(
+            fields["op"],
+            values=[
+                self._eval_expr(fields["left"]),
+                self._eval_expr(fields["right"]),
+            ],
+        )
+
+    def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+        lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+        lower = self._eval_expr(lower) if lower is not None else None
+        upper = self._eval_expr(upper) if upper is not None else None
+        step = self._eval_expr(step) if step is not None else None
+
+        return slice(lower, upper, step)
+
+    def _eval_expr(self, v: Any) -> Any:
+        return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+    parser: "Parser",
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    value_table = {}
+    if dict_globals is not None:
+        value_table.update(dict_globals)
+    return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+    parser: "Parser",
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    try:
+        return _eval_assign(target, source)
+    except Exception as e:  # pylint: disable=broad-except,invalid-name
+        parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+        raise
+
+
+def _eval_expr(
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    node = doc.from_doc(node)
+    if isinstance(node, ast.expr):
+        node = ast.Expression(body=node)
+    assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+    if dict_globals is None:
+        dict_globals = {}
+    node = ast.fix_missing_locations(node)
+    exe = compile(node, filename="<ast>", mode="eval")
+    return eval(exe, dict_globals)  # pylint: disable=eval-used
+
+
+def _eval_op(
+    op: doc.AST,
+    values: List[Any],
+):
+    op_type = type(op)  # pylint: disable=protected-access
+    for i, v in enumerate(values):
+        v_type = getattr(type(v), "_dispatch_type", None)
+        if v_type is None:
+            continue
+        f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+        if f is not None:
+            return f(*values)
+    return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    target = doc.from_doc(target)
+    assert isinstance(target, ast.expr)
+    RHS_VAR_NAME = "__tvm_rhs_var__"  # pylint: disable=invalid-name
+    rhs_var_name = RHS_VAR_NAME
+    dict_locals = {rhs_var_name: source}
+    mod = ast.fix_missing_locations(
+        ast.Module(
+            body=[
+                ast.Assign(
+                    targets=[target],
+                    value=ast.Name(
+                        id=rhs_var_name,
+                        ctx=ast.Load(),
+                    ),
+                )
+            ],
+            type_ignores=[],
+        )
+    )
+    exe = compile(mod, filename="<ast>", mode="exec")
+    exec(exe, {}, dict_locals)  # pylint: disable=exec-used
+    del dict_locals[rhs_var_name]
+    return dict_locals
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser/ir/__init__.py
similarity index 89%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser/ir/__init__.py
index 926d17b166..bea08cfb1b 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
-import tvm._ffi
-
-tvm._ffi._init_api("script", __name__)
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/entry.py
similarity index 66%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/entry.py
index 555659d0c5..353963f29b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,8 +14,21 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Type
 
-from . import tir
+from tvm.ir import IRModule
 
-from .parser import ir_module, from_source
+from ..entry import parse
+from ..utils import inspect_class_capture
+
+
+def ir_module(f: Type) -> IRModule:
+    if not inspect.isclass(f):
+        raise TypeError(f"Expect a class, but got: {f}")
+
+    return parse(f, inspect_class_capture(f))
+
+
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 54%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/parser.py
index 555659d0c5..aec203c7d9 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,8 +14,26 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from tvm.ir import ir_builder as I
 
-from . import tir
+from .. import dispatch, doc
+from ..parser import Parser
 
-from .parser import ir_module, from_source
+
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+    with self.var_table.with_frame():
+        with I.ir_module():
+            with self.with_dispatch_token("ir"):
+                self.visit_body(node.body)
+
+
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+    pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+    pass
diff --git a/python/tvm/script/parser/parser.py b/python/tvm/script/parser/parser.py
new file mode 100644
index 0000000000..c0eb79f144
--- /dev/null
+++ b/python/tvm/script/parser/parser.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from typing import Any, Callable, Dict, List, Optional, Union
+
+from ...error import DiagnosticError
+from . import dispatch, doc
+from .diagnostics import Diagnostics
+from .evaluator import eval_assign, eval_expr
+from .source import Source
+from .utils import deferred
+from .var_table import VarTable
+
+DEFAULT_VISIT = {
+    "Interactive",
+    "Module",
+    "Expression",
+    "Pass",
+}
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+    def _wrapper(self: "Parser", node: doc.AST) -> None:
+        try:
+            return func(self, node)
+        except DiagnosticError:
+            raise
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+            raise
+
+    return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+    for token in [self.dispatch_tokens[-1], "default"]:
+        func = dispatch.get(token=token, type_name=type_name, default=None)
+        if func is not None:
+            return _dispatch_wrapper(func)
+    return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+    """The TVMScript parser"""
+
+    diag: Diagnostics
+    dispatch_tokens: List[str]
+    var_table: VarTable
+
+    def __init__(self, source: Source) -> None:
+        self.diag = Diagnostics(source)
+        self.dispatch_tokens = ["default"]
+        self.var_table = VarTable()
+
+    def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+        if extra_vars is None:
+            extra_vars = {}
+        with self.var_table.with_frame():
+            for k, v in extra_vars.items():
+                self.var_table.add(k, v)
+            node = self.diag.source.as_ast()
+            self.visit(node)
+
+    def with_dispatch_token(self, token: str):
+        def pop_token():
+            self.dispatch_tokens.pop()
+
+        self.dispatch_tokens.append(token)
+        return deferred(pop_token)
+
+    def eval_expr(
+        self,
+        node: Union[doc.Expression, doc.expr],
+        extra_vars: Optional[Dict[str, Any]] = None,
+    ) -> Any:
+        var_values = self.var_table.get()
+        if extra_vars is not None:
+            for k, v in extra_vars.items():
+                var_values[k] = v
+        return eval_expr(self, node, var_values)
+
+    def eval_assign(
+        self,
+        target: doc.expr,
+        source: Any,
+        bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+    ) -> Dict[str, Any]:
+        var_values = eval_assign(self, target, source)
+        for k, v in var_values.items():
+            var = bind_value(self, target, k, v)
+            self.var_table.add(k, var)
+        return var_values
+
+    def report_error(self, node: doc.AST, msg: str) -> None:  # pylint: disable=no-self-use
+        self.diag.error(node, msg)
+
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        name = node.__class__.__name__.split(".")[-1]
+        if name in DEFAULT_VISIT:
+            func = self.generic_visit
+        else:
+            func = getattr(self, "visit_" + name, None)
+        if func is None:
+            raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+        func(node)
+
+    def visit_body(self, node: List[doc.stmt]) -> Any:
+        for stmt in node:
+            self.visit(stmt)
+
+    def visit_tvm_annotation(self, node: doc.expr) -> Any:
+        return _dispatch(self, "tvm_annotation")(self, node)
+
+    def visit_FunctionDef(self, node: doc.FunctionDef) -> Any:  # pylint: disable=invalid-name
+        if not node.decorator_list:
+            self.report_error(node, "Function must be decorated")
+        # TODO: only the last decorator is parsed
+        decorator = self.eval_expr(node.decorator_list[-1])
+        if not hasattr(decorator, "dispatch_token"):
+            self.report_error(node, "The parser does not understand the decorator")
+        token = decorator.dispatch_token
+        func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        func(self, node)
+
+    def visit_ClassDef(self, node: doc.ClassDef) -> Any:  # pylint: disable=invalid-name
+        func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        func(self, node)
+
+    def visit_arguments(self, node: doc.arguments) -> Any:
+        return _dispatch(self, "arguments")(self, node)
+
+    def visit_For(self, node: doc.For) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "For")(self, node)
+
+    def visit_While(self, node: doc.While) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "While")(self, node)
+
+    def visit_With(self, node: doc.With) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "With")(self, node)
+
+    def visit_Assign(self, node: doc.Assign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assign")(self, node)
+
+    def visit_Expr(self, node: doc.Expr) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Expr")(self, node)
+
+    def visit_If(self, node: doc.If) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "If")(self, node)
+
+    def visit_AnnAssign(self, node: doc.AnnAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AnnAssign")(self, node)
+
+    def visit_AugAssign(self, node: doc.AugAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AugAssign")(self, node)
+
+    def visit_Assert(self, node: doc.Assert) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assert")(self, node)
diff --git a/python/tvm/script/parser/source.py b/python/tvm/script/parser/source.py
new file mode 100644
index 0000000000..936d20e610
--- /dev/null
+++ b/python/tvm/script/parser/source.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+import sys
+from typing import Union
+
+from . import doc
+
+
+class Source:
+    source_name: str
+    start_line: int
+    start_column: int
+    source: str
+    full_source: str
+
+    def __init__(self, program: Union[str, doc.AST]):
+        if isinstance(program, str):
+            self.source_name = "<str>"
+            self.start_line = 1
+            self.start_column = 0
+            self.source = program
+            self.full_source = program
+            return
+
+        self.source_name = inspect.getsourcefile(program)  # type: ignore
+        lines, self.start_line = inspect.getsourcelines(program)  # type: ignore
+        if lines:
+            self.start_column = len(lines[0]) - len(lines[0].lstrip())
+        else:
+            self.start_column = 0
+        if self.start_column and lines:
+            self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+        else:
+            self.source = "".join(lines)
+        try:
+            # It will cause a problem when running in Jupyter Notebook.
+            # `mod` will be <module '__main__'>, which is a built-in module
+            # and `getsource` will throw a TypeError
+            mod = inspect.getmodule(program)
+            if mod:
+                self.full_source = inspect.getsource(mod)
+            else:
+                self.full_source = self.source
+        except TypeError:
+            # It's a work around for Jupyter problem.
+            # Since `findsource` is an internal API of inspect, we just use it
+            # as a fallback method.
+            src, _ = inspect.findsource(program)  # type: ignore
+            self.full_source = "".join(src)
+
+    def as_ast(self) -> doc.AST:
+        return doc.parse(self.source)
+
+
+_getfile = inspect.getfile  # pylint: disable=invalid-name
+
+
+def _patched_inspect_getfile(obj):
+    if not inspect.isclass(obj):
+        return _getfile(obj)
+    mod = getattr(obj, "__module__", None)
+    if mod is not None:
+        file = getattr(sys.modules[mod], "__file__", None)
+        if file is not None:
+            return file
+    for _, member in inspect.getmembers(obj):
+        if inspect.isfunction(member):
+            if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+                return inspect.getfile(member)
+    raise TypeError(f"Source for {obj:!r} not found")
+
+
+inspect.getfile = _patched_inspect_getfile
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 78%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..caa51744c8 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from tvm.tir.ir_builder_v2 import *  # pylint: disable=redefined-builtin
 
-from . import tir
-
-from .parser import ir_module, from_source
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..397d014360
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+from tvm.tir.ir_builder_v2 import buffer_decl, ptr
+
+from ..entry import parse
+from ..utils import inspect_function_capture
+
+
+def _is_defined_in_class(frames):
+    if len(frames) > 2:
+        maybe_class_frame = frames[2]
+        statement_list = maybe_class_frame[4]
+        first_statement = statement_list[0]
+        line = first_statement.strip()
+        if line.startswith("class "):
+            return True
+        if line.startswith("@") and "ir_module" in line:
+            return True
+    return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+    if not inspect.isfunction(f):
+        raise TypeError(f"Expect a function, but got: {f}")
+    if _is_defined_in_class(inspect.stack()):
+        return f
+    return parse(f, inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+    def __call__(
+        self,
+        shape,
+        dtype="float32",
+        data=None,
+        strides=None,
+        elem_offset=None,
+        scope="global",
+        align=0,
+        offset_factor=0,
+        buffer_type="",
+        axis_separators=None,
+    ) -> Buffer:
+        return buffer_decl(
+            shape,
+            dtype=dtype,
+            data=data,
+            strides=strides,
+            elem_offset=elem_offset,
+            scope=scope,
+            align=align,
+            offset_factor=offset_factor,
+            buffer_type=buffer_type,
+            axis_separators=axis_separators,
+        )
+
+    def __getitem__(self, keys) -> Buffer:
+        return self(*keys)  # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+    def __call__(self, dtype, storage_scope="global"):
+        if callable(dtype):
+            dtype = dtype().dtype
+        return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+    def __getitem__(self, keys):
+        if not isinstance(keys, tuple):
+            return self(keys)
+        return self(*keys)
+
+
+Buffer = BufferProxy()  # pylint: disable=invalid-name
+Ptr = PtrProxy()  # pylint: disable=invalid-name
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..11ee92ad29
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .. import doc
+from ..dispatch import OpMethod, register_op
+
+
+def _register_expr_op(ty: Type):  # pylint: disable=invalid-name
+    ty._dispatch_type = ty  # pylint: disable=protected-access
+
+    def _and(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.And(a, b)
+
+    def _or(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.Or(a, b)
+
+    def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
+        register_op(ty, op, i)(m)
+
+    for i in [0, 1]:
+        # Case 1. binop
+        r(doc.Add, i, tir.Add)
+        r(doc.Sub, i, tir.Sub)
+        r(doc.Mult, i, tir.Mul)
+        r(doc.Div, i, tir.Div)
+        r(doc.FloorDiv, i, tir.FloorDiv)
+        r(doc.Mod, i, tir.FloorMod)
+        r(doc.LShift, i, lambda a, b: a << b)
+        r(doc.RShift, i, lambda a, b: a >> b)
+        r(doc.BitOr, i, lambda a, b: a | b)
+        r(doc.BitXor, i, lambda a, b: a ^ b)
+        r(doc.BitAnd, i, lambda a, b: a & b)
+        # doc.MatMult <-- not implemented
+        # doc.Pow <-- not implemented
+        # Case 2. cmpop
+        r(doc.Eq, i, tir.EQ)
+        r(doc.NotEq, i, tir.NE)
+        r(doc.Lt, i, tir.LT)
+        r(doc.LtE, i, tir.LE)
+        r(doc.Gt, i, tir.GT)
+        r(doc.GtE, i, tir.GE)
+        # doc.Is <-- not implemented
+        # doc.IsNot <-- not implemented
+        # doc.In <-- not implemented
+        # doc.NotIn <-- not implemented
+        # Case 3. boolop
+        r(doc.And, i, _and)
+        r(doc.Or, i, _or)
+    for i in [0]:
+        #  Case 4. unaryop
+        r(doc.Invert, i, lambda a: ~a)
+        r(doc.Not, i, tir.Not)
+        r(doc.UAdd, i, lambda a: +a)
+        r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..884056d62b
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.ir.ir_builder import IRBuilderFrame as Frame
+from tvm.ir.ir_builder import name
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+from tvm.tir import ir_builder_v2 as T
+
+from .. import dispatch, doc
+from ..parser import Parser
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, (Buffer, Var)):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+        raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Var):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+        raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, _node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Frame):
+        value.add_callback(partial(value.__exit__, None, None, None))
+        res = value.__enter__()
+        name(var_name, res)
+        return res
+    elif isinstance(value, (Buffer, IterVar)) or (
+        isinstance(value, Var) and not self.var_table.exist(value)
+    ):
+        name(var_name, value)
+        return value
+    elif isinstance(value, PrimExpr):
+        var = T.var(value.dtype)
+        name(var_name, var)
+        frame = T.let(var, value)
+        frame.add_callback(partial(frame.__exit__, None, None, None))
+        frame.__enter__()
+        return var
+    return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+    for_frame = self.eval_expr(node.iter)
+    if not isinstance(for_frame, T.frame.ForFrame):
+        self.report_error(
+            node.iter,
+            "Expect the for loop to be one of the following: "
+            "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+        )
+    with self.var_table.with_frame():
+        with for_frame as iters:
+            self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+    with self.var_table.with_frame():
+        cond = self.eval_expr(node.test)
+        with T.While(cond):
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+    lhs = node.targets[0]
+    rhs = self.eval_expr(node.value)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+    lhs_pos = (
+        node.target.lineno,
+        node.target.col_offset,
+        node.target.end_lineno,
+        node.target.end_col_offset,
+    )
+    rhs_pos = (
+        node.value.lineno,
+        node.value.col_offset,
+        node.value.end_lineno,
+        node.value.end_col_offset,
+    )
+    node.target.ctx = doc.Load(*lhs_pos)
+    with self.var_table.with_frame():
+        lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+        rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+        lhs_expr = self.eval_expr(node.target)
+        rhs_expr = self.eval_expr(node.value)
+        self.var_table.add(lhs_name, lhs_expr)
+        self.var_table.add(rhs_name, rhs_expr)
+        op = doc.BinOp(
+            doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+            node.op,
+            doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+            *lhs_pos,
+        )
+        rhs = self.eval_expr(op)
+    lhs = node.target
+    lhs.ctx = doc.Store(*lhs_pos)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+    lhs = node.target
+    rhs = self.eval_expr(node.value)
+    ann_var = self.visit_tvm_annotation(node.annotation)
+    if not isinstance(ann_var, Var):
+        self.report_error(node.annotation, "Annotation should be Var")
+    self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+    frame = T.let(ann_var, rhs)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+    with contextlib.ExitStack() as stack:
+        stack.enter_context(self.var_table.with_frame())
+        for item in node.items:
+            frame = self.eval_expr(item.context_expr)
+            if not isinstance(frame, Frame):
+                self.report_error(
+                    item.context_expr, "Invalid context expression in the with-statement."
+                )
+            rhs = stack.enter_context(frame)
+            if item.optional_vars is not None:
+                self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+        self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    with self.var_table.with_frame():
+        self.var_table.add("range", T.serial)
+        with T.prim_func():
+            T.func_name(node.name)
+            if node.returns is not None:
+                ret_type = self.eval_expr(node.returns)
+                if callable(ret_type):
+                    ret_type = PrimType(ret_type().dtype)
+                T.func_ret(ret_type)
+            with self.with_dispatch_token("tir"):
+                self.visit(node.args)
+                self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+    # TODO: handle different types of arguments:
+    # - vararg: arg | None
+    # - kwonlyargs: list[arg]
+    # - kw_defaults: list[expr | None]
+    # - kwarg: arg | None
+    # - defaults: list[expr]
+    # - posonlyargs: list[arg]
+    arg: doc.arg
+    for arg in node.args:
+        if arg.annotation is None:
+            self.report_error(arg, "Type annotation is required for function parameters.")
+        param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+        self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+    annotation = self.eval_expr(node)
+    if callable(annotation):
+        annotation = annotation()
+    return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+    res = self.eval_expr(node.value)
+    if isinstance(res, Frame):
+        res.add_callback(partial(res.__exit__, None, None, None))
+        res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+    with self.var_table.with_frame():
+        with T.If(self.eval_expr(node.test)):
+            with T.Then():
+                self.visit_body(node.body)
+            if node.orelse:
+                with T.Else():
+                    self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+    cond = self.eval_expr(node.test)
+    msg = self.eval_expr(node.msg)
+    frame = T.Assert(cond, msg)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
diff --git a/python/tvm/script/parser/utils.py b/python/tvm/script/parser/utils.py
new file mode 100644
index 0000000000..9d681236c8
--- /dev/null
+++ b/python/tvm/script/parser/utils.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+from contextlib import contextmanager
+from typing import Any, Callable, Dict
+
+
+def deferred(f: Callable[[], None]):
+    @contextmanager
+    def context():
+        try:
+            yield
+        finally:
+            f()
+
+    return context()
+
+
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+    prefix = "tvm."
+    result = {}
+    captured = {
+        **inspect.getclosurevars(func).nonlocals,
+        **func.__globals__,
+    }
+    for k, v in captured.items():
+        # Case 1: a module like `T` or `tvm.tir.ir_builder`
+        if inspect.ismodule(v) and v.__name__.startswith(prefix):
+            result[k] = v
+            continue
+        # Case 2: a function like `T.match_buffer`
+        if hasattr(v, "__module__") and v.__module__.startswith(prefix):
+            result[k] = v
+            continue
+        # Case 3: atomic types
+        if v is None or isinstance(v, (int, float, str, bool)):
+            result[k] = v
+            continue
+    return result
+
+
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+    result: Dict[str, Any] = {}
+    for _, v in cls.__dict__.items():
+        if inspect.isfunction(v):
+            func_vars = inspect_function_capture(v)
+            result.update(**func_vars)
+    return result
diff --git a/python/tvm/script/parser/var_table.py b/python/tvm/script/parser/var_table.py
new file mode 100644
index 0000000000..32fced625a
--- /dev/null
+++ b/python/tvm/script/parser/var_table.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The symbol table of variable values"""
+
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Set
+
+from .utils import deferred
+
+
+class VarTableFrame:
+    vars: Set[str]
+
+    def __init__(self):
+        self.vars = set()
+
+    def add(self, var: str):
+        if var in self.vars:
+            raise ValueError(f"Variable {var} already defined in current scope")
+        self.vars.add(var)
+
+    def pop_all(self, fn_pop: Callable[[str], None]):
+        for var in self.vars:
+            fn_pop(var)
+        self.vars.clear()
+
+
+class VarTable:
+
+    frames: List[VarTableFrame]
+    name2value: Dict[str, List[Any]]
+
+    def __init__(self):
+        self.frames = []
+        self.name2value = defaultdict(list)
+
+    def with_frame(self):
+        def pop_frame():
+            frame = self.frames.pop()
+            frame.pop_all(lambda name: self.name2value[name].pop())
+
+        self.frames.append(VarTableFrame())
+        return deferred(pop_frame)
+
+    def add(self, var: str, value: Any):
+        self.frames[-1].add(var)
+        self.name2value[var].append(value)
+
+    def get(self) -> Dict[str, Any]:
+        return {key: values[-1] for key, values in self.name2value.items() if values}
+
+    def exist(self, value: Any):
+        for v in self.name2value.values():
+            if v is value:
+                return True
+        return False
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser_v1/__init__.py
similarity index 95%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser_v1/__init__.py
index 555659d0c5..004e947bf6 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser_v1/__init__.py
@@ -17,5 +17,4 @@
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
 
 from . import tir
-
-from .parser import ir_module, from_source
+from .parser import from_source, ir_module
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser_v1/_ffi_api.py
similarity index 100%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser_v1/_ffi_api.py
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py
similarity index 98%
rename from python/tvm/script/context_maintainer.py
rename to python/tvm/script/parser_v1/context_maintainer.py
index f7f16855c7..400baacc4b 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/parser_v1/context_maintainer.py
@@ -16,16 +16,16 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
+from typing import Callable, Dict, List, Mapping, Optional, Union
 
+import synr
 import tvm
 from tvm.ir import Span
 from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
 from tvm.runtime import Object
+from tvm.tir import Buffer, MatchBufferRegion, PrimExpr, Stmt, Var
 from tvm.tir.expr import IterVar
+
 from .tir.node import BufferSlice
 
 
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/parser_v1/diagnostics.py
similarity index 95%
rename from python/tvm/script/diagnostics.py
rename to python/tvm/script/parser_v1/diagnostics.py
index e676461ab3..b15997552f 100644
--- a/python/tvm/script/diagnostics.py
+++ b/python/tvm/script/parser_v1/diagnostics.py
@@ -17,11 +17,11 @@
 """Bridge from synr's (the library used for parsing the python AST)
    DiagnosticContext to TVM's diagnostics
 """
-from synr import DiagnosticContext, ast
-
 import tvm
+from synr import DiagnosticContext, ast
+from tvm.ir.diagnostics import Diagnostic
 from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
+from tvm.ir.diagnostics import DiagnosticLevel, get_renderer
 
 
 class TVMDiagnosticCtx(DiagnosticContext):
diff --git a/python/tvm/script/highlight.py b/python/tvm/script/parser_v1/highlight.py
similarity index 100%
rename from python/tvm/script/highlight.py
rename to python/tvm/script/parser_v1/highlight.py
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/parser_v1/meta_unparser.py
similarity index 100%
rename from python/tvm/script/meta_unparser.py
rename to python/tvm/script/parser_v1/meta_unparser.py
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser_v1/parser.py
similarity index 99%
rename from python/tvm/script/parser.py
rename to python/tvm/script/parser_v1/parser.py
index 908af081c9..b2b7e388f1 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser_v1/parser.py
@@ -20,35 +20,34 @@ We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
 different python versions. Synr also provides an error handling context that we
 use for error reporting.
 """
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
+import inspect
 import json
 import operator
-import inspect
+
+# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
+import types
 from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
 
 import tvm
+from synr import Transformer, ast, to_ast
 from tvm import IRModule
 from tvm._ffi.base import TVMError
 from tvm.ir import GlobalVar
 from tvm.ir.function import BaseFunc
 from tvm.tir import buffer
 from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
 
+from . import _ffi_api, tir
 from .context_maintainer import ContextMaintainer
+from .diagnostics import TVMDiagnosticCtx
 from .meta_unparser import MetaUnparser
 from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
+from .tir import ty
 from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
+from .tir.node import BufferSlice, Slice
+from .tir.scope_handler import ForScopeHandler, ScopeHandler, WithScopeHandler
 from .tir.special_stmt import SpecialStmt
-from .tir import ty
+from .utils import call_with_error_reporting, synr_span_from_tvm, tvm_span_from_synr
 
 
 class CallArgumentReader(object):
diff --git a/python/tvm/script/registry.py b/python/tvm/script/parser_v1/registry.py
similarity index 97%
rename from python/tvm/script/registry.py
rename to python/tvm/script/parser_v1/registry.py
index e7d90dd515..e816b90f5d 100644
--- a/python/tvm/script/registry.py
+++ b/python/tvm/script/parser_v1/registry.py
@@ -17,7 +17,7 @@
 """TVM Script Parser Function Registry """
 # pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
 import types
-from typing import Union, Callable, Dict, Optional, Any
+from typing import Any, Callable, Dict, Optional, Union
 
 
 class Registry(object):
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser_v1/tir/__init__.py
similarity index 100%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser_v1/tir/__init__.py
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi
similarity index 100%
rename from python/tvm/script/tir/__init__.pyi
rename to python/tvm/script/parser_v1/tir/__init__.pyi
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py
similarity index 99%
rename from python/tvm/script/tir/intrin.py
rename to python/tvm/script/parser_v1/tir/intrin.py
index 627b89086a..008d7495a0 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/parser_v1/tir/intrin.py
@@ -17,11 +17,12 @@
 """TVM Script Parser Intrinsic Classes"""
 # pylint: disable=redefined-builtin, relative-beyond-top-level
 import builtins
-from typing import List, Any
+from typing import Any, List
 
 import tvm.tir
+from tvm.target import codegen
+
 from ..registry import register
-from ...target import codegen
 from ..utils import get_param_list, tvm_span_from_synr
 
 
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/parser_v1/tir/node.py
similarity index 98%
rename from python/tvm/script/tir/node.py
rename to python/tvm/script/parser_v1/tir/node.py
index 29e79607fb..cfaf9df476 100644
--- a/python/tvm/script/tir/node.py
+++ b/python/tvm/script/parser_v1/tir/node.py
@@ -17,12 +17,13 @@
 # pylint: disable=redefined-builtin
 """TVM Script nodes."""
 
-from typing import Optional, Union, List, Callable
+from typing import Callable, List, Optional, Union
+
 import synr
 from tvm.arith import Analyzer
+from tvm.ir import Range, Span
 from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
+from tvm.tir import Buffer, BufferLoad, BufferRegion, IntImm, PrimExpr, Ramp
 
 
 class Slice:
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser_v1/tir/prim_func.py
similarity index 97%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser_v1/tir/prim_func.py
index 923eb97d27..a5fdcc15c5 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser_v1/tir/prim_func.py
@@ -19,7 +19,8 @@
 import inspect
 from typing import Callable
 
-from tvm.tir.function import PrimFunc
+from tvm.tir import PrimFunc
+
 from ..parser import from_source
 
 
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/parser_v1/tir/scope_handler.py
similarity index 98%
rename from python/tvm/script/tir/scope_handler.py
rename to python/tvm/script/parser_v1/tir/scope_handler.py
index da7545c9a9..cd2167f4ab 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/parser_v1/tir/scope_handler.py
@@ -16,24 +16,19 @@
 # under the License.
 """TVM Script Parser Scope Handler Classes"""
 # pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
 
-import synr
 import numpy as np
+import synr
 import tvm.tir
+from tvm.ir import Range, Span
 from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
-
-from .node import BufferSlice
+from tvm.tir import Buffer, BufferRegion, ForKind, IterVar, PrimExpr, Stmt, Var
 
 from ..context_maintainer import ContextMaintainer
 from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
 
 
 class ScopeHandler:
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py
similarity index 99%
rename from python/tvm/script/tir/special_stmt.py
rename to python/tvm/script/parser_v1/tir/special_stmt.py
index 15502055b7..42a90f647f 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/parser_v1/tir/special_stmt.py
@@ -17,27 +17,21 @@
 """TVM Script Parser Special Stmt Classes"""
 # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
 # pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
 
 import synr
+import tvm.tir
 from synr import ast
+from tvm.ir import Span
 from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
 from tvm.runtime import Object, String
 from tvm.target import Target
-from tvm.ir import Span
 from tvm.tir import IntImm, IterVar, Var
 
-from .node import BufferSlice
-
 from ..context_maintainer import BlockInfo, ContextMaintainer
 from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
 
 
 def convert_to_int(
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/parser_v1/tir/ty.py
similarity index 99%
rename from python/tvm/script/tir/ty.py
rename to python/tvm/script/parser_v1/tir/ty.py
index a64485b215..8a048c07b5 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/parser_v1/tir/ty.py
@@ -23,6 +23,7 @@ a wrapper for uniform Type system in IR
 from numbers import Integral
 
 import tvm
+
 from .special_stmt import SpecialStmt, convert_to_int
 
 
diff --git a/python/tvm/script/utils.py b/python/tvm/script/parser_v1/utils.py
similarity index 97%
rename from python/tvm/script/utils.py
rename to python/tvm/script/parser_v1/utils.py
index c655a62237..f358a90081 100644
--- a/python/tvm/script/utils.py
+++ b/python/tvm/script/parser_v1/utils.py
@@ -16,13 +16,12 @@
 # under the License.
 """Helper functions in TVM Script Parser"""
 
-from typing import Callable, List, Any, Optional, Tuple
-
 import inspect
-import synr
+from typing import Any, Callable, List, Optional, Tuple
 
-from tvm.ir import Span, SourceName
+import synr
 from tvm.error import DiagnosticError
+from tvm.ir import SourceName, Span
 
 
 def get_param_list(
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index c64b7dfe71..0d949234d8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -19,50 +19,180 @@
 from tvm.ir import PrimExpr
 from tvm.runtime import const
 
-from .buffer import Buffer, decl_buffer, DataProducer
-from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
-from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
-from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
-from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
-from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
-
-from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
+from . import analysis
+from . import ir_builder_v1 as ir_builder
+from . import schedule, stmt_functor, transform, usmp
+from .buffer import Buffer, DataProducer, decl_buffer
+from .data_layout import BijectiveLayout, Layout, bijective_layout, layout
+from .expr import (
+    EQ,
+    GE,
+    GT,
+    LE,
+    LT,
+    NE,
+    Add,
+    And,
+    Any,
+    Broadcast,
+    BufferLoad,
+    Call,
+    CallEffectKind,
+    Cast,
+    CommReducer,
+    Div,
+    FloatImm,
+    FloorDiv,
+    FloorMod,
+    IntImm,
+    IterVar,
+    Let,
+    Load,
+    Max,
+    Min,
+    Mod,
+    Mul,
+    Not,
+    Or,
+    ProducerLoad,
+    Ramp,
+    Reduce,
+    Select,
+    Shuffle,
+    SizeVar,
+    StringImm,
+    Sub,
+    Var,
+)
+from .function import IndexMap, PrimFunc, TensorIntrin
+from .op import (
+    TVMBackendAllocWorkspace,
+    TVMBackendFreeWorkspace,
+    abs,
+    acos,
+    acosh,
+    address_of,
+    all,
+    any,
+    asin,
+    asinh,
+    atan,
+    atan2,
+    atanh,
+    call_cpacked,
+    call_cpacked_lowered,
+    call_extern,
+    call_intrin,
+    call_llvm_intrin,
+    call_llvm_pure_intrin,
+    call_packed,
+    call_packed_lowered,
+    call_pure_extern,
+    ceil,
+    ceildiv,
+    clz,
+    comm_reducer,
+    copysign,
+    cos,
+    cosh,
+    div,
+    erf,
+    exp,
+    exp2,
+    exp10,
+    floor,
+    floordiv,
+    floormod,
+    fmod,
+    hypot,
+    if_then_else,
+    indexdiv,
+    indexmod,
+    isfinite,
+    isinf,
+    isnan,
+    isnullptr,
+    ldexp,
+    likely,
+    log,
+    log1p,
+    log2,
+    log10,
+    lookup_param,
+    max,
+    max_value,
+    min,
+    min_value,
+    mma_fill,
+    mma_store,
+    nearbyint,
+    nextafter,
+    popcount,
+    power,
+    ptx_commit_group,
+    ptx_cp_async,
+    ptx_ldmatrix,
+    ptx_mma,
+    ptx_mma_sp,
+    ptx_wait_group,
+    q_multiply_shift,
+    ret,
+    round,
+    rsqrt,
+    shift_left,
+    shift_right,
+    sigmoid,
+    sin,
+    sinh,
+    sqrt,
+    sum,
+    tan,
+    tanh,
+    trace,
+    trunc,
+    truncdiv,
+    truncmod,
+    tvm_access_ptr,
+    tvm_bmma_sync,
+    tvm_fill_fragment,
+    tvm_load_matrix_sync,
+    tvm_mma_sync,
+    tvm_stack_alloca,
+    tvm_stack_make_array,
+    tvm_stack_make_shape,
+    tvm_store_matrix_sync,
+    tvm_struct_get,
+    tvm_struct_set,
+    tvm_thread_allreduce,
+    tvm_throw_last_error,
+    tvm_tuple,
+)
+from .schedule import BlockScope, Schedule, ScheduleError, ScheduleState, StmtSRef
 from .stmt import (
-    BufferStore,
-    BufferRealize,
-    Store,
-    ProducerStore,
     Allocate,
     AllocateConst,
+    AssertStmt,
     AttrStmt,
+    Block,
+    BlockRealize,
+    BufferRealize,
+    BufferRegion,
+    BufferStore,
     DeclBuffer,
+    Evaluate,
+    For,
+    ForKind,
+    IfThenElse,
+    LetStmt,
+    MatchBufferRegion,
+    Prefetch,
+    ProducerRealize,
+    ProducerStore,
+    SeqStmt,
+    Stmt,
+    Store,
+    While,
+    stmt_list,
+    stmt_seq,
+    type_annotation,
 )
-
-from .stmt import ProducerRealize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
-from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-
-from .function import PrimFunc, TensorIntrin, IndexMap
-
-from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
-from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
-from .op import sin, sinh, asin, asinh
-from .op import cos, cosh, acos, acosh
-from .op import tan, tanh, atan, atan2, atanh
-from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
-from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
-from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
-from .op import comm_reducer, min, max, sum
-from .op import q_multiply_shift
-
-from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
-
-from . import schedule
-from . import ir_builder
-from . import transform
-from . import analysis
-from . import stmt_functor
-from . import usmp
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/tir/_ffi_ir_builder_api.py
similarity index 88%
rename from python/tvm/script/_ffi_api.py
rename to python/tvm/tir/_ffi_ir_builder_api.py
index 926d17b166..61b288d498 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/tir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
 import tvm._ffi
 
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder.tir", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 13674daa24..ea220fea22 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -21,9 +21,9 @@ from typing import Dict, List, Union
 from tvm import Object
 from tvm.ir import IRModule
 from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
+from tvm.tir.stmt import Block, BufferRegion, PrimExpr, Stmt
 
-from .. import Buffer, Stmt
+from ..buffer import Buffer
 from ..function import PrimFunc
 from . import _ffi_api
 
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index d9b0aec76a..e74eb15453 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -16,11 +16,12 @@
 # under the License.
 """Abstraction for array data structures."""
 from numbers import Integral
-import tvm._ffi
 
+import tvm._ffi
 from tvm._ffi.base import string_types
+from tvm.ir import PointerType, PrimExpr, PrimType, Range
 from tvm.runtime import Object, convert
-from tvm.ir import PrimExpr, PointerType, PrimType
+
 from . import _ffi_api
 
 
@@ -176,6 +177,40 @@ class Buffer(Object):
         """
         return _ffi_api.BufferOffsetOf(self, indices)  # type: ignore
 
+    def __getitem__(self, indices):
+        from ..arith import Analyzer  # pylint: disable=import-outside-toplevel
+        from .expr import BufferLoad, Ramp  # pylint: disable=import-outside-toplevel
+        from .stmt import BufferRegion  # pylint: disable=import-outside-toplevel
+
+        if not isinstance(indices, (tuple, list)):
+            indices = [indices]
+        if any(isinstance(index, slice) and index.step is None for index in indices):
+            region = []
+            for index in indices:
+                if isinstance(index, slice):
+                    region.append(
+                        Range.from_min_extent(
+                            index.start, Analyzer().simplify(index.stop - index.start)
+                        )
+                    )
+                else:
+                    region.append(Range.from_min_extent(index, 1))
+            return BufferRegion(self, region)
+        else:
+            expr_indices = []
+            for index in indices:
+                if isinstance(index, slice):
+                    lanes = Analyzer().simplify(
+                        (index.stop - index.start + index.step - 1) // index.step
+                    )
+                    if lanes == 1:
+                        expr_indices.append(index.start)
+                    else:
+                        expr_indices.append(Ramp(index.start, index.step, int(lanes)))
+                else:
+                    expr_indices.append(index)
+            return BufferLoad(self, expr_indices)
+
 
 def decl_buffer(
     shape,
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
   assert(y.a == x)
 """
 from typing import Optional, Union
-from tvm import ir
+
 import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
 from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
 
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
 from . import _ffi_api
+from . import generic as _generic
 
 
 def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
 class ExprOp(object):
     """Operator overloading for Expr like expressions."""
 
-    # TODO(tkonolige): use inspect to add source information to these objects
-
     def __add__(self, other):
         return _generic.add(self, other)
 
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
     """
 
     def __init__(self, condition, true_value, false_value, span=None):
+        if isinstance(condition, bool):
+            condition = IntImm("bool", condition)
         self.__init_handle_by_constructor__(
             _ffi_api.Select, condition, true_value, false_value, span  # type: ignore
         )
diff --git a/python/tvm/tir/ir_builder_frame.py b/python/tvm/tir/ir_builder_frame.py
new file mode 100644
index 0000000000..a1f457aad2
--- /dev/null
+++ b/python/tvm/tir/ir_builder_frame.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IRBuilder for TIR"""
+
+from typing import List
+
+from tvm._ffi import register_object as _register_object
+from tvm.ir.ir_builder import IRBuilderFrame
+
+from .buffer import Buffer
+from .expr import Var
+
+
+@_register_object("ir_builder.tir.TIRFrame")
+class TIRFrame(IRBuilderFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.BlockFrame")
+class BlockFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ForFrame")
+class ForFrame(TIRFrame):
+    def __enter__(self) -> List[Var]:
+        super().__enter__()
+        return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("ir_builder.tir.PrimFuncFrame")
+class PrimFuncFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder_v1.py
similarity index 99%
rename from python/tvm/tir/ir_builder.py
rename to python/tvm/tir/ir_builder_v1.py
index ce8cd1b403..8a68faaac6 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder_v1.py
@@ -17,13 +17,13 @@
 """Developer API of IR node builder make function."""
 import tvm
 from tvm._ffi.base import string_types
-from tvm.runtime import ObjectGeneric, convert, const
 from tvm.ir import container as _container
+from tvm.runtime import ObjectGeneric, const, convert
 
-from . import stmt as _stmt
-from . import expr as _expr
 from . import buffer as _buffer
+from . import expr as _expr
 from . import op
+from . import stmt as _stmt
 
 
 class WithScope(object):
diff --git a/python/tvm/tir/ir_builder_v2.py b/python/tvm/tir/ir_builder_v2.py
new file mode 100644
index 0000000000..65699743fd
--- /dev/null
+++ b/python/tvm/tir/ir_builder_v2.py
@@ -0,0 +1,903 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""IRBuilder for TIR"""
+import functools
+import inspect
+from numbers import Integral
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
+from tvm.target import Target as target
+
+from . import _ffi_ir_builder_api as _ffi_api
+from . import ir_builder_frame as frame
+from . import op as _tir_op
+from .buffer import Buffer
+from .expr import Broadcast as broadcast
+from .expr import BufferLoad, CommReducer, IntImm, IterVar, Let, PrimExpr
+from .expr import Ramp as ramp
+from .expr import Select, Shuffle, StringImm, Var
+from .generic import cast
+from .stmt import BufferRegion, type_annotation
+
+
+def buffer_decl(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.BufferDecl(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def ptr(dtype, storage_scope="global"):
+    return _ffi_api.Ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+
+def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
+    return _ffi_api.Block(name, no_realize)  # pylint: disable=no-member # type: ignore
+
+
+def init() -> frame.BlockInitFrame:
+    return _ffi_api.Init()  # pylint: disable=no-member # type: ignore
+
+
+def where(predicate) -> None:
+    if isinstance(predicate, bool):
+        predicate = IntImm("bool", predicate)
+    _ffi_api.Where(predicate)  # pylint: disable=no-member # type: ignore
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]  # type: ignore
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Reads(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Writes(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.BlockAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def alloc_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.AllocBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def _as_range(dom) -> Range:
+    if isinstance(dom, Range):
+        return dom
+    if isinstance(dom, (list, tuple)):
+        return Range(dom[0], dom[1])
+    return Range(0, dom)
+
+
+class axis:  # pylint: disable=invalid-name
+    @staticmethod
+    def spatial(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisSpatial(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def reduce(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisReduce(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def scan(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisScan(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def opaque(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisOpaque(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def remap(kinds, bindings, dtype="int32") -> Union[List[IterVar], IterVar]:
+        iter_vars = _ffi_api.AxisRemap(  # pylint: disable=no-member # type: ignore
+            kinds, bindings, dtype
+        )
+        return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+    S = spatial  # pylint: disable=invalid-name
+    R = reduce  # pylint: disable=invalid-name
+
+
+def serial(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Serial(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def parallel(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Parallel(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def vectorized(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Vectorized(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def unroll(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Unroll(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def thread_binding(
+    start,
+    stop=None,
+    thread=None,
+    *,
+    annotations=None,
+) -> frame.ForFrame:
+    if thread is None:
+        if not isinstance(stop, str):
+            raise ValueError("Thread cannot be None for thread_binding")
+        thread = stop
+        stop = start
+        start = 0
+    elif stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.ThreadBinding(  # pylint: disable=no-member # type: ignore
+        start, stop, thread, annotations
+    )
+
+
+def grid(*extents) -> frame.ForFrame:
+    return _ffi_api.Grid(extents)  # pylint: disable=no-member # type: ignore
+
+
+def prim_func() -> frame.PrimFuncFrame:
+    return _ffi_api.PrimFunc()  # pylint: disable=no-member # type: ignore
+
+
+def arg(name, obj):
+    return _ffi_api.Arg(name, obj)  # pylint: disable=no-member # type: ignore
+
+
+def func_name(name: str) -> str:
+    return _ffi_api.FuncName(name)  # pylint: disable=no-member # type: ignore
+
+
+def func_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.FuncAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def func_ret(ret_type) -> Type:
+    return _ffi_api.FuncRet(ret_type)  # pylint: disable=no-member # type: ignore
+
+
+def match_buffer(
+    param,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.MatchBuffer(  # pylint: disable=no-member # type: ignore
+        param,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def preflattened_buffer(
+    postflattened,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> None:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    _ffi_api.PreflattenedBuffer(  # pylint: disable=no-member # type: ignore
+        postflattened,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Assert(condition, message)  # pylint: disable=no-member # type: ignore
+
+
+def let(
+    v: Var,
+    value: PrimExpr,
+    body: PrimExpr = None,
+) -> frame.LetFrame:
+    if body is None:
+        return _ffi_api.Let(v, value)  # pylint: disable=no-member # type: ignore
+    return Let(v, value, body)
+
+
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str = "",
+    condition: PrimExpr = None,
+    annotations=None,
+) -> frame.AllocateFrame:
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.Allocate(  # pylint: disable=no-member # type: ignore
+        extents, dtype, scope, condition, annotations
+    )
+
+
+def allocate_const(
+    data: List[PrimExpr],
+    dtype: str,
+    extents: List[PrimExpr],
+    annotations=None,
+) -> frame.AllocateConstFrame:
+
+    return _ffi_api.AllocateConst(  # pylint: disable=no-member # type: ignore
+        ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+    )
+
+
+def realize(
+    buffer_slice: BufferRegion,
+    storage_scope: str,
+    condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+    return _ffi_api.Realize(  # pylint: disable=no-member # type: ignore
+        buffer_slice, storage_scope, condition
+    )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+    node = convert(node)
+    value = convert(value)
+    return _ffi_api.Attr(node, attr_key, value)  # pylint: disable=no-member # type: ignore
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.While(condition)  # pylint: disable=no-member # type: ignore
+
+
+def If(condition: PrimExpr) -> frame.IfFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.If(condition)  # pylint: disable=no-member # type: ignore
+
+
+def Then() -> frame.ThenFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Then()  # pylint: disable=no-member # type: ignore
+
+
+def Else() -> frame.ElseFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Else()  # pylint: disable=no-member # type: ignore
+
+
+def decl_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> frame.DeclBufferFrame:
+
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.DeclBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def launch_thread(
+    iter_var: IterVar,  # pylint: disable=redefined-outer-name
+    extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+    return _ffi_api.LaunchThread(iter_var, extent)  # pylint: disable=no-member # type: ignore
+
+
+def env_thread(thread_tag: str) -> IterVar:
+    return _ffi_api.EnvThread(thread_tag)  # pylint: disable=no-member # type: ignore
+
+
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+    from tvm.arith import Analyzer  # pylint: disable=import-outside-toplevel
+
+    expr_indices = []
+    for index in indices:
+        if isinstance(index, slice):
+            step = 1 if index.step is None else index.step
+            lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+            if lanes == 1:
+                expr_indices.append(index.start)
+            else:
+                expr_indices.append(ramp(index.start, step, int(lanes)))
+        else:
+            expr_indices.append(index)
+    if isinstance(value, bool) and buffer.dtype == "bool":
+        value = IntImm("bool", value)
+    return _ffi_api.BufferStore(  # pylint: disable=no-member # type: ignore
+        buffer, value, expr_indices
+    )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+    return _ffi_api.Prefetch(buffer, indices)  # pylint: disable=no-member # type: ignore
+
+
+def evaluate(value: PrimExpr) -> None:
+    if isinstance(value, str):
+        value = StringImm(value)
+    return _ffi_api.Evaluate(value)  # pylint: disable=no-member # type: ignore
+
+
+def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Boolean(expr)  # pylint: disable=no-member # type: ignore
+
+
+def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Handle(expr)  # pylint: disable=no-member # type: ignore
+
+
+def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Void(expr)  # pylint: disable=no-member # type: ignore
+
+
+def min(a, b):  # pylint: disable=redefined-builtin
+    """Compute the minimum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.min(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def max(a, b):  # pylint: disable=redefined-builtin
+    """Compute the maximum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.max(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def var(dtype, name="") -> Var:
+    return Var(name, dtype)  # pylint: disable=no-member # type: ignore
+
+
+def iter_var(v, dom, iter_type, thread_tag):
+    iter_type = getattr(IterVar, iter_type)
+    return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner, identity):
+    """Create a CommReducer from lambda inputs/outputs and the identities"""
+    params = inspect.signature(combiner).parameters
+    num_args = len(params)
+    args = []
+    for name, i in zip(params.keys(), identity + identity):
+        if isinstance(i, int):
+            args.append(Var(name, "int32"))
+        else:
+            args.append(Var(name, i.dtype))
+    res = combiner(*args)
+    if not isinstance(res, tuple):
+        res = (res,)
+    return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def llvm_lookup_intrinsic_id(name):
+    # pylint: disable=import-outside-toplevel
+    from tvm.target.codegen import llvm_lookup_intrinsic_id as f
+
+    # pylint: enable=import-outside-toplevel
+    return f(name)
+
+
+def _op_wrapper(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            kwargs.pop("dtype")
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+def _dtype_forward(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            args = (kwargs.pop("dtype"),) + args
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+# pylint: disable=invalid-name
+
+buffer_var = ptr
+abs = _op_wrapper(_tir_op.abs)  # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round)  # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+# pylint: enable=invalid-name
+
+
+__all__ = [
+    "Assert",
+    "Else",
+    "If",
+    "Let",
+    "Select",
+    "Shuffle",
+    "TVMBackendAllocWorkspace",
+    "TVMBackendFreeWorkspace",
+    "Then",
+    "While",
+    "abs",
+    "acos",
+    "acosh",
+    "address_of",
+    "alloc_buffer",
+    "allocate",
+    "allocate_const",
+    "arg",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "attr",
+    "axis",
+    "block",
+    "block_attr",
+    "boolean",
+    "broadcast",
+    "buffer_decl",
+    "buffer_store",
+    "buffer_var",
+    "call_cpacked",
+    "call_cpacked_lowered",
+    "call_extern",
+    "call_intrin",
+    "call_llvm_intrin",
+    "call_llvm_pure_intrin",
+    "call_packed",
+    "call_packed_lowered",
+    "call_pure_extern",
+    "cast",
+    "ceil",
+    "ceildiv",
+    "clz",
+    "comm_reducer",
+    "copysign",
+    "cos",
+    "cosh",
+    "env_thread",
+    "erf",
+    "evaluate",
+    "exp",
+    "exp10",
+    "exp2",
+    "decl_buffer",
+    "fabs",
+    "float16",
+    "float32",
+    "float64",
+    "float8",
+    "floor",
+    "floordiv",
+    "floormod",
+    "fmod",
+    "func_attr",
+    "func_name",
+    "func_ret",
+    "grid",
+    "handle",
+    "hypot",
+    "if_then_else",
+    "infinity",
+    "init",
+    "int16",
+    "int32",
+    "int64",
+    "int8",
+    "isfinite",
+    "isinf",
+    "isnan",
+    "isnullptr",
+    "iter_var",
+    "launch_thread",
+    "ldexp",
+    "let",
+    "likely",
+    "llvm_lookup_intrinsic_id",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "lookup_param",
+    "match_buffer",
+    "max",
+    "max_value",
+    "min",
+    "min_value",
+    "mma_fill",
+    "mma_store",
+    "nearbyint",
+    "nextafter",
+    "parallel",
+    "popcount",
+    "power",
+    "prefetch",
+    "preflattened_buffer",
+    "prim_func",
+    "ptr",
+    "ptx_commit_group",
+    "ptx_cp_async",
+    "ptx_ldmatrix",
+    "ptx_mma",
+    "ptx_mma_sp",
+    "ptx_wait_group",
+    "q_multiply_shift",
+    "ramp",
+    "reads",
+    "realize",
+    "reinterpret",
+    "ret",
+    "round",
+    "rsqrt",
+    "serial",
+    "shift_left",
+    "shift_right",
+    "sigmoid",
+    "sin",
+    "sinh",
+    "sqrt",
+    "tan",
+    "tanh",
+    "target",
+    "thread_binding",
+    "trunc",
+    "truncdiv",
+    "truncmod",
+    "tvm_access_ptr",
+    "tvm_bmma_sync",
+    "tvm_call_cpacked",
+    "tvm_call_cpacked_lowered",
+    "tvm_call_packed",
+    "tvm_call_packed_lowered",
+    "tvm_fill_fragment",
+    "tvm_load_matrix_sync",
+    "tvm_mma_sync",
+    "tvm_stack_alloca",
+    "tvm_stack_make_array",
+    "tvm_stack_make_shape",
+    "tvm_store_matrix_sync",
+    "tvm_struct_get",
+    "tvm_struct_set",
+    "tvm_thread_allreduce",
+    "tvm_throw_last_error",
+    "tvm_tuple",
+    "type_annotation",
+    "uint16",
+    "uint32",
+    "uint64",
+    "uint8",
+    "unroll",
+    "var",
+    "vectorized",
+    "void",
+    "where",
+    "writes",
+]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 17005b04a4..a0ecd3daaa 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -14,17 +14,18 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=redefined-builtin, invalid-name
+# pylint: disable=redefined-builtin,invalid-name,no-member,protected-access
 """Operators used in TIR expression."""
 from typing import Any, Optional
+
 import tvm._ffi
-from tvm.ir.base import Span
-from tvm.runtime import convert, const
 from tvm.ir import Array, Op
+from tvm.ir.base import Span
+from tvm.runtime import const, convert
 
-from .buffer import Buffer
-from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer
 from . import _ffi_api
+from .buffer import Buffer
+from .expr import Call, CommReducer, PrimExprWithOp, StringImm, Var
 
 
 def _pack_buffer(buf, span=None):
@@ -100,6 +101,64 @@ def call_cpacked(*args, span=None):
     return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span)
 
 
+def call_packed_lowered(*args, span=None):
+    """Lowered version of call packed.
+
+    The argument to packed function can be Expr or Buffer.
+    The argument is the corresponding POD type when Expr is presented.
+
+    When the argument is Buffer, the corresponding PackedFunc
+    will recieve an TVMArrayHandle whose content is valid during the callback period.
+    If the PackedFunc is a python callback, then the corresponding argument is NDArray.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span)
+
+
+def call_cpacked_lowered(*args, span=None):
+    """Lowered version of call c-packed.
+
+    Same as call_packed, except that the first argument is the function name
+    (as in call_extern), and the last argument is the resource handle.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span)
+
+
 def call_intrin(dtype, func_name, *args, span=None):
     """Build expression by calling an intrinsic function.
 
@@ -151,7 +210,10 @@ def call_pure_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span
+        dtype,
+        Op.get("tir.call_pure_extern"),
+        convert((StringImm(func_name),) + args),
+        span,
     )
 
 
@@ -178,7 +240,10 @@ def call_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span
+        dtype,
+        Op.get("tir.call_extern"),
+        convert((StringImm(func_name),) + args),
+        span=span,
     )
 
 
@@ -207,10 +272,21 @@ def call_llvm_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    assert llvm_id != 0, f"{name} is not an LLVM intrinsic"
     return call_intrin(
-        dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+        dtype,
+        Op.get("tir.call_llvm_intrin"),
+        tvm.tir.const(llvm_id, "uint32"),
+        *args,
+        span=span,
     )
 
 
@@ -239,8 +315,15 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    assert llvm_id != 0, f"{name} is not an LLVM intrinsic"
     return call_intrin(
         dtype,
         Op.get("tir.call_llvm_pure_intrin"),
@@ -250,6 +333,306 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     )
 
 
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+    return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+    return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
+def tvm_stack_alloca(dtype_str, num):
+    return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num)
+
+
+def tvm_stack_make_shape(*args):
+    return call_intrin("handle", "tir.tvm_stack_make_shape", *args)
+
+
+def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
+    return call_intrin(
+        "handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset
+    )
+
+
+def address_of(buffer_load, span=None):
+    """Returns the address of an element in the buffer
+
+    Parameters
+    ----------
+    buffer_load: BufferLoad
+        The buffer load.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+
+
+def lookup_param(param_name, span=None):
+    """Returns the param by name
+
+    Parameters
+    ----------
+    param_name : str
+        The name of param.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.lookup_param", param_name, span=span)
+
+
+def tvm_tuple(*value):
+    return call_intrin("handle", "tir.tvm_tuple", *value)
+
+
+def tvm_struct_get(arr, index, field_id, dtype):
+    return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field_id)
+
+
+def tvm_struct_set(arr, index, field_id, value):
+    return call_intrin("handle", "tir.tvm_struct_set", arr, index, field_id, value)
+
+
+def tvm_thread_allreduce(*freduce_args):
+    return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_load_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def tvm_mma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_mma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_bmma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_bmma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_fill_fragment(fragment, m, n, k, index, value):
+    return call_intrin(
+        "handle",
+        "tir.tvm_fill_fragment",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        value,
+    )
+
+
+def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_store_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def ptx_mma(  # pylint: disable=missing-docstring
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    saturate,
+    operator=None,
+):
+    if operator is None:
+        return call_intrin(
+            dtype,
+            "tir.ptx_mma",
+            shape,
+            A_layout,
+            B_layout,
+            A_dtype,
+            B_dtype,
+            C_dtype,
+            multiplicand_a,
+            a_index,
+            multiplicand_b,
+            b_index,
+            accumulator,
+            c_index,
+            saturate,
+        )
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        saturate,
+        operator,
+    )
+
+
+def ptx_mma_sp(  # pylint: disable=missing-docstring
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    metadata,
+    meta_index,
+    sparse_selector,
+    saturate,
+):
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma_sp",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        metadata,
+        meta_index,
+        sparse_selector,
+        saturate,
+    )
+
+
+def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
+    return call_intrin(
+        dtype,
+        "tir.ptx_ldmatrix",
+        trans,
+        num,
+        type,
+        local_ptr,
+        local_offset,
+        smem_ptr,
+        smem_offset,
+    )
+
+
+def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
+    return call_intrin(
+        dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
+    )
+
+
+def ptx_commit_group():
+    return call_intrin("", "tir.ptx_commit_group")
+
+
+def ptx_wait_group(num):
+    return call_intrin("", "tir.ptx_wait_group", num)
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+    return call_intrin(
+        dtype,
+        "tir.mma_store",
+        m,
+        n,
+        dst_ptr,
+        src_ptr,
+        src_offset,
+        dst_stride,
+    )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+    return call_intrin(
+        dtype,
+        "tir.mma_fill",
+        local_size,
+        local_ptr,
+        offset,
+    )
+
+
 def ret(val):
     """Create a tir return expression
 
@@ -286,9 +669,9 @@ def any(*args, span=None):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    val = _ffi_api._OpOr(args[0], args[1], span)  # type: ignore
+    val = _ffi_api._OpOr(args[0], args[1], span)  # type: ignore # pylint: disable=no-member,protected-access
     for i in range(2, len(args)):
-        val = _ffi_api._OpOr(val, args[i], span)  # type: ignore
+        val = _ffi_api._OpOr(val, args[i], span)  # type: ignore # pylint: disable=no-member,protected-access
     return val
 
 
@@ -313,9 +696,9 @@ def all(*args, span=None):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    val = _ffi_api._OpAnd(args[0], args[1], span)  # type: ignore
+    val = _ffi_api._OpAnd(args[0], args[1], span)  # type: ignore  # pylint: disable=no-member,protected-access
     for i in range(2, len(args)):
-        val = _ffi_api._OpAnd(val, args[i], span)  # type: ignore
+        val = _ffi_api._OpAnd(val, args[i], span)  # type: ignore  # pylint: disable=no-member,protected-access
     return val
 
 
@@ -394,6 +777,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
     return _ffi_api.max_value(dtype, span)  # type: ignore
 
 
+def infinity(dtype: str, span: Optional[Span] = None) -> Any:
+    """infinity value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The infinity value of dtype.
+    """
+    return _ffi_api.infinity(dtype, span)  # type: ignore
+
+
+def reinterpret(dtype, value, span=None) -> Any:
+    """infinity value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    value : PrimExpr
+        The input value.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The reinterpret cast value of dtype.
+    """
+    return _ffi_api.reinterpret(dtype, value, span)  # type: ignore
+
+
 def exp(x):
     """Take exponential of input x.
 
@@ -998,6 +1422,25 @@ def ldexp(x1, x2):
     return call_intrin(x1.dtype, "tir.ldexp", x1, x2)  # type: ignore
 
 
+def likely(cond, span=None):
+    """Mark condition as likely.
+
+    Parameters
+    ----------
+    cond : PrimExpr
+        Input argument.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    y : PrimExpr
+        The marked expression.
+    """
+    return _ffi_api.likely(cond, span)  # type: ignore
+
+
 def isnan(x, span=None):
     """Check if input value is Nan.
 
@@ -1017,6 +1460,25 @@ def isnan(x, span=None):
     return _ffi_api.isnan(x, span)  # type: ignore
 
 
+def isnullptr(x, span=None):
+    """Check if input value is nullptr.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_intrin("bool", "tir.isnullptr", x, span=span)  # type: ignore
+
+
 def isfinite(x, span=None):
     """Check if input value is finite.
 
@@ -1122,6 +1584,42 @@ def q_multiply_shift(x, y, q, s):
     return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)
 
 
+def shift_left(x, y, span=None):
+    """Return the result of x left shifted by y bits.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+    y : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    z : PrimExpr
+        The result.
+    """
+    return _ffi_api.left_shift(x, y, span)
+
+
+def shift_right(x, y, span=None):
+    """Return the result of x right shifted by y bits.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+    y : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    z : PrimExpr
+        The result.
+    """
+    return _ffi_api.right_shift(x, y, span)
+
+
 def fmod(x, y):
     """Return the remainder of x divided by y with the same sign as x.
 
@@ -1306,8 +1804,8 @@ def truncmod(a, b, span=None):
     return _ffi_api._OpTruncMod(a, b, span)  # type: ignore
 
 
-def floordiv(a, b, span=None):
-    """Compute the floordiv of two expressions.
+def ceildiv(a, b, span=None):
+    """Compute the ceildiv of two expressions.
 
     Parameters
     ----------
@@ -1325,11 +1823,11 @@ def floordiv(a, b, span=None):
     res : PrimExpr
         The result expression.
     """
-    return _ffi_api._OpFloorDiv(a, b, span)  # type: ignore
+    return _ffi_api._OpCeilDiv(a, b, span)  # type: ignore
 
 
-def floormod(a, b, span=None):
-    """Compute the floormod of two expressions.
+def floordiv(a, b, span=None):
+    """Compute the floordiv of two expressions.
 
     Parameters
     ----------
@@ -1347,27 +1845,29 @@ def floormod(a, b, span=None):
     res : PrimExpr
         The result expression.
     """
-    return _ffi_api._OpFloorMod(a, b, span)  # type: ignore
+    return _ffi_api._OpFloorDiv(a, b, span)  # type: ignore
 
 
-def ceildiv(lhs, rhs, span=None):
-    """Generic ceildiv operator.
+def floormod(a, b, span=None):
+    """Compute the floormod of two expressions.
 
     Parameters
     ----------
-    lhs : object
-        The left operand.
-    rhs : object
-        The right operand.
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
     span : Optional[Span]
         The location of this operator in the source.
 
     Returns
     -------
-    op : tvm.Expr
-        The result Expr of ceildiv operaton.
+    res : PrimExpr
+        The result expression.
     """
-    return _ffi_api._OpCeilDiv(lhs, rhs, span)  # type: ignore
+    return _ffi_api._OpFloorMod(a, b, span)  # type: ignore
 
 
 def comm_reducer(fcombine, fidentity, name="reduce"):
@@ -1523,6 +2023,22 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
     return reducer
 
 
+def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
+    return call_intrin(
+        "handle",
+        "tir.TVMBackendAllocWorkspace",
+        device_type,
+        device_id,
+        nbytes,
+        dtype_code_hint,
+        dtype_bits_hint,
+    )
+
+
+def TVMBackendFreeWorkspace(device_type, device_id, ptr):
+    call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)
+
+
 # pylint: disable=unnecessary-lambda
 sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
 min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min")  # type: ignore
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
index 30e047b4f7..0ebaf212d1 100644
--- a/python/tvm/tir/schedule/block_scope.py
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -20,8 +20,8 @@ from typing import List, Optional, Union
 
 from tvm._ffi import register_object
 from tvm.runtime import Object
-from tvm.tir import Block, For
 
+from ..stmt import Block, For
 from . import _ffi_api
 
 
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index cf031c014c..f26f954d51 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -21,9 +21,11 @@ from tvm._ffi import register_object as _register_object
 from tvm.error import TVMError, register_error
 from tvm.ir import IRModule, PrimExpr
 from tvm.runtime import Object, String
-from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
 
-from ..function import IndexMap
+from ..buffer import Buffer
+from ..expr import FloatImm, IntImm
+from ..function import IndexMap, PrimFunc
+from ..stmt import Block, For
 from . import _ffi_api
 from ._type_checker import type_checked
 from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..3aed52fb50 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -22,8 +22,9 @@ from typing import Dict, Optional, Union
 from tvm._ffi import register_object
 from tvm.ir import IRModule
 from tvm.runtime import Object
-from tvm.tir import Block, BlockRealize, For, PrimFunc
 
+from ..function import PrimFunc
+from ..stmt import Block, BlockRealize, For
 from . import _ffi_api
 from .block_scope import BlockScope, StmtSRef
 
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4847e377de..3c2228e6d9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -754,3 +754,7 @@ def stmt_list(stmt):
             res += stmt_list(x)
         return res
     return [stmt]
+
+
+def type_annotation(dtype, span=None):
+    return _ffi_api.TypeAnnotation(dtype, span)
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
index f472172cf3..86d8bef356 100644
--- a/python/tvm/tir/usmp/transform/transform.py
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -20,8 +20,9 @@
 from typing import Dict
 
 import tvm
-from tvm.tir import Stmt
-from tvm.tir.usmp.utils import PoolAllocation
+
+from ...stmt import Stmt
+from ..utils import PoolAllocation
 from . import _ffi_api
 
 
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index a3318bf94f..a7d95848da 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -21,6 +21,7 @@
  * \file src/ir/expr.cc
  * \brief The expression AST nodes for the common IR infra.
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/expr.h>
 #include <tvm/ir/function.h>
 #include <tvm/runtime/registry.h>
@@ -49,6 +50,18 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
   if (auto* ptr = ref.as<runtime::StringObj>()) {
     return tir::StringImm(GetRef<runtime::String>(ptr));
   }
+  if (auto* ptr = ref.as<tir::BufferRegionNode>()) {
+    tir::BufferRegion buffer_region = GetRef<tir::BufferRegion>(ptr);
+    Array<PrimExpr> indices;
+    for (Range r : buffer_region->region) {
+      if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+        indices.push_back(r->min);
+      } else {
+        indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+      }
+    }
+    return tir::BufferLoad(buffer_region->buffer, indices);
+  }
   Optional<String> actual_type = ObjectTypeChecker<PrimExpr>::CheckAndGetMismatch(ref.get());
   ICHECK(!actual_type.defined()) << "Expected type " << ObjectTypeChecker<PrimExpr>::TypeName()
                                  << " but got " << actual_type.value();
diff --git a/src/ir/ir_builder.cc b/src/ir/ir_builder.cc
new file mode 100644
index 0000000000..9f42cdb168
--- /dev/null
+++ b/src/ir/ir_builder.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ir/ir_builder.h>
+#include <tvm/ir/module.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace ir_builder {
+
+void IRBuilderFrameNode::EnterWithScope() {
+  IRBuilder::Current()->frames.push_back(GetRef<IRBuilderFrame>(this));
+}
+
+void IRBuilderFrameNode::ExitWithScope() {
+  for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
+    (*it)();
+  }
+  this->callbacks.clear();
+  IRBuilder::Current()->frames.pop_back();
+}
+
+void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc<void()> callback) {
+  if (IRBuilder::Current()->frames.empty()) {
+    LOG(FATAL) << "ValueError: No frames in Builder to add callback";
+  }
+  IRBuilder::Current()->frames.back()->callbacks.push_back(callback);
+}
+
+IRBuilder::IRBuilder() {
+  ObjectPtr<IRBuilderNode> n = make_object<IRBuilderNode>();
+  n->frames.clear();
+  n->result = NullOpt;
+  data_ = n;
+}
+
+std::vector<IRBuilder>* ThreadLocalBuilderStack() {
+  thread_local std::vector<IRBuilder> stack;
+  return &stack;
+}
+
+void IRBuilder::EnterWithScope() {
+  IRBuilderNode* n = this->get();
+  CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: "
+                           << n->frames.size()
+                           << ". Please use a fresh new builder every time building IRs";
+  n->result = NullOpt;
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  stack->push_back(*this);
+}
+
+void IRBuilder::ExitWithScope() {
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  ICHECK(!stack->empty());
+  stack->pop_back();
+}
+
+IRBuilder IRBuilder::Current() {
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  CHECK(!stack->empty()) << "ValueError: No builder in current scope";
+  return stack->back();
+}
+
+IRModuleFrame IRModule() {
+  ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
+  n->global_vars.clear();
+  n->functions.clear();
+  return IRModuleFrame(n);
+}
+
+void IRModuleFrameNode::ExitWithScope() {
+  ICHECK_EQ(functions.size(), global_vars.size());
+  int n = functions.size();
+  Map<GlobalVar, BaseFunc> func_map;
+  for (int i = 0; i < n; ++i) {
+    func_map.Set(global_vars[i], functions[i]);
+  }
+  IRBuilder builder = IRBuilder::Current();
+  ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+  builder->result = tvm::IRModule(func_map);
+}
+
+namespace details {
+
+Namer::FType& Namer::vtable() {
+  static FType inst;
+  return inst;
+}
+
+void Namer::Name(ObjectRef node, String name) {
+  static const FType& f = vtable();
+  CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name;
+  CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \""
+                              << node->GetTypeKey();
+  f(node, name);
+}
+
+}  // namespace details
+
+TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode);
+TVM_REGISTER_NODE_TYPE(IRBuilderNode);
+TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameEnter")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameExit")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameAddCallback")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::AddCallback);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); });
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderGet")
+    .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRModule").set_body_typed(IRModule);
+
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 7979c9f47a..9bbd9068e2 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -20,6 +20,7 @@
 /*!
  * \file expr.cc
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
@@ -828,12 +829,26 @@ TVM_REGISTER_GLOBAL("tir.Call")
     .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) {
       Array<PrimExpr> prim_expr_args;
       for (const auto& it : args) {
-        ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>())
+        ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() ||
+               it->IsInstance<IterVarNode>() || it->IsInstance<BufferRegionNode>())
             << "Argument " << it << " is not a string or primexpr";
         if (const auto* str = it.as<runtime::StringObj>()) {
           prim_expr_args.push_back(StringImm(str->data));
+        } else if (const auto* expr = it.as<PrimExprNode>()) {
+          prim_expr_args.push_back(GetRef<PrimExpr>(expr));
+        } else if (const auto* br = it.as<BufferRegionNode>()) {
+          BufferRegion buffer_region = GetRef<BufferRegion>(br);
+          Array<PrimExpr> indices;
+          for (Range r : buffer_region->region) {
+            if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+              indices.push_back(r->min);
+            } else {
+              indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+            }
+          }
+          prim_expr_args.push_back(BufferLoad(buffer_region->buffer, indices));
         } else {
-          prim_expr_args.push_back(Downcast<PrimExpr>(it));
+          prim_expr_args.push_back(Downcast<IterVar>(it).operator PrimExpr());
         }
       }
       return Call(type, op, prim_expr_args, span);
diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index 7e3d3d1075..b11ca6650a 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -22,11 +22,10 @@
  * \brief Used by TVM Script parser to expand incomplete TIR input
  */
 
+#include "./script_complete.h"
+
 #include <tvm/arith/int_set.h>
-#include <tvm/runtime/registry.h>
 #include <tvm/tir/analysis.h>
-#include <tvm/tir/stmt.h>
-#include <tvm/tir/stmt_functor.h>
 
 #include <utility>
 
diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h
new file mode 100644
index 0000000000..5392ea309d
--- /dev/null
+++ b/src/tir/ir/script/script_complete.h
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/ir/script/script_complete.h
+ * \brief Used by TVM Script parser to expand incomplete TIR input
+ */
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 524204f3d3..6dd5bf5886 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -1091,5 +1091,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
 TVM_REGISTER_OP("tir.type_annotation")
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
+TVM_REGISTER_GLOBAL("tir.TypeAnnotation").set_body_typed(TypeAnnotation);
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder.cc b/src/tir/ir_builder/ir_builder.cc
new file mode 100644
index 0000000000..d312f014e5
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder.cc
@@ -0,0 +1,637 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+                  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+                  String storage_scope, int align, int offset_factor, String buffer_type,
+                  Optional<Array<IntImm>> axis_separators) {
+  Var buffer_data;
+  if (!data.defined()) {
+    DataType storage_dtype = dtype;
+    if (storage_dtype == DataType::Bool()) {
+      storage_dtype = DataType::Int(8);
+    }
+    buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope));
+  } else {
+    buffer_data = data.value();
+  }
+  if (!elem_offset.defined() && offset_factor) {
+    DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype;
+    elem_offset = tvm::tir::Var("elem_offset", shape_dtype);
+  }
+  return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
+                elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor,
+                (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault,
+                axis_separators.value_or(Array<IntImm>()));
+}
+
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+                           Optional<Var> data, Optional<Array<PrimExpr>> strides,
+                           Optional<PrimExpr> elem_offset, String storage_scope, int align,
+                           int offset_factor, String buffer_type,
+                           Optional<Array<IntImm>> axis_separators) {
+  ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
+  n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
+                         align, offset_factor, buffer_type, axis_separators);
+  return DeclBufferFrame(n);
+}
+
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
+  return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
+}
+
+BlockFrame Block(String name, bool no_realize) {
+  ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
+  n->name = name;
+  n->iter_vars.clear();
+  n->reads = NullOpt;
+  n->writes = NullOpt;
+  n->init = NullOpt;
+  n->alloc_buffers.clear();
+  n->match_buffers.clear();
+  n->annotations.clear();
+  n->iter_values.clear();
+  n->predicate = NullOpt;
+  n->no_realize = no_realize;
+  return BlockFrame(n);
+}
+
+BlockInitFrame Init() { return BlockInitFrame(make_object<BlockInitFrameNode>()); }
+
+void Where(PrimExpr predicate) {
+  BlockFrame frame = FindBlockFrame("T.where");
+  if (frame->predicate.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is "
+               << frame->predicate.value();
+  }
+  frame->predicate = predicate;
+}
+
+void Reads(Array<ObjectRef> buffer_slices) {
+  using namespace tvm::tir;
+  BlockFrame frame = FindBlockFrame("T.reads");
+  if (frame->reads.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads;
+  }
+  Array<BufferRegion> reads;
+  for (const ObjectRef& obj : buffer_slices) {
+    if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+      reads.push_back(GetRef<BufferRegion>(buffer_region));
+    } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+      reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+    } else {
+      LOG(FATAL) << "Invalid type for buffer reads.";
+    }
+  }
+  frame->reads = reads;
+}
+
+void Writes(Array<ObjectRef> buffer_slices) {
+  using namespace tvm::tir;
+  BlockFrame frame = FindBlockFrame("T.writes");
+  if (frame->writes.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is "
+               << frame->writes;
+  }
+  Array<BufferRegion> writes;
+  for (const ObjectRef& obj : buffer_slices) {
+    if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+      writes.push_back(GetRef<BufferRegion>(buffer_region));
+    } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+      writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+    } else {
+      LOG(FATAL) << "Invalid type for buffer writes.";
+    }
+  }
+  frame->writes = writes;
+}
+
+void BlockAttrs(Map<String, ObjectRef> attrs) {
+  BlockFrame frame = FindBlockFrame("T.block_attr");
+  if (!frame->annotations.empty()) {
+    LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
+  }
+  frame->annotations = attrs;
+}
+
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+                   Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+                   int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+  Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+                             offset_factor, buffer_type_str, axis_separators);
+  IRBuilder builder = IRBuilder::Current();
+  if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) {
+    frame.value()->alloc_buffers.push_back(buffer);
+  } else if (Optional<PrimFuncFrame> frame = builder->GetLastFrame<PrimFuncFrame>()) {
+    frame.value()->root_alloc_buffers.push_back(buffer);
+  } else {
+    LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure "
+                  "'T.alloc_buffer' is called under T.block() or T.prim_func()";
+  }
+  return buffer;
+};
+
+namespace axis {
+
+IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
+  if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+    BlockFrame frame = opt_frame.value();
+    frame->iter_vars.push_back(iter_var);
+    frame->iter_values.push_back(binding);
+  } else {
+    LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
+  }
+  return iter_var;
+}
+
+#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name)                                           \
+  IterVar Method(Range dom, PrimExpr binding, DataType dtype) {                               \
+    ICHECK(dom.defined()) << Name << " axis must have a domain";                              \
+    int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
+    return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)),          \
+                                /*iter_type=*/Kind, /*thread_tag=*/""),                       \
+                        binding);                                                             \
+  }
+TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
+TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
+TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
+TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
+#undef TVM_TIR_IR_BUILDER_AXIS
+
+Array<IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
+  using namespace tvm::tir;
+  Array<IterVar> results;
+  ICHECK_EQ(kinds.size(), bindings.size());
+  int n = bindings.size();
+  results.reserve(n);
+  for (int i = 0; i < n; ++i) {
+    char c = kinds.c_str()[i];
+    PrimExpr e = bindings[i];
+    const VarNode* v = e.as<VarNode>();
+    ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
+    Range dom{nullptr};
+    for (const auto& frame : IRBuilder::Current()->frames) {
+      if (const auto* for_frame = frame.as<ForFrameNode>()) {
+        ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
+        int n = for_frame->doms.size();
+        for (int i = 0; i < n; ++i) {
+          if (for_frame->vars[i].get() == v) {
+            dom = for_frame->doms[i];
+            break;
+          }
+        }
+        if (dom.defined()) {
+          break;
+        }
+      }
+    }
+    ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
+    DataType dtype = v->dtype;
+    if (c == 'S') {
+      results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+                                             /*var=*/Var("", dtype),
+                                             /*iter_type=*/IterVarType::kDataPar,
+                                             /*thread_tag=*/""),
+                                     e));
+    } else if (c == 'R') {
+      results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+                                             /*var=*/Var("", dtype),
+                                             /*iter_type=*/IterVarType::kCommReduce,
+                                             /*thread_tag=*/""),
+                                     e));
+    } else {
+      LOG(FATAL) << "Unknown axis kind: " << c;
+    }
+  }
+  return results;
+}
+
+}  // namespace axis
+
+#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind)                                                \
+  ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) {  \
+    PrimExpr min = start;                                                                         \
+    PrimExpr extent = arith::Analyzer().Simplify(stop - start);                                   \
+    ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();                                      \
+    int bits = std::max(min.dtype().bits(), extent.dtype().bits());                               \
+    n->vars = {Var("v", DataType::Int(bits))};                                                    \
+    n->doms = {Range::FromMinExtent(min, extent)};                                                \
+    n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, tvm::tir::Stmt body) { \
+      ICHECK_EQ(vars.size(), 1);                                                                  \
+      ICHECK_EQ(doms.size(), 1);                                                                  \
+      return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt,           \
+                           annotations.value_or(Map<String, ObjectRef>()));                       \
+    };                                                                                            \
+    return ForFrame(n);                                                                           \
+  }
+
+TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled);
+
+#undef TVM_TIR_IR_BUILDER_FOR_FRAME
+
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+                       Optional<Map<String, ObjectRef>> annotations) {
+  using namespace tvm::tir;
+  PrimExpr min = start;
+  PrimExpr extent = arith::Analyzer().Simplify(stop - start);
+  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+  int bits = std::max(min.dtype().bits(), extent.dtype().bits());
+  n->vars = {Var("v", DataType::Int(bits))};
+  n->doms = {Range::FromMinExtent(min, extent)};
+  n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
+    ICHECK_EQ(vars.size(), 1);
+    ICHECK_EQ(doms.size(), 1);
+    IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
+                     thread);
+    return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
+               annotations.value_or(Map<String, ObjectRef>()));
+  };
+  return ForFrame(n);
+}
+
+ForFrame Grid(Array<PrimExpr> extents) {
+  using namespace tvm::tir;
+  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+  n->vars.reserve(extents.size());
+  n->doms.reserve(extents.size());
+  for (const auto& extent : extents) {
+    DataType dtype = extent.dtype();
+    n->vars.push_back(Var("v", extent.dtype()));
+    n->doms.push_back(Range(make_const(dtype, 0), extent));
+  }
+  n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, Stmt body) -> Stmt {
+    ICHECK_EQ(vars.size(), doms.size());
+    int n = vars.size();
+    for (int i = n - 1; i >= 0; --i) {
+      Range dom = doms[i];
+      Var var = vars[i];
+      body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
+                 /*thread_binding=*/NullOpt, /*annotations=*/{});
+    }
+    return body;
+  };
+  return ForFrame(n);
+}
+
+PrimFuncFrame PrimFunc() {
+  ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>();
+  n->name = NullOpt;
+  n->args.clear();
+  n->ret_type = NullOpt;
+  n->buffer_map.clear();
+  n->preflattened_buffer_map.clear();
+  n->attrs.clear();
+  n->root_alloc_buffers.clear();
+  return PrimFuncFrame(n);
+}
+
+Var Arg(String name, Var var) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+  details::Namer::Name(var, name);
+  frame->args.push_back(var);
+  return var;
+}
+
+Buffer Arg(String name, Buffer buffer) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+  details::Namer::Name(buffer, name);
+  Var handle(buffer->name + "_handle", DataType::Handle());
+  frame->args.push_back(handle);
+  frame->buffer_map.Set(handle, buffer);
+  return buffer;
+}
+
+void FuncName(String name) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
+  if (frame->name.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value();
+  }
+  frame->name = name;
+}
+
+void FuncAttrs(Map<String, ObjectRef> attrs) {
+  using namespace tvm::tir;
+  PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr");
+  if (!frame->attrs.empty()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
+  }
+  frame->attrs = attrs;
+}
+
+tvm::Type FuncRet(tvm::Type ret_type) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type");
+  if (frame->ret_type.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is "
+               << frame->ret_type.value();
+  }
+  frame->ret_type = ret_type;
+  return ret_type;
+}
+
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+                   Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+                   int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+  Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+                             offset_factor, buffer_type_str, axis_separators);
+  if (const auto* var = param.as<tvm::tir::VarNode>()) {
+    PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer");
+    Var v = GetRef<Var>(var);
+    for (auto const& arg : frame->args) {
+      if (arg.same_as(v)) {
+        frame->buffer_map.Set(v, buffer);
+        return buffer;
+      }
+    }
+    LOG(FATAL) << "ValueError: Can not bind non-input param to buffer.";
+  } else if (const auto* buffer_load = param.as<tvm::tir::BufferLoadNode>()) {
+    BlockFrame frame = FindBlockFrame("T.match_buffer");
+    frame->match_buffers.push_back(tvm::tir::MatchBufferRegion(
+        buffer, BufferRegionFromLoad(GetRef<tvm::tir::BufferLoad>(buffer_load))));
+  } else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) {
+    BlockFrame frame = FindBlockFrame("T.match_buffer");
+    frame->match_buffers.push_back(
+        tvm::tir::MatchBufferRegion(buffer, GetRef<tvm::tir::BufferRegion>(buffer_region)));
+  } else {
+    LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
+  }
+  return buffer;
+};
+
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape, DataType dtype,
+                        Optional<Var> data, Array<PrimExpr> strides, PrimExpr elem_offset,
+                        String storage_scope, int align, int offset_factor, String buffer_type_str,
+                        Array<IntImm> axis_separators) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer");
+  for (auto const& p : frame->buffer_map) {
+    if (p.second.same_as(postflattened_buffer)) {
+      String buffer_name(postflattened_buffer->name + "_preflatten");
+      Buffer buffer =
+          BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset,
+                     storage_scope, align, offset_factor, buffer_type_str, axis_separators);
+      details::Namer::Name(buffer, buffer_name);
+      frame->preflattened_buffer_map.Set(p.first, buffer);
+      return;
+    }
+  }
+  LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name
+             << " does not exist.";
+};
+
+AssertFrame Assert(PrimExpr condition, String message) {
+  ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>();
+  n->condition = condition;
+  n->message = tvm::tir::StringImm(message);
+  return AssertFrame(n);
+}
+
+LetFrame Let(Var var, PrimExpr value) {
+  ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
+  n->var = var;
+  n->value = value;
+  return LetFrame(n);
+}
+
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope,
+                       Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) {
+  ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>();
+  n->extents = extents;
+  n->dtype = dtype;
+  n->storage_scope = storage_scope;
+  n->condition = condition.value_or(tvm::Bool(true));
+  n->annotations = annotations.value_or(Map<String, ObjectRef>());
+  n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
+                         "default", NullOpt);
+  return AllocateFrame(n);
+}
+
+AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
+                                 Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+  ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
+  n->dtype = dtype;
+  n->extents = extents;
+  n->data = data;
+  n->annotations = annotations;
+  n->buffer =
+      BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+  return AllocateConstFrame(n);
+}
+
+LaunchThreadFrame LaunchThread(IterVar iter_var, PrimExpr extent) {
+  ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
+  if (!iter_var->dom.defined()) {
+    const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+  } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
+    LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
+               << iter_var->dom->extent << " vs " << extent;
+  }
+  n->iter_var = iter_var;
+  n->extent = extent;
+  n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent";
+  return LaunchThreadFrame(n);
+}
+
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
+                     PrimExpr condition) {
+  ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
+  n->buffer_slice = buffer_slice;
+  n->storage_scope = storage_scope;
+  n->condition = condition;
+  return RealizeFrame(n);
+}
+
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) {
+  ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
+  n->node = node;
+  n->attr_key = attr_key;
+  n->value = value;
+  return AttrFrame(n);
+}
+
+WhileFrame While(PrimExpr condition) {
+  ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>();
+  n->condition = condition;
+  return WhileFrame(n);
+}
+
+IfFrame If(PrimExpr condition) {
+  ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>();
+  n->condition = condition;
+  n->then_stmts = NullOpt;
+  n->else_stmts = NullOpt;
+  return IfFrame(n);
+}
+
+ThenFrame Then() {
+  ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>();
+  return ThenFrame(n);
+}
+
+ElseFrame Else() {
+  ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>();
+  return ElseFrame(n);
+}
+
+IterVar EnvThread(String thread_tag) {
+  return IterVar(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
+                 thread_tag);
+}
+
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+  AddToParent(tvm::tir::BufferStore(buffer, value, indices));
+}
+
+void Prefetch(Buffer buffer, Array<Range> bounds) {
+  AddToParent(tvm::tir::Prefetch(buffer, bounds));
+}
+
+void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
+
+using tvm::ir_builder::details::Namer;
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::BufferNode>([](const ObjectRef& node, String name) -> void {
+      tvm::tir::BufferNode* buffer =
+          const_cast<tvm::tir::BufferNode*>(node.as<tvm::tir::BufferNode>());
+      buffer->name = name;
+      Namer::Name(buffer->data, name + "_data");
+      int n = buffer->strides.size();
+      for (int i = 0; i < n; ++i) {
+        PrimExpr e = buffer->strides[i];
+        if (const tvm::tir::VarNode* v = e.as<tvm::tir::VarNode>()) {
+          Namer::Name(GetRef<tvm::tir::Var>(v), name + "_s" + std::to_string(i));
+        }
+      }
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::SizeVarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      SizeVarNode* var = const_cast<SizeVarNode*>(node.as<SizeVarNode>());
+      var->name_hint = name;
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::VarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      VarNode* var = const_cast<VarNode*>(node.as<VarNode>());
+      var->name_hint = name;
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::IterVarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      IterVarNode* var = const_cast<IterVarNode*>(node.as<IterVarNode>());
+      Namer::Name(var->var, name);
+    });
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Ptr").set_body_typed(Ptr);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Block").set_body_typed(Block);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Init").set_body_typed(Init);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Where").set_body_typed(Where);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Reads").set_body_typed(Reads);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Writes").set_body_typed(Writes);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Serial").set_body_typed(Serial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Parallel").set_body_typed(Parallel);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Vectorized").set_body_typed(Vectorized);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Unroll").set_body_typed(Unroll);
+TVM_REGISTER_GLOBAL("ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Grid").set_body_typed(Grid);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Arg")
+    .set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
+      using namespace tvm::tir;
+      if (const auto* var = obj.as<VarNode>()) {
+        return Arg(name, GetRef<tvm::tir::Var>(var));
+      }
+      if (const auto* buffer = obj.as<BufferNode>()) {
+        return Arg(name, GetRef<Buffer>(buffer));
+      }
+      LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey();
+      throw;
+    });
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncName").set_body_typed(FuncName);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncRet").set_body_typed(FuncRet);
+TVM_REGISTER_GLOBAL("ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Assert").set_body_typed(Assert);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Allocate").set_body_typed(Allocate);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Attr").set_body_typed(Attr);
+TVM_REGISTER_GLOBAL("ir_builder.tir.While").set_body_typed(While);
+TVM_REGISTER_GLOBAL("ir_builder.tir.If").set_body_typed(If);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Then").set_body_typed(Then);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Else").set_body_typed(Else);
+TVM_REGISTER_GLOBAL("ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.EnvThread").set_body_typed(EnvThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferStore").set_body_typed(BufferStore);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Prefetch").set_body_typed(Prefetch);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Evaluate").set_body_typed(Evaluate);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int8").set_body_typed(Int8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int16").set_body_typed(Int16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32").set_body_typed(Int32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int64").set_body_typed(Int64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt8").set_body_typed(UInt8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt16").set_body_typed(UInt16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt32").set_body_typed(UInt32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt64").set_body_typed(UInt64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float8").set_body_typed(Float8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float16").set_body_typed(Float16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float32").set_body_typed(Float32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float64").set_body_typed(Float64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Boolean").set_body_typed(Boolean);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Handle").set_body_typed(Handle);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Void").set_body_typed(Void);
+TVM_REGISTER_GLOBAL("ir_builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+  return tvm::min(a, b);
+});
+TVM_REGISTER_GLOBAL("ir_builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+  return tvm::max(a, b);
+});
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder_frame.cc b/src/tir/ir_builder/ir_builder_frame.cc
new file mode 100644
index 0000000000..cd0cd46b50
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder_frame.cc
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/function.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "../../tir/ir/script/script_complete.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+void BlockFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  Array<tvm::tir::Buffer> tir_alloc_buffers;
+  for (const tvm::tir::Buffer& buffer : alloc_buffers) {
+    tir_alloc_buffers.push_back(buffer);
+  }
+  if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) {
+    annotations.Set("tir.script_parsing_detect_access",
+                    tvm::IntImm(DataType::Int(64), detect_access));
+  }
+  tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()),
+                        writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init,
+                        tir_alloc_buffers, match_buffers, annotations);
+  if (no_realize) {
+    CHECK(iter_values.empty())
+        << "ValueError: Block bindings are not allowed when `no_realize=True`";
+    CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
+    AddToParent(block);
+  } else {
+    AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
+  }
+}
+
+void BlockInitFrameNode::EnterWithScope() {
+  BlockFrame frame = FindBlockFrame("T.init");
+  if (frame->init.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block init declaration";
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void BlockInitFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  BlockFrame frame = FindBlockFrame("T.init");
+  frame->init = AsStmt(stmts);
+}
+
+void ForFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
+}
+
+void PrimFuncFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  tvm::tir::PrimFunc func(/*params=*/args,
+                          /*body=*/AsStmt(stmts),
+                          /*ret_type=*/ret_type.value_or(TupleType::Empty()),
+                          /*buffer_map=*/buffer_map,
+                          /*preflattened_buffer_map=*/preflattened_buffer_map,
+                          /*attrs=*/attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs));
+  func = tvm::tir::ScriptComplete(func, root_alloc_buffers);
+  IRBuilder builder = IRBuilder::Current();
+  if (builder->frames.empty()) {
+    ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+    builder->result = func;
+  } else if (Optional<IRModuleFrame> opt_frame = builder->FindFrame<IRModuleFrame>()) {
+    IRModuleFrame frame = opt_frame.value();
+    frame->global_vars.push_back(GlobalVar(name.value_or("")));
+    frame->functions.push_back(func);
+  } else {
+    LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
+  }
+}
+
+void AssertFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+}
+
+void LetFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
+}
+
+void AllocateFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
+                                 AsStmt(stmts), annotations));
+}
+
+void AllocateConstFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(
+      tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+}
+
+void LaunchThreadFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
+}
+
+void RealizeFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
+                                 tvm::tir::StringImm(storage_scope),
+                                 tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region,
+                                                         condition, AsStmt(stmts))));
+}
+
+void AttrFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts)));
+}
+
+void WhileFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::While(condition, AsStmt(stmts)));
+}
+
+void IfFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  if (!stmts.empty()) {
+    LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame";
+  }
+  if (!then_stmts.defined()) {
+    LOG(FATAL) << "IfThenElse frame should have at least one then branch";
+  }
+  AddToParent(tvm::tir::IfThenElse(
+      condition, AsStmt(then_stmts.value()),
+      else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr)));
+}
+
+void ThenFrameNode::EnterWithScope() {
+  IfFrame frame = FindIfFrame("T.then_");
+  if (frame->then_stmts.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is "
+               << frame->then_stmts.value();
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void ThenFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  FindIfFrame("T.then_")->then_stmts = stmts;
+}
+
+void ElseFrameNode::EnterWithScope() {
+  IfFrame frame = FindIfFrame("T.else_");
+  if (!frame->then_stmts.defined()) {
+    LOG(FATAL) << "The else branch should follow then branch";
+  }
+  if (frame->else_stmts.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is "
+               << frame->else_stmts.value();
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void ElseFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  FindIfFrame("T.else_")->else_stmts = stmts;
+}
+
+void DeclBufferFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+}
+
+TVM_REGISTER_NODE_TYPE(TIRFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
+TVM_REGISTER_NODE_TYPE(ForFrameNode);
+TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
+TVM_REGISTER_NODE_TYPE(AssertFrameNode);
+TVM_REGISTER_NODE_TYPE(LetFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode);
+TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
+TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
+TVM_REGISTER_NODE_TYPE(AttrFrameNode);
+TVM_REGISTER_NODE_TYPE(WhileFrameNode);
+TVM_REGISTER_NODE_TYPE(IfFrameNode);
+TVM_REGISTER_NODE_TYPE(ThenFrameNode);
+TVM_REGISTER_NODE_TYPE(ElseFrameNode);
+TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode);
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir_builder/utils.h b/src/tir/ir_builder/utils.h
new file mode 100644
index 0000000000..6b6271ab0e
--- /dev/null
+++ b/src/tir/ir_builder/utils.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_BUILDER_TIR_BASE_H_
+#define TVM_SCRIPT_BUILDER_TIR_BASE_H_
+
+#include <tvm/tir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+inline void AddToParent(tvm::tir::Stmt stmt) {
+  IRBuilder builder = IRBuilder::Current();
+  if (builder->frames.empty()) {
+    ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+    builder->result = stmt;
+  } else if (const auto* tir_frame = builder->frames.back().as<TIRFrameNode>()) {
+    GetRef<TIRFrame>(tir_frame)->stmts.push_back(stmt);
+  } else {
+    LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back();
+  }
+}
+
+inline tvm::tir::Stmt AsStmt(const Array<tvm::tir::Stmt>& stmt) {
+  using namespace tvm::tir;
+  if (stmt.empty()) {
+    return tvm::tir::Evaluate(0);
+  } else if (stmt.size() == 1) {
+    return stmt[0];
+  } else {
+    return SeqStmt(stmt);
+  }
+}
+
+inline BlockFrame FindBlockFrame(const String& method) {
+  if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+    return frame.value();
+  }
+  LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
+             << "' is called under T.block()";
+  throw;
+}
+
+inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
+  if (Optional<PrimFuncFrame> frame = IRBuilder::Current()->GetLastFrame<PrimFuncFrame>()) {
+    return frame.value();
+  }
+  LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method
+             << "' is called under T.prim_func()";
+  throw;
+}
+
+inline IfFrame FindIfFrame(const String& method) {
+  if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) {
+    return frame.value();
+  } else {
+    LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method
+               << "' is called under T.if_()";
+  }
+  throw;
+}
+
+inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) {
+  Array<Range> ranges;
+  for (const PrimExpr& index : buffer_load->indices) {
+    ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1)));
+  }
+  return tvm::tir::BufferRegion(buffer_load->buffer, ranges);
+}
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_BUILDER_TIR_BASE_H_
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 114571218b..dd3c7a0b0f 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
                    {x, y, q, s}, span);
 }
 
+// address_of
+PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
+  return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
+}
+
+// lookup_param
+PrimExpr lookup_param(String param_name, Span span) {
+  return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
+                   span);
+}
+
 // The public function with a quick checking path.
 void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) {  // NOLINT(*)
   CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
@@ -702,6 +713,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
   }
 }
 
+// isnullptr
+PrimExpr isnullptr(PrimExpr x, Span span) {
+  return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
+}
+
 // isinf
 PrimExpr isinf(PrimExpr x, Span span) {
   DataType t = DataType::Bool(x.dtype().lanes());
@@ -931,6 +947,8 @@ TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
 
 TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
 
+TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);
+
 TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);
 
 TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
@@ -949,6 +967,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
 
 TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
 
+TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity);
+
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
 // operator overloading, smarter than make
 #define REGISTER_MAKE_BINARY_OP(Node, Func)                                                \
   TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
@@ -997,6 +1019,8 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
 REGISTER_MAKE_BIT_OP(left_shift, left_shift);  // NOLINT(*)
 REGISTER_MAKE_BIT_OP(right_shift, right_shift);
 
+TVM_REGISTER_GLOBAL("tir._OpNot").set_body_typed(logical_not);
+
 TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
     .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
       return if_then_else(cond, true_value, false_value, span);