You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 22:31:05 UTC

[tvm] 01/01: Squashed

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch ir-builder
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 22477336093aaf860a59107c58caae99c37ad3bb
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Mon May 23 17:06:45 2022 -0700

    Squashed
---
 include/tvm/ir/expr.h                              |  20 +-
 include/tvm/ir/ir_builder.h                        | 190 +++++
 include/tvm/support/with.h                         |   2 +
 include/tvm/tir/ir_builder.h                       | 140 +++
 include/tvm/tir/ir_builder_frame.h                 | 456 ++++++++++
 include/tvm/tir/op.h                               |  34 +-
 python/tvm/ir/__init__.py                          |  55 +-
 .../_ffi_api.py => ir/_ffi_ir_builder_api.py}      |   4 +-
 python/tvm/ir/ir_builder.py                        |  85 ++
 python/tvm/meta_schedule/default_config.py         |   2 +-
 python/tvm/meta_schedule/testing/schedule_rule.py  |  12 +-
 python/tvm/script/__init__.py                      |  19 +-
 python/tvm/script/{ => parser}/__init__.py         |  14 +-
 python/tvm/script/parser/diagnostics.py            |  60 ++
 python/tvm/script/parser/dispatch.py               |  63 ++
 python/tvm/script/parser/doc.py                    | 361 ++++++++
 python/tvm/script/{printer => parser}/doc_core.py  |   0
 .../script/{tir/prim_func.py => parser/entry.py}   |  45 +-
 python/tvm/script/parser/evaluator.py              | 282 ++++++
 .../script/{_ffi_api.py => parser/ir/__init__.py}  |   7 +-
 .../tvm/script/{__init__.py => parser/ir/entry.py} |  19 +-
 .../{tir/__init__.py => parser/ir/parser.py}       |  30 +-
 python/tvm/script/parser/parser.py                 | 214 +++++
 python/tvm/script/parser/source.py                 | 134 +++
 python/tvm/script/{ => parser/tir}/__init__.py     |   9 +-
 python/tvm/script/parser/tir/entry.py              | 103 +++
 python/tvm/script/parser/tir/operation.py          |  85 ++
 python/tvm/script/parser/tir/parser.py             | 269 ++++++
 .../script/{tir/prim_func.py => parser/utils.py}   |  47 +-
 python/tvm/script/parser/var_table.py              |  71 ++
 python/tvm/script/{ => parser_v1}/__init__.py      |   3 +-
 python/tvm/script/{ => parser_v1}/_ffi_api.py      |   0
 .../script/{ => parser_v1}/context_maintainer.py   |   8 +-
 python/tvm/script/{ => parser_v1}/diagnostics.py   |   6 +-
 python/tvm/script/{ => parser_v1}/meta_unparser.py |   0
 python/tvm/script/{ => parser_v1}/parser.py        |  23 +-
 python/tvm/script/{ => parser_v1}/registry.py      |   2 +-
 python/tvm/script/{ => parser_v1}/tir/__init__.py  |   0
 python/tvm/script/{ => parser_v1}/tir/__init__.pyi |   0
 python/tvm/script/{ => parser_v1}/tir/intrin.py    |   5 +-
 python/tvm/script/{ => parser_v1}/tir/node.py      |   7 +-
 python/tvm/script/{ => parser_v1}/tir/prim_func.py |   3 +-
 .../script/{ => parser_v1}/tir/scope_handler.py    |  17 +-
 .../tvm/script/{ => parser_v1}/tir/special_stmt.py |  16 +-
 python/tvm/script/{ => parser_v1}/tir/ty.py        |   1 +
 python/tvm/script/{ => parser_v1}/utils.py         |   7 +-
 python/tvm/te/operation.py                         |  15 +-
 python/tvm/tir/__init__.py                         | 217 ++++-
 .../_ffi_api.py => tir/_ffi_ir_builder_api.py}     |   4 +-
 python/tvm/tir/analysis/analysis.py                |   4 +-
 python/tvm/tir/buffer.py                           |  39 +-
 python/tvm/tir/expr.py                             |  15 +-
 python/tvm/tir/function.py                         |   2 +-
 python/tvm/tir/ir_builder_frame.py                 | 118 +++
 python/tvm/tir/{ir_builder.py => ir_builder_v1.py} |   6 +-
 python/tvm/tir/ir_builder_v2.py                    | 949 +++++++++++++++++++++
 python/tvm/tir/op.py                               | 601 ++++++++++++-
 python/tvm/tir/schedule/block_scope.py             |   2 +-
 python/tvm/tir/schedule/schedule.py                |   6 +-
 python/tvm/tir/schedule/state.py                   |   3 +-
 python/tvm/tir/stmt.py                             |   4 +
 python/tvm/tir/tensor_intrin/__init__.py           |   6 +-
 python/tvm/tir/tensor_intrin/arm_cpu.py            |   3 +-
 python/tvm/tir/tensor_intrin/cuda.py               |  14 +-
 python/tvm/tir/tensor_intrin/rocm.py               |   2 +-
 python/tvm/tir/usmp/transform/transform.py         |   5 +-
 src/ir/diagnostic.cc                               |   6 +-
 src/ir/expr.cc                                     |  13 +
 src/ir/ir_builder.cc                               | 134 +++
 src/tir/ir/expr.cc                                 |  24 +-
 src/tir/ir/script/script_complete.cc               |   5 +-
 src/tir/ir/script/script_complete.h                |  37 +
 src/tir/ir/stmt.cc                                 |  10 +
 src/tir/ir_builder/ir_builder.cc                   | 664 ++++++++++++++
 src/tir/ir_builder/ir_builder_frame.cc             | 208 +++++
 src/tir/ir_builder/utils.h                         |  92 ++
 src/tir/op/op.cc                                   |  24 +
 src/tir/schedule/primitive/cache_read_write.cc     |   2 +-
 .../test_meta_schedule_auto_tensorize.py           |  10 +-
 .../unittest/test_aot_legalize_packed_call.py      |   6 +-
 ...est_meta_schedule_postproc_rewrite_tensorize.py |   2 +-
 ...ta_schedule_schedule_rule_multi_level_tiling.py |   4 +-
 .../unittest/test_meta_schedule_space_cuda.py      |   2 +-
 .../unittest/test_meta_schedule_tune_relay.py      |   8 +-
 tests/python/unittest/test_target_codegen_llvm.py  |  15 +-
 .../python/unittest/test_tir_lower_match_buffer.py |  44 +-
 .../python/unittest/test_tir_schedule_tensorize.py |  11 +-
 .../python/unittest/test_tir_transform_helpers.py  |   7 +-
 .../test_tir_transform_hoist_expression.py         |  10 +-
 .../test_tir_transform_inject_software_pipeline.py |  14 +-
 .../test_tir_transform_inject_virtual_thread.py    |  18 +-
 .../python/unittest/test_tvmscript_error_report.py |  40 +-
 tests/python/unittest/test_tvmscript_spans.py      |  73 --
 .../python/unittest/test_tvmscript_syntax_sugar.py |  39 +-
 94 files changed, 5992 insertions(+), 475 deletions(-)

diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 5e358ed50e..cbc7db8686 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -764,16 +764,32 @@ struct PackedFuncValueConverter<PrimExpr> {
       return PrimExpr(ObjectPtr<Object>(nullptr));
     }
     if (val.type_code() == kDLInt) {
-      return PrimExpr(val.operator int());
+      return IntImm(runtime::DataType::Int(32), val.operator int());
     }
     if (val.type_code() == kDLFloat) {
-      return PrimExpr(static_cast<float>(val.operator double()));
+      return FloatImm(runtime::DataType::Float(32), val.operator double());
     }
 
     return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
   }
 };
 
+// template <>
+// struct PackedFuncValueConverter<Array<PrimExpr>> {
+//   static Array<PrimExpr> From(const TVMPODValue_& val) {
+//     if (val.type_code() == kTVMNullptr) return Array<PrimExpr>(nullptr);
+//     Array<ObjectRef> vals = val.AsObjectRef<Array<ObjectRef>>();
+//     Array<PrimExpr> exprs;
+//     for (const ObjectRef& v : vals) {
+//       TVMValue value;
+//       value.v_handle = const_cast<void*>(static_cast<const void*>(v.get()));
+//       exprs.push_back(
+//           PackedFuncValueConverter<PrimExpr>::From(TVMArgValue(value, kTVMObjectHandle)));
+//     }
+//     return exprs;
+//   }
+// };
+
 template <>
 struct PackedFuncValueConverter<tvm::Integer> {
   static tvm::Integer From(const TVMPODValue_& val) {
diff --git a/include/tvm/ir/ir_builder.h b/include/tvm/ir/ir_builder.h
new file mode 100644
index 0000000000..2ecc774ec5
--- /dev/null
+++ b/include/tvm/ir/ir_builder.h
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_IR_IR_BUILDER_H_
+#define TVM_IR_IR_BUILDER_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+
+#include <vector>
+
+namespace tvm {
+namespace ir_builder {
+
+////////////////////////////// Core Infra: Frame and IRBuilder //////////////////////////////
+
+class IRBuilderFrameNode : public runtime::Object {
+ public:
+  std::vector<runtime::TypedPackedFunc<void()>> callbacks;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `callbacks` is not visited.
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRBuilderFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
+
+ public:
+  virtual ~IRBuilderFrameNode() = default;
+  virtual void EnterWithScope();
+  virtual void ExitWithScope();
+
+  void AddCallback(runtime::TypedPackedFunc<void()> callback);
+};
+
+class IRBuilderFrame : public runtime::ObjectRef {
+ public:
+  virtual ~IRBuilderFrame() = default;
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
+
+ protected:
+  IRBuilderFrame() = default;
+
+ public:
+  inline void EnterWithScope();
+  inline void ExitWithScope();
+};
+
+class IRBuilderNode : public runtime::Object {
+ public:
+  runtime::Array<IRBuilderFrame> frames;
+  Optional<ObjectRef> result;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("frames", &frames);
+    v->Visit("result", &result);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRBuilder";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
+
+ public:
+  template <typename TFrame>
+  inline Optional<TFrame> FindFrame() const;
+  template <typename TFrame>
+  inline Optional<TFrame> GetLastFrame() const;
+  template <typename TObjectRef>
+  inline TObjectRef Get() const;
+};
+
+class IRBuilder : public runtime::ObjectRef {
+ public:
+  IRBuilder();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
+
+ public:
+  void EnterWithScope();
+  void ExitWithScope();
+  static IRBuilder Current();
+  template <class TObjectRef>
+  inline static TObjectRef Name(String name, TObjectRef obj);
+};
+
+////////////////////////////// Generic IRModule //////////////////////////////
+
+class IRModuleFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<GlobalVar> global_vars;
+  Array<BaseFunc> functions;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("global_vars", &global_vars);
+    v->Visit("functions", &functions);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.IRModuleFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IRModuleFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame,
+                                                    IRModuleFrameNode);
+};
+
+TVM_DLL IRModuleFrame IRModule();
+
+////////////////////////////// Details //////////////////////////////
+
+namespace details {
+
+class Namer {
+ public:
+  using FType = NodeFunctor<void(const ObjectRef&, String)>;
+  static FType& vtable();
+  static void Name(ObjectRef node, String name);
+};
+
+}  // namespace details
+
+template <class TObjectRef>
+inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
+  details::Namer::Name(obj, name);
+  return Downcast<TObjectRef>(obj);
+}
+
+inline void IRBuilderFrame::EnterWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
+}
+
+inline void IRBuilderFrame::ExitWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
+  data_.reset();
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::FindFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
+    if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
+      return GetRef<TFrame>(p);
+    }
+  }
+  return NullOpt;
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
+    return Downcast<TFrame>(frames.back());
+  }
+  return NullOpt;
+}
+
+template <typename TObjectRef>
+inline TObjectRef IRBuilderNode::Get() const {
+  using TObject = typename TObjectRef::ContainerType;
+  CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
+  const auto* n = result.as<TObject>();
+  CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key;
+  return GetRef<TObjectRef>(n);
+}
+
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_IR_IR_BUILDER_H_
diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h
index 3651e05e74..bbc36419ff 100644
--- a/include/tvm/support/with.h
+++ b/include/tvm/support/with.h
@@ -75,6 +75,8 @@ class With {
   ContextType& operator*() { return *get(); }
   const ContextType* operator*() const { return *get(); }
 
+  ContextType operator()() { return ctx_; }
+
  private:
   /*! \brief internal context type. */
   ContextType ctx_;
diff --git a/include/tvm/tir/ir_builder.h b/include/tvm/tir/ir_builder.h
new file mode 100644
index 0000000000..19111b9b20
--- /dev/null
+++ b/include/tvm/tir/ir_builder.h
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_H_
+#define TVM_TIR_IR_BUILDER_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/ir_builder_frame.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+using tvm::runtime::NDArray;
+using tvm::tir::Buffer;
+using tvm::tir::Var;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+                  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+                  String storage_scope, int align, int offset_factor, String buffer_type,
+                  Optional<Array<IntImm>> axis_separators);
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
+BlockFrame Block(String name, bool no_realize = false);
+BlockInitFrame Init();
+void Where(PrimExpr predicate);
+void Reads(Array<ObjectRef> buffer_slices);
+void Writes(Array<ObjectRef> buffer_slices);
+void BlockAttrs(Map<String, ObjectRef> attrs);
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
+                   int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+
+namespace axis {
+Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+}  // namespace axis
+
+ForFrame Serial(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Parallel(PrimExpr start, PrimExpr stop,
+                  Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Vectorized(PrimExpr start, PrimExpr stop,
+                    Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Unroll(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Grid(Array<PrimExpr> extents);
+
+PrimFuncFrame PrimFunc();
+Var Arg(String name, Var var);
+Buffer Arg(String name, Buffer buffer);
+void FuncName(String name);
+void FuncAttrs(Map<String, ObjectRef> attrs);
+Type FuncRet(Type ret_type);
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
+                   int align = -1, int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
+                        DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
+                        Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
+                        String storage_scope = "global", int align = -1, int offset_factor = 0,
+                        String buffer_type = "default", Array<IntImm> axis_separators = {});
+
+AssertFrame Assert(PrimExpr condition, String message);
+LetFrame Let(Var var, PrimExpr value);
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+                       Optional<PrimExpr> condition = NullOpt,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+AllocateConstFrame AllocateConst(
+    NDArray data, DataType dtype, Array<PrimExpr> extents,
+    Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+WhileFrame While(PrimExpr condition);
+IfFrame If(PrimExpr condition);
+ThenFrame Then();
+ElseFrame Else();
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
+Var EnvThread(String thread_tag);
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+void Prefetch(Buffer buffer, Array<Range> bounds);
+void Evaluate(PrimExpr value);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)                             \
+  inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) {                        \
+    DataType dtype = DType;                                                            \
+    return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
+  }
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
+
+#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_H_
diff --git a/include/tvm/tir/ir_builder_frame.h b/include/tvm/tir/ir_builder_frame.h
new file mode 100644
index 0000000000..d975710a05
--- /dev/null
+++ b/include/tvm/tir/ir_builder_frame.h
@@ -0,0 +1,456 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_FRAME_H_
+#define TVM_TIR_IR_BUILDER_FRAME_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+class TIRFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<tvm::tir::Stmt> stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("stmts", &stmts);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.TIRFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode);
+};
+
+class TIRFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode);
+
+ protected:
+  TIRFrame() = default;
+};
+
+class BlockFrameNode : public TIRFrameNode {
+ public:
+  String name;
+  Array<tvm::tir::IterVar> iter_vars;
+  Optional<Array<tvm::tir::BufferRegion>> reads;
+  Optional<Array<tvm::tir::BufferRegion>> writes;
+  Optional<tvm::tir::Stmt> init;
+  Array<tvm::tir::Buffer> alloc_buffers;
+  Array<tvm::tir::MatchBufferRegion> match_buffers;
+  Optional<Map<String, ObjectRef>> annotations;
+
+  Array<PrimExpr> iter_values;
+  Optional<PrimExpr> predicate;
+  bool no_realize;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("iter_vars", &iter_vars);
+    v->Visit("reads", &reads);
+    v->Visit("writes", &writes);
+    v->Visit("init", &init);
+    v->Visit("alloc_buffers", &alloc_buffers);
+    v->Visit("match_buffers", &match_buffers);
+    v->Visit("annotations", &annotations);
+    v->Visit("iter_values", &iter_values);
+    v->Visit("predicate", &predicate);
+    v->Visit("no_realize", &no_realize);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.BlockFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class BlockFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
+};
+
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+  static constexpr const char* _type_key = "ir_builder.tir.BlockInitFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class BlockInitFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode);
+};
+
+class ForFrameNode : public TIRFrameNode {
+ public:
+  using FMakeForLoop =
+      runtime::TypedPackedFunc<tvm::tir::Stmt(Array<tvm::tir::Var>, Array<Range>, tvm::tir::Stmt)>;
+
+  Array<tvm::tir::Var> vars;
+  Array<Range> doms;
+  FMakeForLoop f_make_for_loop;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("vars", &vars);
+    v->Visit("doms", &doms);
+    // `f_make_for_loop` is not visited.
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.ForFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class ForFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
+};
+
+class PrimFuncFrameNode : public TIRFrameNode {
+ public:
+  Optional<String> name;
+  Array<tvm::tir::Var> args;
+  Optional<Type> ret_type;
+  Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
+  Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
+  Optional<Map<String, ObjectRef>> attrs;
+  Map<tvm::tir::Var, tvm::tir::IterVar> env_threads;
+  Array<tvm::tir::Buffer> root_alloc_buffers;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("args", &args);
+    v->Visit("ret_type", &ret_type);
+    v->Visit("buffer_map", &buffer_map);
+    v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
+    v->Visit("attrs", &attrs);
+    v->Visit("env_threads", &env_threads);
+    v->Visit("root_alloc_buffers", &root_alloc_buffers);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.PrimFuncFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class PrimFuncFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
+};
+
+class AssertFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  PrimExpr message;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("message", &message);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AssertFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AssertFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+class LetFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Var var;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.LetFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LetFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+  Array<PrimExpr> extents;
+  DataType dtype;
+  String storage_scope;
+  PrimExpr condition;
+  Map<String, ObjectRef> annotations;
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extents", &extents);
+    v->Visit("dtype", &dtype);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+    v->Visit("annotations", &annotations);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AllocateFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+  DataType dtype;
+  Array<PrimExpr> extents;
+  tvm::runtime::NDArray data;
+  tvm::tir::Buffer buffer;
+  Map<String, ObjectRef> annotations;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("dtype", &dtype);
+    v->Visit("extents", &extents);
+    v->Visit("data", &data);
+    v->Visit("buffer", &buffer);
+    v->Visit("annotations", &annotations);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AllocateConstFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateConstFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+                                                    AllocateConstFrameNode);
+};
+
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr extent;
+  String attr_key;
+  tvm::tir::IterVar iter_var;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extent", &extent);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("iter_var", &iter_var);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.LaunchThreadFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LaunchThreadFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+                                                    LaunchThreadFrameNode);
+};
+
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::BufferRegion buffer_slice;
+  String storage_scope;
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer_slice", &buffer_slice);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.RealizeFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class RealizeFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
+
+class AttrFrameNode : public TIRFrameNode {
+ public:
+  ObjectRef node;
+  String attr_key;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("node", &node);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.AttrFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AttrFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+class WhileFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.WhileFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class WhileFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+class IfFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  Optional<Array<tvm::tir::Stmt>> then_stmts;
+  Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("then_stmts", &then_stmts);
+    v->Visit("else_stmts", &else_stmts);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.IfFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IfFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+class ThenFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "ir_builder.tir.ThenFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ThenFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+class ElseFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "ir_builder.tir.ElseFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ElseFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "ir_builder.tir.DeclBufferFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_FRAME_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 7236c6a611..09758cc923 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
  * \return The result expression.
  */
 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
 /*!
  * \brief Check if x is infinite.
  * \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
                       Span span = Span());
 
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
 /*!
  * \brief Calculate floor(x)
  * \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
                                   Span span = Span());
 
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
 // Intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
   inline PrimExpr OpName(PrimExpr x, Span span = Span()) {              \
@@ -701,6 +732,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
 TVM_DECLARE_INTRIN_UNARY(log);
 TVM_DECLARE_INTRIN_UNARY(log2);
 TVM_DECLARE_INTRIN_UNARY(log10);
+TVM_DECLARE_INTRIN_UNARY(log1p);
 TVM_DECLARE_INTRIN_UNARY(popcount);
 TVM_DECLARE_INTRIN_UNARY(tan);
 TVM_DECLARE_INTRIN_UNARY(cos);
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4e847c0310..51dff32ac6 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -16,29 +16,46 @@
 # under the License.
 # pylint: disable=unused-import
 """Common data structures across all IR variants."""
-from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .base import structural_equal, assert_structural_equal, structural_hash
-from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
-from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
-from .tensor_type import TensorType
-from .affine_type import TensorAffineType, TupleAffineType
-from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
-from .op import Op, register_op_attr, register_intrin_lowering
-from .function import CallingConv, BaseFunc
+from . import diagnostics, instrument, ir_builder, transform
 from .adt import Constructor, TypeData
-from .module import IRModule
+from .affine_type import TensorAffineType, TupleAffineType
 from .attrs import Attrs, DictAttrs, make_node
+from .base import (
+    EnvFunc,
+    Node,
+    SourceName,
+    Span,
+    assert_structural_equal,
+    load_json,
+    save_json,
+    structural_equal,
+    structural_hash,
+)
 from .container import Array, Map
+from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
+from .function import BaseFunc, CallingConv
 from .memory_pools import (
-    PoolInfo,
-    WorkspacePoolInfo,
-    ConstantPoolInfo,
-    WorkspaceMemoryPools,
     ConstantMemoryPools,
+    ConstantPoolInfo,
+    PoolInfo,
     PoolInfoProperties,
+    WorkspaceMemoryPools,
+    WorkspacePoolInfo,
 )
-
-from . import transform
-from . import instrument
-from . import diagnostics
+from .module import IRModule
+from .op import Op, register_intrin_lowering, register_op_attr
+from .tensor_type import TensorType
+from .type import (
+    FuncType,
+    GlobalTypeVar,
+    IncompleteType,
+    PointerType,
+    PrimType,
+    RelayRefType,
+    TupleType,
+    Type,
+    TypeConstraint,
+    TypeKind,
+    TypeVar,
+)
+from .type_relation import TypeCall, TypeRelation
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/ir/_ffi_ir_builder_api.py
similarity index 88%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/ir/_ffi_ir_builder_api.py
index 926d17b166..9d08bc9b70 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/ir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
 import tvm._ffi
 
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/ir/ir_builder.py b/python/tvm/ir/ir_builder.py
new file mode 100644
index 0000000000..db0b1ead6d
--- /dev/null
+++ b/python/tvm/ir/ir_builder.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A generic IRBuilder across the TVM stack"""
+from typing import List, TypeVar
+
+from tvm._ffi import register_object as _register_object
+from tvm.runtime import Object as _Object
+
+from . import _ffi_ir_builder_api as _ffi_api
+
+
+@_register_object("ir_builder.IRBuilderFrame")
+class IRBuilderFrame(_Object):
+    def __enter__(self) -> "IRBuilderFrame":
+        _ffi_api.IRBuilderFrameEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameExit(self)  # pylint: disable=no-member # type: ignore
+
+    def add_callback(self, callback) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameAddCallback(  # pylint: disable=no-member # type: ignore
+            self, callback
+        )
+
+
+@_register_object("ir_builder.IRBuilder")
+class IRBuilder(_Object):
+    def __init__(self) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.IRBuilder  # pylint: disable=no-member # type: ignore
+        )
+
+    def __enter__(self) -> "IRBuilder":
+        _ffi_api.IRBuilderEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderExit(self)  # pylint: disable=no-member # type: ignore
+
+    @staticmethod
+    def current() -> "IRBuilder":
+        return _ffi_api.IRBuilderCurrent()  # pylint: disable=no-member # type: ignore
+
+    def get(self) -> _Object:
+        return _ffi_api.IRBuilderGet(self)  # pylint: disable=no-member # type: ignore
+
+
+DefType = TypeVar("DefType", bound=_Object)
+
+
+def name(s: str, v: DefType) -> DefType:
+    return _ffi_api.IRBuilderName(s, v)  # pylint: disable=no-member # type: ignore
+
+
+def name_many(  # pylint: disable=invalid-name
+    s: List[str],
+    vs: List[DefType],
+) -> List[DefType]:
+    assert len(s) == len(vs)
+    return [name(i, v) for i, v in zip(s, vs)]
+
+
+@_register_object("ir_builder.IRModuleFrame")
+class IRModuleFrame(IRBuilderFrame):
+    ...
+
+
+def ir_module() -> IRModuleFrame:
+    return _ffi_api.IRModule()  # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index 58f82a248b..105b3467de 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -357,7 +357,7 @@ class _DefaultCUDATensorCore:
     @staticmethod
     def schedule_rules():
         from tvm.meta_schedule import schedule_rule as M
-        from tvm.tir.tensor_intrin import get_wmma_intrin_group
+        from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group
 
         return [
             M.MultiLevelTilingTensorCore(
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index f5a936f491..3d90030bcf 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -16,6 +16,7 @@
 # under the License.
 """Default schedule rules"""
 from typing import List, Union
+
 from tvm.meta_schedule.schedule_rule import (
     AddRFactor,
     AutoBind,
@@ -27,8 +28,9 @@ from tvm.meta_schedule.schedule_rule import (
     ReuseType,
     ScheduleRule,
 )
-from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore
-from tvm.tir import tensor_intrin
+from tvm.meta_schedule.schedule_rule.multi_level_tiling import (
+    MultiLevelTilingTensorCore,
+)
 from tvm.target import Target
 
 
@@ -130,8 +132,12 @@ def multi_level_tiling_tensor_core(
         trans_b = [trans_b]
 
     if target.kind.name == "cuda":
+        from tvm.tir.tensor_intrin import (  # pylint: disable=import-outside-toplevel
+            cuda,
+        )
+
         intrin_groups = [
-            tensor_intrin.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
+            cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
             for _in_dtype in in_dtype
             for _out_dtype in out_dtype
             for _trans_b in trans_b
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..8b132dcdf0 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,22 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
+from . import parser, parser_v1
 
-from . import tir
+#############
+from .parser import ir as ir_v2
+from .parser import ir_module as ir_module_v2
+from .parser import parse as from_source_v2
+from .parser import tir as tir_v2
 
-from .parser import ir_module, from_source
+#############
+from .parser_v1 import from_source as from_source_v1
+from .parser_v1 import ir_module as ir_module_v1
+from .parser_v1 import tir as tir_v1
+
+# pylint: disable=invalid-name
+
+ir = ir_v2
+ir_module = ir_module_v2
+tir = tir_v2
+from_source = from_source_v2
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 77%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..d8530e0ab1 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,13 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
+# under the Licens.
+"""The parser"""
+from . import dispatch as _dispatch
+from . import doc as _doc
+from . import ir
+from . import parser as _parser
 from . import tir
-
-from .parser import ir_module, from_source
+from .entry import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/parser/diagnostics.py b/python/tvm/script/parser/diagnostics.py
new file mode 100644
index 0000000000..bb4f05a254
--- /dev/null
+++ b/python/tvm/script/parser/diagnostics.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+from .source import Source
+
+
+class Diagnostics:
+
+    source: Source
+    ctx: diagnostics.DiagnosticContext
+
+    def __init__(self, source: Source):
+        mod = IRModule()
+        mod.source_map.add(source.source_name, source.full_source)
+        self.source = source
+        self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+    def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+        lineno = node.lineno or self.source.start_line
+        col_offset = node.col_offset or self.source.start_column
+        end_lineno = node.end_lineno or lineno
+        end_col_offset = node.end_col_offset or col_offset
+        lineno += self.source.start_line - 1
+        end_lineno += self.source.start_line - 1
+        col_offset += self.source.start_column + 1
+        end_col_offset += self.source.start_column + 1
+        self.ctx.emit(
+            diagnostics.Diagnostic(
+                level=level,
+                span=Span(
+                    source_name=SourceName(self.source.source_name),
+                    line=lineno,
+                    end_line=end_lineno,
+                    column=col_offset,
+                    end_column=end_col_offset,
+                ),
+                message=message,
+            )
+        )
+
+    def error(self, node: doc.AST, message: str) -> None:
+        self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+        self.ctx.render()
diff --git a/python/tvm/script/parser/dispatch.py b/python/tvm/script/parser/dispatch.py
new file mode 100644
index 0000000000..f10b90961a
--- /dev/null
+++ b/python/tvm/script/parser/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+    """Register a method for a dispatch token and type name"""
+
+    def f(method: ParseMethod):
+        ParseVTable[(token, type_name)] = method
+
+    return f
+
+
+def get(
+    token: str,
+    type_name: str,
+    default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+    return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int):  # pylint: disable=invalid-name
+    def f(method: OpMethod):
+        OpVTable[(ty, op, operand_index)] = method
+
+    return f
+
+
+def get_op(  # pylint: disable=invalid-name
+    ty: Type,
+    op: Type,
+    operand_index: int,
+    default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+    return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/doc.py b/python/tvm/script/parser/doc.py
new file mode 100644
index 0000000000..f6a641cb64
--- /dev/null
+++ b/python/tvm/script/parser/doc.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import *  # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+    to_doc: typing.Optional[FnToDoc]
+    from_doc: typing.Optional[FnFromDoc]
+
+    def __init__(self):
+        self.to_doc = None
+        self.from_doc = None
+
+
+class Registry:
+    _inst: typing.Optional["Registry"] = None
+    table: typing.Dict[str, Entry]
+
+    def __init__(self):
+        self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+    def f(to_doc: FnToDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].to_doc = to_doc
+
+    return f
+
+
+def register_from_doc(name: str):
+    def f(to_doc: FnFromDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].from_doc = to_doc
+
+    return f
+
+
+def _is_atomic_type(node):
+    return (
+        node is None
+        or node in [..., True, False]
+        or isinstance(
+            node,
+            (
+                int,
+                float,
+                str,
+                bool,
+                bytes,
+                complex,
+            ),
+        )
+    )
+
+
+def _get_registry_entry(cls_name, attr):
+    cls_name = cls_name.split(".")[-1]
+    reg = Registry._inst  # pylint: disable=protected-access
+    if cls_name in reg.table:
+        entry = reg.table[cls_name]
+        return getattr(entry, attr, None)
+    return None
+
+
+def from_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(from_doc(n) for n in node)
+    if isinstance(node, list):
+        return [from_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "from_doc")
+    if not func:
+        raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def to_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(to_doc(n) for n in node)
+    if isinstance(node, list):
+        return [to_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "to_doc")
+    if not func:
+        raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def parse(
+    source,
+    filename="<unknown>",
+    mode="exec",
+) -> doc.AST:
+    try:
+        program = ast.parse(  # pylint: disable=unexpected-keyword-arg
+            source=source,
+            filename=filename,
+            mode=mode,
+            feature_version=(3, 8),
+        )
+    except:  # pylint: disable=bare-except
+        program = ast.parse(
+            source=source,
+            filename=filename,
+            mode=mode,
+        )
+    return to_doc(program)
+
+
+class NodeVisitor:
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> None:
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                self.visit(value)
+
+
+class NodeTransformer:
+    def visit(self, node: doc.AST) -> doc.AST:
+        if isinstance(node, list):
+            return [self.visit(item) for item in node]
+        if isinstance(node, tuple):
+            return tuple(self.visit(item) for item in node)
+        if not isinstance(node, doc.AST):
+            return node
+        return getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> doc.AST:
+        kv: typing.Dict[str, typing.Any] = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                value = self.visit(value)
+            kv[field] = value
+        return node.__class__(**kv)
+
+
+def _register_default():
+    class DefaultTranslator:
+        def __init__(self, doc_cls, func, fields):
+            self.doc_cls = doc_cls  # getattr(doc, name)
+            self.func = func
+            self.fields = fields
+
+        def __call__(self, node):
+            kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+            return self.doc_cls(**kv)
+
+    Registry._inst = Registry()  # pylint: disable=protected-access
+    for cls_name in dir(doc):
+        doc_cls = getattr(doc, cls_name)
+        if not hasattr(ast, cls_name):
+            continue
+        if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+            assert "." not in cls_name
+            register_to_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(doc, cls_name),
+                    to_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+            register_from_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(ast, cls_name),
+                    from_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+    return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+    if _py_version() not in [(3, 6), (3, 7)]:
+        return
+
+    def as_constant(f) -> doc.Constant:
+        def to_doc_func(x: ast.AST) -> doc.Constant:
+            return doc.Constant(
+                value=getattr(x, f) if isinstance(f, str) else f(x),
+                kind=None,
+                s=None,
+                n=None,
+                lineno=x.lineno,
+                col_offset=x.col_offset,
+                end_lineno=x.lineno,
+                end_col_offset=x.col_offset,
+            )
+
+        return to_doc_func
+
+    register_to_doc("Str")(as_constant("s"))
+    register_to_doc("NameConstant")(as_constant("value"))
+    register_to_doc("Num")(as_constant("n"))
+    register_to_doc("Bytes")(as_constant("s"))
+    register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+        if isinstance(x.slice, ast.Slice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Slice(
+                    lower=to_doc(x.slice.lower),
+                    upper=to_doc(x.slice.upper),
+                    step=to_doc(x.slice.step),
+                    lineno=getattr(x.slice, "lineno", None),
+                    col_offset=getattr(x.slice, "col_offset", None),
+                    end_lineno=getattr(x.slice, "end_lineno", None),
+                    end_col_offset=getattr(x.slice, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.ExtSlice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Tuple(
+                    elts=[to_doc(i) for i in x.slice.dims],
+                    ctx=doc.Load(
+                        lineno=None,
+                        col_offset=None,
+                        end_lineno=None,
+                        end_col_offset=None,
+                    ),
+                    lineno=getattr(x, "lineno", None),
+                    col_offset=getattr(x, "col_offset", None),
+                    end_lineno=getattr(x, "end_lineno", None),
+                    end_col_offset=getattr(x, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.Index):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=to_doc(x.slice.value),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+    def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+        if isinstance(x.slice, doc.Slice):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=from_doc(x.slice),
+                ctx=from_doc(x.ctx),
+            )
+        elif isinstance(x.slice, doc.Tuple):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.ExtSlice(
+                    dims=[from_doc(i) for i in x.slice.elts],
+                ),
+                ctx=from_doc(x.ctx),
+            )
+        else:
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.Index(value=from_doc(x.slice)),
+                ctx=from_doc(x.ctx),
+            )
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Subscript")(subscript_to_doc)
+    register_from_doc("Subscript")(subscript_from_doc)
+
+
+def _register_index_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def index_to_doc(x: ast.Index) -> doc.Expr:
+        return to_doc(x.value)
+
+    def index_from_doc(x: doc.Expr) -> ast.Index:
+        result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx))
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Index")(index_to_doc)
+    register_from_doc("Index")(index_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
+_register_index_handling()
diff --git a/python/tvm/script/printer/doc_core.py b/python/tvm/script/parser/doc_core.py
similarity index 100%
rename from python/tvm/script/printer/doc_core.py
rename to python/tvm/script/parser/doc_core.py
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/entry.py
similarity index 51%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/entry.py
index 923eb97d27..b70e876d43 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/entry.py
@@ -14,32 +14,31 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
 
-import inspect
-from typing import Callable
+from tvm.ir.ir_builder import IRBuilder
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+from . import doc
+from .parser import Parser
+from .source import Source
 
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+    if extra_vars is None:
+        from tvm.script.parser import ir  # pylint: disable=import-outside-toplevel
+        from tvm.script.parser import tir  # pylint: disable=import-outside-toplevel
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
+        extra_vars = {
+            "I": ir,
+            "ir": ir,
+            "T": tir,
+            "tir": tir,
+        }
 
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
-
-    raise TypeError("Only function definitions are supported.")
+    source = Source(program)
+    parser = Parser(source)
+    with IRBuilder() as builder:
+        parser.parse(extra_vars=extra_vars)
+    return builder.get()
diff --git a/python/tvm/script/parser/evaluator.py b/python/tvm/script/parser/evaluator.py
new file mode 100644
index 0000000000..3899531b21
--- /dev/null
+++ b/python/tvm/script/parser/evaluator.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+    doc.Add: lambda a, b: a + b,
+    doc.Sub: lambda a, b: a - b,
+    doc.Mult: lambda a, b: a * b,
+    doc.Div: lambda a, b: a / b,
+    doc.FloorDiv: lambda a, b: a // b,
+    doc.Mod: lambda a, b: a % b,
+    doc.LShift: lambda a, b: a << b,
+    doc.RShift: lambda a, b: a >> b,
+    doc.BitOr: lambda a, b: a | b,
+    doc.BitXor: lambda a, b: a ^ b,
+    doc.BitAnd: lambda a, b: a & b,
+    doc.MatMult: lambda a, b: a @ b,
+    doc.Pow: lambda a, b: a**b,
+    doc.Eq: lambda a, b: a == b,
+    doc.NotEq: lambda a, b: a != b,
+    doc.Lt: lambda a, b: a < b,
+    doc.LtE: lambda a, b: a <= b,
+    doc.Gt: lambda a, b: a > b,
+    doc.GtE: lambda a, b: a >= b,
+    doc.Is: lambda a, b: a is b,
+    doc.IsNot: lambda a, b: a is not b,
+    doc.In: lambda a, b: a in b,
+    doc.NotIn: lambda a, b: a not in b,
+    doc.And: lambda a, b: a and b,
+    doc.Or: lambda a, b: a or b,
+    doc.Invert: lambda a: ~a,
+    doc.Not: lambda a: not a,
+    doc.UAdd: lambda a: +a,
+    doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+    parser: "Parser"
+    value_table: Dict[str, Any]
+    new_value_count: int
+
+    def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+        super().__init__()
+        self.parser = parser
+        self.value_table = value_table
+        self.new_value_count = 0
+
+    @staticmethod
+    def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+        self = ExprEvaluator(parser, value_table)
+        result = self._visit(node)  # pylint: disable=protected-access
+        if isinstance(result, doc.Name):
+            if result.id not in self.value_table:
+                self.parser.report_error(result, f"Undefined variable: {result.id}")
+            return self.value_table[result.id]
+        if isinstance(result, doc.Constant):
+            return result.value
+        raise TypeError(f"Unexpected result type: {type(result)}")
+
+    def _add_intermediate_result(self, value: Any) -> doc.Name:
+        name = f"__tvm_tmp_value_{self.new_value_count}"
+        self.new_value_count += 1
+        self.value_table[name] = value
+        lineno = 0
+        col_offset = 0
+        return doc.Name(
+            id=name,
+            ctx=doc.Load(
+                lineno=lineno,
+                col_offset=col_offset,
+                end_lineno=None,
+                end_col_offset=None,
+            ),
+            lineno=lineno,
+            col_offset=col_offset,
+            end_lineno=None,
+            end_col_offset=None,
+        )
+
+    def _visit(self, node: doc.AST) -> Any:
+        if isinstance(node, list):
+            return [self._visit(n) for n in node]
+        if isinstance(node, tuple):
+            return tuple(self._visit(n) for n in node)
+        assert isinstance(node, doc.AST)
+        if isinstance(node, doc.Name):
+            if node.id not in self.value_table:
+                self.parser.report_error(node, f"Undefined variable: {node.id}")
+            return node
+        if isinstance(
+            node,
+            (
+                doc.Constant,
+                doc.expr_context,
+                doc.operator,
+                doc.boolop,
+                doc.unaryop,
+                doc.cmpop,
+            ),
+        ):
+            return node
+        if not isinstance(node, (doc.expr, doc.slice)):
+            return node
+        if isinstance(node, doc.Lambda):
+            return self._eval_lambda(node)
+        fields = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            attr = getattr(node, field)
+            if isinstance(attr, (doc.AST, tuple, list)):
+                fields[field] = self._visit(attr)
+            else:
+                fields[field] = attr
+        try:
+            if isinstance(node, doc.BoolOp):
+                value = self._eval_bool_op(fields)
+            elif isinstance(node, doc.Compare):
+                value = self._eval_compare(fields)
+            elif isinstance(node, doc.UnaryOp):
+                value = self._eval_unary_op(fields)
+            elif isinstance(node, doc.BinOp):
+                value = self._eval_bin_op(fields)
+            elif isinstance(node, doc.Slice):
+                value = self._eval_slice(fields)
+            else:
+                value = self._eval_expr(node.__class__(**fields))
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_lambda(self, node: doc.Lambda) -> Any:
+        try:
+            value = self._eval_expr(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+        op = fields["op"]
+        if not isinstance(op, (doc.And, doc.Or)):
+            raise TypeError(f"Unexpected operator: {op}")
+        value = self._eval_expr(fields["values"][0])
+        for rhs in fields["values"][1:]:
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["left"])
+        for op, rhs in zip(fields["ops"], fields["comparators"]):
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["operand"])
+        value = _eval_op(fields["op"], values=[value])
+        return value
+
+    def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+        return _eval_op(
+            fields["op"],
+            values=[
+                self._eval_expr(fields["left"]),
+                self._eval_expr(fields["right"]),
+            ],
+        )
+
+    def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+        lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+        lower = self._eval_expr(lower) if lower is not None else None
+        upper = self._eval_expr(upper) if upper is not None else None
+        step = self._eval_expr(step) if step is not None else None
+
+        return slice(lower, upper, step)
+
+    def _eval_expr(self, v: Any) -> Any:
+        return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+    parser: "Parser",
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    value_table = {}
+    if dict_globals is not None:
+        value_table.update(dict_globals)
+    return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+    parser: "Parser",
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    try:
+        return _eval_assign(target, source)
+    except Exception as e:  # pylint: disable=broad-except,invalid-name
+        parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+        raise
+
+
+def _eval_expr(
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    node = doc.from_doc(node)
+    if isinstance(node, ast.expr):
+        node = ast.Expression(body=node)
+    assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+    if dict_globals is None:
+        dict_globals = {}
+    node = ast.fix_missing_locations(node)
+    exe = compile(node, filename="<ast>", mode="eval")
+    return eval(exe, dict_globals)  # pylint: disable=eval-used
+
+
+def _eval_op(
+    op: doc.AST,
+    values: List[Any],
+):
+    op_type = type(op)  # pylint: disable=protected-access
+    for i, v in enumerate(values):
+        v_type = getattr(type(v), "_dispatch_type", None)
+        if v_type is None:
+            continue
+        f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+        if f is not None:
+            return f(*values)
+    return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    target = doc.from_doc(target)
+    assert isinstance(target, ast.expr)
+    RHS_VAR_NAME = "__tvm_rhs_var__"  # pylint: disable=invalid-name
+    rhs_var_name = RHS_VAR_NAME
+    dict_locals = {rhs_var_name: source}
+    mod = ast.fix_missing_locations(
+        ast.Module(
+            body=[
+                ast.Assign(
+                    targets=[target],
+                    value=ast.Name(
+                        id=rhs_var_name,
+                        ctx=ast.Load(),
+                    ),
+                )
+            ],
+            type_ignores=[],
+        )
+    )
+    exe = compile(mod, filename="<ast>", mode="exec")
+    exec(exe, {}, dict_locals)  # pylint: disable=exec-used
+    del dict_locals[rhs_var_name]
+    return dict_locals
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser/ir/__init__.py
similarity index 89%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser/ir/__init__.py
index 926d17b166..bea08cfb1b 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
-import tvm._ffi
-
-tvm._ffi._init_api("script", __name__)
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/entry.py
similarity index 66%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/entry.py
index 555659d0c5..353963f29b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,8 +14,21 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Type
 
-from . import tir
+from tvm.ir import IRModule
 
-from .parser import ir_module, from_source
+from ..entry import parse
+from ..utils import inspect_class_capture
+
+
+def ir_module(f: Type) -> IRModule:
+    if not inspect.isclass(f):
+        raise TypeError(f"Expect a class, but got: {f}")
+
+    return parse(f, inspect_class_capture(f))
+
+
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 55%
copy from python/tvm/script/tir/__init__.py
copy to python/tvm/script/parser/ir/parser.py
index 2f2b4bbc25..aec203c7d9 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,18 +14,26 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+from tvm.ir import ir_builder as I
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
+from .. import dispatch, doc
+from ..parser import Parser
 
-from .prim_func import prim_func
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+    with self.var_table.with_frame():
+        with I.ir_module():
+            with self.with_dispatch_token("ir"):
+                self.visit_body(node.body)
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+    pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+    pass
diff --git a/python/tvm/script/parser/parser.py b/python/tvm/script/parser/parser.py
new file mode 100644
index 0000000000..a89cd10fad
--- /dev/null
+++ b/python/tvm/script/parser/parser.py
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from typing import Any, Callable, Dict, List, Optional, Set, Union
+
+from ...error import DiagnosticError
+from . import dispatch, doc
+from .diagnostics import Diagnostics
+from .evaluator import eval_assign, eval_expr
+from .source import Source
+from .utils import deferred
+from .var_table import VarTable
+
+DEFAULT_VISIT = {
+    "Interactive",
+    "Module",
+    "Expression",
+    "Pass",
+}
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+    def _wrapper(self: "Parser", node: doc.AST) -> None:
+        try:
+            return func(self, node)
+        except DiagnosticError:
+            raise
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+            raise
+
+    return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+    for token in [self.dispatch_tokens[-1], "default"]:
+        func = dispatch.get(token=token, type_name=type_name, default=None)
+        if func is not None:
+            return _dispatch_wrapper(func)
+    return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+    """The TVMScript parser"""
+
+    diag: Diagnostics
+    dispatch_tokens: List[str]
+    var_table: VarTable
+
+    def __init__(self, source: Source) -> None:
+        self.diag = Diagnostics(source)
+        self.dispatch_tokens = ["default"]
+        self.var_table = VarTable()
+
+    def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+        if extra_vars is None:
+            extra_vars = {}
+        with self.var_table.with_frame():
+            for k, v in extra_vars.items():
+                self.var_table.add(k, v)
+            node = self.diag.source.as_ast()
+            self.visit(node)
+
+    def with_dispatch_token(self, token: str):
+        def pop_token():
+            self.dispatch_tokens.pop()
+
+        self.dispatch_tokens.append(token)
+        return deferred(pop_token)
+
+    def eval_expr(
+        self,
+        node: Union[doc.Expression, doc.expr],
+        extra_vars: Optional[Dict[str, Any]] = None,
+    ) -> Any:
+        var_values = self.var_table.get()
+        if extra_vars is not None:
+            for k, v in extra_vars.items():
+                var_values[k] = v
+        return eval_expr(self, node, var_values)
+
+    def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
+        if isinstance(target, (doc.Tuple, doc.List)):
+            vars: Set[str] = set()  # pylint: disable=redefined-builtin
+            for i in target.elts:
+                res = self._duplicate_lhs_check(i)
+                if isinstance(res, bool) and res:
+                    return True
+                assert isinstance(res, set)
+                if vars & res:
+                    return True
+                vars = vars.union(res)
+            return vars
+        elif isinstance(target, doc.Name):
+            return {target.id}
+        else:
+            self.report_error(target, "Invalid type in assign statement")
+            raise NotImplementedError
+
+    def eval_assign(
+        self,
+        target: doc.expr,
+        source: Any,
+        bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+    ) -> Dict[str, Any]:
+        if self._duplicate_lhs_check(target) is True:
+            self.report_error(target, "Duplicate vars assigned.")
+        var_values = eval_assign(self, target, source)
+        for k, v in var_values.items():
+            var = bind_value(self, target, k, v)
+            self.var_table.add(k, var)
+        return var_values
+
+    def report_error(self, node: doc.AST, msg: str) -> None:  # pylint: disable=no-self-use
+        self.diag.error(node, msg)
+
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        name = node.__class__.__name__.split(".")[-1]
+        if name in DEFAULT_VISIT:
+            func = self.generic_visit
+        else:
+            func = getattr(self, "visit_" + name, None)
+        if func is None:
+            raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+        try:
+            func(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_body(self, node: List[doc.stmt]) -> Any:
+        for stmt in node:
+            self.visit(stmt)
+
+    def visit_tvm_annotation(self, node: doc.expr) -> Any:
+        return _dispatch(self, "tvm_annotation")(self, node)
+
+    def visit_FunctionDef(self, node: doc.FunctionDef) -> Any:  # pylint: disable=invalid-name
+        if not node.decorator_list:
+            self.report_error(node, "Function must be decorated")
+        # TODO: only the last decorator is parsed
+        decorator = self.eval_expr(node.decorator_list[-1])
+        if not hasattr(decorator, "dispatch_token"):
+            self.report_error(node, "The parser does not understand the decorator")
+        token = decorator.dispatch_token
+        func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        try:
+            func(self, node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_ClassDef(self, node: doc.ClassDef) -> Any:  # pylint: disable=invalid-name
+        func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        try:
+            func(self, node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_arguments(self, node: doc.arguments) -> Any:
+        return _dispatch(self, "arguments")(self, node)
+
+    def visit_For(self, node: doc.For) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "For")(self, node)
+
+    def visit_While(self, node: doc.While) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "While")(self, node)
+
+    def visit_With(self, node: doc.With) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "With")(self, node)
+
+    def visit_Assign(self, node: doc.Assign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assign")(self, node)
+
+    def visit_Expr(self, node: doc.Expr) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Expr")(self, node)
+
+    def visit_If(self, node: doc.If) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "If")(self, node)
+
+    def visit_AnnAssign(self, node: doc.AnnAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AnnAssign")(self, node)
+
+    def visit_AugAssign(self, node: doc.AugAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AugAssign")(self, node)
+
+    def visit_Assert(self, node: doc.Assert) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assert")(self, node)
+
+    def visit_Return(self, node: doc.Return) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Return")(self, node)
diff --git a/python/tvm/script/parser/source.py b/python/tvm/script/parser/source.py
new file mode 100644
index 0000000000..a7a436d568
--- /dev/null
+++ b/python/tvm/script/parser/source.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name
+import inspect
+import re
+import sys
+from typing import Union
+
+from . import doc
+
+
+class Source:
+    source_name: str
+    start_line: int
+    start_column: int
+    source: str
+    full_source: str
+
+    def __init__(self, program: Union[str, doc.AST]):
+        if isinstance(program, str):
+            self.source_name = "<str>"
+            self.start_line = 1
+            self.start_column = 0
+            self.source = program
+            self.full_source = program
+            return
+
+        self.source_name = inspect.getsourcefile(program)  # type: ignore
+        lines, self.start_line = getsourcelines(program)  # type: ignore
+        if lines:
+            self.start_column = len(lines[0]) - len(lines[0].lstrip())
+        else:
+            self.start_column = 0
+        if self.start_column and lines:
+            self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+        else:
+            self.source = "".join(lines)
+        try:
+            # It will cause a problem when running in Jupyter Notebook.
+            # `mod` will be <module '__main__'>, which is a built-in module
+            # and `getsource` will throw a TypeError
+            mod = inspect.getmodule(program)
+            if mod:
+                self.full_source = inspect.getsource(mod)
+            else:
+                self.full_source = self.source
+        except TypeError:
+            # It's a work around for Jupyter problem.
+            # Since `findsource` is an internal API of inspect, we just use it
+            # as a fallback method.
+            src, _ = inspect.findsource(program)  # type: ignore
+            self.full_source = "".join(src)
+
+    def as_ast(self) -> doc.AST:
+        return doc.parse(self.source)
+
+
+_getfile = inspect.getfile  # pylint: disable=invalid-name
+_findsource = inspect.findsource  # pylint: disable=invalid-name
+
+
+def _patched_inspect_getfile(obj):
+    if not inspect.isclass(obj):
+        return _getfile(obj)
+    mod = getattr(obj, "__module__", None)
+    if mod is not None:
+        file = getattr(sys.modules[mod], "__file__", None)
+        if file is not None:
+            return file
+    for _, member in inspect.getmembers(obj):
+        if inspect.isfunction(member):
+            if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+                return inspect.getfile(member)
+    raise TypeError(f"Source for {obj:!r} not found")
+
+
+def findsource(obj):
+    import linecache  # pylint: disable=import-outside-toplevel
+
+    if not inspect.isclass(obj):
+        return _findsource(obj)
+
+    file = inspect.getsourcefile(obj)
+    if file:
+        linecache.checkcache(file)
+    else:
+        file = inspect.getfile(obj)
+        if not (file.startswith("<") and file.endswith(">")):
+            raise OSError("source code not available")
+
+    module = inspect.getmodule(obj, file)
+    if module:
+        lines = linecache.getlines(file, module.__dict__)
+    else:
+        lines = linecache.getlines(file)
+    if not lines:
+        raise OSError("could not get source code")
+    qual_name = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
+    pat_list = []
+    for qn in qual_name:
+        if qn.endswith("<locals>"):
+            pat_list.append(re.compile(r"^(\s*)def\s*" + qn[:-8] + r"\b"))
+        else:
+            pat_list.append(re.compile(r"^(\s*)class\s*" + qn + r"\b"))
+    for i, line in enumerate(lines):
+        match = pat_list[0].match(line)
+        if match:
+            pat_list.pop(0)
+        if not pat_list:
+            return lines, i
+    raise OSError("could not find class definition")
+
+
+def getsourcelines(obj):
+    obj = inspect.unwrap(obj)
+    lines, l_num = findsource(obj)
+    return inspect.getblock(lines[l_num:]), l_num + 1
+
+
+inspect.getfile = _patched_inspect_getfile
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 78%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..caa51744c8 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from tvm.tir.ir_builder_v2 import *  # pylint: disable=redefined-builtin
 
-from . import tir
-
-from .parser import ir_module, from_source
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..4a0c7c40fb
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+from tvm.tir.ir_builder_v2 import buffer_decl, ptr
+
+from ..entry import parse
+from ..utils import inspect_function_capture
+
+
+def _is_defined_in_class(frames):
+    if len(frames) > 2:
+        maybe_class_frame = frames[2]
+        statement_list = maybe_class_frame[4]
+        if statement_list is None:
+            return False
+        first_statement = statement_list[0]
+        line = first_statement.strip()
+        if line.startswith("class "):
+            return True
+        if line.startswith("@") and "ir_module" in line:
+            return True
+    return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+    if not inspect.isfunction(f):
+        raise TypeError(f"Expect a function, but got: {f}")
+    if _is_defined_in_class(inspect.stack()):
+        return f
+    return parse(f, inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+    def __call__(
+        self,
+        shape,
+        dtype="float32",
+        data=None,
+        strides=None,
+        elem_offset=None,
+        scope="global",
+        align=0,
+        offset_factor=0,
+        buffer_type="",
+        axis_separators=None,
+    ) -> Buffer:
+        return buffer_decl(
+            shape,
+            dtype=dtype,
+            data=data,
+            strides=strides,
+            elem_offset=elem_offset,
+            scope=scope,
+            align=align,
+            offset_factor=offset_factor,
+            buffer_type=buffer_type,
+            axis_separators=axis_separators,
+        )
+
+    def __getitem__(self, keys) -> Buffer:
+        if not isinstance(keys, tuple):
+            return self(keys)
+        if len(keys) >= 2 and not isinstance(keys[1], str):
+            return self(keys)
+        return self(*keys)  # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+    def __call__(self, dtype, storage_scope="global"):
+        if callable(dtype):
+            dtype = dtype().dtype
+        return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+    def __getitem__(self, keys):
+        if not isinstance(keys, tuple):
+            return self(keys)
+        return self(*keys)
+
+
+Buffer = BufferProxy()  # pylint: disable=invalid-name
+Ptr = PtrProxy()  # pylint: disable=invalid-name
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..11ee92ad29
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .. import doc
+from ..dispatch import OpMethod, register_op
+
+
+def _register_expr_op(ty: Type):  # pylint: disable=invalid-name
+    ty._dispatch_type = ty  # pylint: disable=protected-access
+
+    def _and(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.And(a, b)
+
+    def _or(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.Or(a, b)
+
+    def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
+        register_op(ty, op, i)(m)
+
+    for i in [0, 1]:
+        # Case 1. binop
+        r(doc.Add, i, tir.Add)
+        r(doc.Sub, i, tir.Sub)
+        r(doc.Mult, i, tir.Mul)
+        r(doc.Div, i, tir.Div)
+        r(doc.FloorDiv, i, tir.FloorDiv)
+        r(doc.Mod, i, tir.FloorMod)
+        r(doc.LShift, i, lambda a, b: a << b)
+        r(doc.RShift, i, lambda a, b: a >> b)
+        r(doc.BitOr, i, lambda a, b: a | b)
+        r(doc.BitXor, i, lambda a, b: a ^ b)
+        r(doc.BitAnd, i, lambda a, b: a & b)
+        # doc.MatMult <-- not implemented
+        # doc.Pow <-- not implemented
+        # Case 2. cmpop
+        r(doc.Eq, i, tir.EQ)
+        r(doc.NotEq, i, tir.NE)
+        r(doc.Lt, i, tir.LT)
+        r(doc.LtE, i, tir.LE)
+        r(doc.Gt, i, tir.GT)
+        r(doc.GtE, i, tir.GE)
+        # doc.Is <-- not implemented
+        # doc.IsNot <-- not implemented
+        # doc.In <-- not implemented
+        # doc.NotIn <-- not implemented
+        # Case 3. boolop
+        r(doc.And, i, _and)
+        r(doc.Or, i, _or)
+    for i in [0]:
+        #  Case 4. unaryop
+        r(doc.Invert, i, lambda a: ~a)
+        r(doc.Not, i, tir.Not)
+        r(doc.UAdd, i, lambda a: +a)
+        r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..38973f6de2
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,269 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.ir.ir_builder import IRBuilderFrame as Frame
+from tvm.ir.ir_builder import name
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+from tvm.tir import ir_builder_v2 as T
+
+from .. import dispatch, doc
+from ..parser import Parser
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, (Buffer, Var)):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+        raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Var):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+        raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, T.inline):
+        return value.value
+    elif isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, _node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Frame):
+        value.add_callback(partial(value.__exit__, None, None, None))
+        res = value.__enter__()
+        name(var_name, res)
+        return res
+    elif isinstance(value, (Buffer, IterVar)) or (
+        isinstance(value, Var) and not self.var_table.exist(value)
+    ):
+        name(var_name, value)
+        return value
+    elif isinstance(value, PrimExpr):
+        var = T.var(value.dtype)
+        name(var_name, var)
+        frame = T.let(var, value)
+        frame.add_callback(partial(frame.__exit__, None, None, None))
+        frame.__enter__()
+        return var
+    return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+    for_frame = self.eval_expr(node.iter)
+    if not isinstance(for_frame, T.frame.ForFrame):
+        self.report_error(
+            node.iter,
+            "Expect the for loop to be one of the following: "
+            "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+        )
+    with self.var_table.with_frame():
+        with for_frame as iters:
+            self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+    with self.var_table.with_frame():
+        cond = self.eval_expr(node.test)
+        with T.While(cond):
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+    lhs = node.targets[0]
+    rhs = self.eval_expr(node.value)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+    lhs_pos = (
+        node.target.lineno,
+        node.target.col_offset,
+        node.target.end_lineno,
+        node.target.end_col_offset,
+    )
+    rhs_pos = (
+        node.value.lineno,
+        node.value.col_offset,
+        node.value.end_lineno,
+        node.value.end_col_offset,
+    )
+    node.target.ctx = doc.Load(*lhs_pos)
+    with self.var_table.with_frame():
+        lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+        rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+        lhs_expr = self.eval_expr(node.target)
+        rhs_expr = self.eval_expr(node.value)
+        self.var_table.add(lhs_name, lhs_expr)
+        self.var_table.add(rhs_name, rhs_expr)
+        op = doc.BinOp(
+            doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+            node.op,
+            doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+            *lhs_pos,
+        )
+        rhs = self.eval_expr(op)
+    lhs = node.target
+    lhs.ctx = doc.Store(*lhs_pos)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+    lhs = node.target
+    rhs = self.eval_expr(node.value)
+    ann_var = self.visit_tvm_annotation(node.annotation)
+    if not isinstance(ann_var, Var):
+        self.report_error(node.annotation, "Annotation should be Var")
+    self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+    frame = T.let(ann_var, rhs)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+    with contextlib.ExitStack() as stack:
+        stack.enter_context(self.var_table.with_frame())
+        for item in node.items:
+            frame = self.eval_expr(item.context_expr)
+            if not isinstance(frame, Frame):
+                self.report_error(
+                    item.context_expr, "Invalid context expression in the with-statement."
+                )
+            rhs = stack.enter_context(frame)
+            if item.optional_vars is not None:
+                self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+        self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    with self.var_table.with_frame():
+        self.var_table.add("range", T.serial)
+        with T.prim_func():
+            T.func_name(node.name)
+            if node.returns is not None:
+                ret_type = self.eval_expr(node.returns)
+                if callable(ret_type):
+                    ret_type = PrimType(ret_type().dtype)
+                T.func_ret(ret_type)
+            with self.with_dispatch_token("tir"):
+                self.visit(node.args)
+                self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+    # TODO: handle different types of arguments:
+    # - vararg: arg | None
+    # - kwonlyargs: list[arg]
+    # - kw_defaults: list[expr | None]
+    # - kwarg: arg | None
+    # - defaults: list[expr]
+    # - posonlyargs: list[arg]
+    arg: doc.arg
+    for arg in node.args:
+        if arg.annotation is None:
+            self.report_error(arg, "Type annotation is required for function parameters.")
+        param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+        self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+    annotation = self.eval_expr(node)
+    if callable(annotation):
+        annotation = annotation()
+    return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+    res = self.eval_expr(node.value)
+    if isinstance(res, Frame):
+        res.add_callback(partial(res.__exit__, None, None, None))
+        res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+    with self.var_table.with_frame():
+        with T.If(self.eval_expr(node.test)):
+            with T.Then():
+                self.visit_body(node.body)
+            if node.orelse:
+                with T.Else():
+                    self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+    cond = self.eval_expr(node.test)
+    msg = self.eval_expr(node.msg)
+    frame = T.Assert(cond, msg)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="Return")
+def visit_return(self: Parser, node: doc.Return) -> None:
+    self.report_error(node, "Return is not allowed.")
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/utils.py
similarity index 52%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/utils.py
index 923eb97d27..4c08a381c0 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/utils.py
@@ -14,32 +14,35 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
-
+# pylint: disable=missing-docstring
 import inspect
-from typing import Callable
+from contextlib import contextmanager
+from typing import Any, Callable, Dict
+
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+def deferred(f: Callable[[], None]):
+    @contextmanager
+    def context():
+        try:
+            yield
+        finally:
+            f()
 
+    return context()
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+    captured = {
+        **inspect.getclosurevars(func).nonlocals,
+        **func.__globals__,
+    }
+    return captured
 
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
 
-    raise TypeError("Only function definitions are supported.")
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+    result: Dict[str, Any] = {}
+    for _, v in cls.__dict__.items():
+        if inspect.isfunction(v):
+            func_vars = inspect_function_capture(v)
+            result.update(**func_vars)
+    return result
diff --git a/python/tvm/script/parser/var_table.py b/python/tvm/script/parser/var_table.py
new file mode 100644
index 0000000000..32fced625a
--- /dev/null
+++ b/python/tvm/script/parser/var_table.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The symbol table of variable values"""
+
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Set
+
+from .utils import deferred
+
+
+class VarTableFrame:
+    vars: Set[str]
+
+    def __init__(self):
+        self.vars = set()
+
+    def add(self, var: str):
+        if var in self.vars:
+            raise ValueError(f"Variable {var} already defined in current scope")
+        self.vars.add(var)
+
+    def pop_all(self, fn_pop: Callable[[str], None]):
+        for var in self.vars:
+            fn_pop(var)
+        self.vars.clear()
+
+
+class VarTable:
+
+    frames: List[VarTableFrame]
+    name2value: Dict[str, List[Any]]
+
+    def __init__(self):
+        self.frames = []
+        self.name2value = defaultdict(list)
+
+    def with_frame(self):
+        def pop_frame():
+            frame = self.frames.pop()
+            frame.pop_all(lambda name: self.name2value[name].pop())
+
+        self.frames.append(VarTableFrame())
+        return deferred(pop_frame)
+
+    def add(self, var: str, value: Any):
+        self.frames[-1].add(var)
+        self.name2value[var].append(value)
+
+    def get(self) -> Dict[str, Any]:
+        return {key: values[-1] for key, values in self.name2value.items() if values}
+
+    def exist(self, value: Any):
+        for v in self.name2value.values():
+            if v is value:
+                return True
+        return False
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser_v1/__init__.py
similarity index 95%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser_v1/__init__.py
index 555659d0c5..004e947bf6 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser_v1/__init__.py
@@ -17,5 +17,4 @@
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
 
 from . import tir
-
-from .parser import ir_module, from_source
+from .parser import from_source, ir_module
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser_v1/_ffi_api.py
similarity index 100%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser_v1/_ffi_api.py
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py
similarity index 98%
rename from python/tvm/script/context_maintainer.py
rename to python/tvm/script/parser_v1/context_maintainer.py
index f7f16855c7..400baacc4b 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/parser_v1/context_maintainer.py
@@ -16,16 +16,16 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
+from typing import Callable, Dict, List, Mapping, Optional, Union
 
+import synr
 import tvm
 from tvm.ir import Span
 from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
 from tvm.runtime import Object
+from tvm.tir import Buffer, MatchBufferRegion, PrimExpr, Stmt, Var
 from tvm.tir.expr import IterVar
+
 from .tir.node import BufferSlice
 
 
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/parser_v1/diagnostics.py
similarity index 95%
rename from python/tvm/script/diagnostics.py
rename to python/tvm/script/parser_v1/diagnostics.py
index e676461ab3..b15997552f 100644
--- a/python/tvm/script/diagnostics.py
+++ b/python/tvm/script/parser_v1/diagnostics.py
@@ -17,11 +17,11 @@
 """Bridge from synr's (the library used for parsing the python AST)
    DiagnosticContext to TVM's diagnostics
 """
-from synr import DiagnosticContext, ast
-
 import tvm
+from synr import DiagnosticContext, ast
+from tvm.ir.diagnostics import Diagnostic
 from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
+from tvm.ir.diagnostics import DiagnosticLevel, get_renderer
 
 
 class TVMDiagnosticCtx(DiagnosticContext):
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/parser_v1/meta_unparser.py
similarity index 100%
rename from python/tvm/script/meta_unparser.py
rename to python/tvm/script/parser_v1/meta_unparser.py
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser_v1/parser.py
similarity index 99%
rename from python/tvm/script/parser.py
rename to python/tvm/script/parser_v1/parser.py
index 908af081c9..b2b7e388f1 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser_v1/parser.py
@@ -20,35 +20,34 @@ We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
 different python versions. Synr also provides an error handling context that we
 use for error reporting.
 """
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
+import inspect
 import json
 import operator
-import inspect
+
+# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
+import types
 from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
 
 import tvm
+from synr import Transformer, ast, to_ast
 from tvm import IRModule
 from tvm._ffi.base import TVMError
 from tvm.ir import GlobalVar
 from tvm.ir.function import BaseFunc
 from tvm.tir import buffer
 from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
 
+from . import _ffi_api, tir
 from .context_maintainer import ContextMaintainer
+from .diagnostics import TVMDiagnosticCtx
 from .meta_unparser import MetaUnparser
 from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
+from .tir import ty
 from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
+from .tir.node import BufferSlice, Slice
+from .tir.scope_handler import ForScopeHandler, ScopeHandler, WithScopeHandler
 from .tir.special_stmt import SpecialStmt
-from .tir import ty
+from .utils import call_with_error_reporting, synr_span_from_tvm, tvm_span_from_synr
 
 
 class CallArgumentReader(object):
diff --git a/python/tvm/script/registry.py b/python/tvm/script/parser_v1/registry.py
similarity index 97%
rename from python/tvm/script/registry.py
rename to python/tvm/script/parser_v1/registry.py
index e7d90dd515..e816b90f5d 100644
--- a/python/tvm/script/registry.py
+++ b/python/tvm/script/parser_v1/registry.py
@@ -17,7 +17,7 @@
 """TVM Script Parser Function Registry """
 # pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
 import types
-from typing import Union, Callable, Dict, Optional, Any
+from typing import Any, Callable, Dict, Optional, Union
 
 
 class Registry(object):
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser_v1/tir/__init__.py
similarity index 100%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser_v1/tir/__init__.py
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi
similarity index 100%
rename from python/tvm/script/tir/__init__.pyi
rename to python/tvm/script/parser_v1/tir/__init__.pyi
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py
similarity index 98%
rename from python/tvm/script/tir/intrin.py
rename to python/tvm/script/parser_v1/tir/intrin.py
index 382431c229..8c51decf14 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/parser_v1/tir/intrin.py
@@ -17,11 +17,12 @@
 """TVM Script Parser Intrinsic Classes"""
 # pylint: disable=redefined-builtin, relative-beyond-top-level
 import builtins
-from typing import List, Any
+from typing import Any, List
 
 import tvm.tir
+from tvm.target import codegen
+
 from ..registry import register
-from ...target import codegen
 from ..utils import get_param_list, tvm_span_from_synr
 
 
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/parser_v1/tir/node.py
similarity index 98%
rename from python/tvm/script/tir/node.py
rename to python/tvm/script/parser_v1/tir/node.py
index 29e79607fb..cfaf9df476 100644
--- a/python/tvm/script/tir/node.py
+++ b/python/tvm/script/parser_v1/tir/node.py
@@ -17,12 +17,13 @@
 # pylint: disable=redefined-builtin
 """TVM Script nodes."""
 
-from typing import Optional, Union, List, Callable
+from typing import Callable, List, Optional, Union
+
 import synr
 from tvm.arith import Analyzer
+from tvm.ir import Range, Span
 from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
+from tvm.tir import Buffer, BufferLoad, BufferRegion, IntImm, PrimExpr, Ramp
 
 
 class Slice:
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser_v1/tir/prim_func.py
similarity index 97%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser_v1/tir/prim_func.py
index 923eb97d27..a5fdcc15c5 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser_v1/tir/prim_func.py
@@ -19,7 +19,8 @@
 import inspect
 from typing import Callable
 
-from tvm.tir.function import PrimFunc
+from tvm.tir import PrimFunc
+
 from ..parser import from_source
 
 
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/parser_v1/tir/scope_handler.py
similarity index 98%
rename from python/tvm/script/tir/scope_handler.py
rename to python/tvm/script/parser_v1/tir/scope_handler.py
index da7545c9a9..cd2167f4ab 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/parser_v1/tir/scope_handler.py
@@ -16,24 +16,19 @@
 # under the License.
 """TVM Script Parser Scope Handler Classes"""
 # pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
 
-import synr
 import numpy as np
+import synr
 import tvm.tir
+from tvm.ir import Range, Span
 from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
-
-from .node import BufferSlice
+from tvm.tir import Buffer, BufferRegion, ForKind, IterVar, PrimExpr, Stmt, Var
 
 from ..context_maintainer import ContextMaintainer
 from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
 
 
 class ScopeHandler:
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py
similarity index 99%
rename from python/tvm/script/tir/special_stmt.py
rename to python/tvm/script/parser_v1/tir/special_stmt.py
index 15502055b7..42a90f647f 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/parser_v1/tir/special_stmt.py
@@ -17,27 +17,21 @@
 """TVM Script Parser Special Stmt Classes"""
 # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
 # pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
 
 import synr
+import tvm.tir
 from synr import ast
+from tvm.ir import Span
 from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
 from tvm.runtime import Object, String
 from tvm.target import Target
-from tvm.ir import Span
 from tvm.tir import IntImm, IterVar, Var
 
-from .node import BufferSlice
-
 from ..context_maintainer import BlockInfo, ContextMaintainer
 from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
 
 
 def convert_to_int(
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/parser_v1/tir/ty.py
similarity index 99%
rename from python/tvm/script/tir/ty.py
rename to python/tvm/script/parser_v1/tir/ty.py
index 4548102a9e..d9e4b3388d 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/parser_v1/tir/ty.py
@@ -23,6 +23,7 @@ a wrapper for uniform Type system in IR
 from numbers import Integral
 
 import tvm
+
 from .special_stmt import SpecialStmt, convert_to_int
 
 
diff --git a/python/tvm/script/utils.py b/python/tvm/script/parser_v1/utils.py
similarity index 97%
rename from python/tvm/script/utils.py
rename to python/tvm/script/parser_v1/utils.py
index c655a62237..f358a90081 100644
--- a/python/tvm/script/utils.py
+++ b/python/tvm/script/parser_v1/utils.py
@@ -16,13 +16,12 @@
 # under the License.
 """Helper functions in TVM Script Parser"""
 
-from typing import Callable, List, Any, Optional, Tuple
-
 import inspect
-import synr
+from typing import Any, Callable, List, Optional, Tuple
 
-from tvm.ir import Span, SourceName
+import synr
 from tvm.error import DiagnosticError
+from tvm.ir import SourceName, Span
 
 
 def get_param_list(
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index ada5c369ad..b0d72d283b 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -22,9 +22,9 @@ from numbers import Integral as _Integral
 from typing import List
 
 import tvm._ffi
+import tvm.arith._ffi_api
 import tvm.tir
 import tvm.tir._ffi_api
-import tvm.arith._ffi_api
 from tvm._ffi.base import string_types
 from tvm.ir import Array
 from tvm.runtime import convert
@@ -420,11 +420,14 @@ def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimF
     )
     for tensor, buffer in zip(input_tensors, input_buffers):
         # TODO(csullivan): Can a stronger comparison between Tensor<>Buffer be made?
-        assert tensor.shape == buffer.shape, (
-            "The input input_tensors provided do not match the input buffers in the ",
-            "primfunc. Please check that the order of input te.Input_Tensors and the ",
-            "order of the primfunc variables in the params list agree.",
-        )
+        assert len(tensor.shape) == len(buffer.shape)
+        for d1, d2 in zip(tensor.shape, buffer.shape):
+            assert d1 == d2, (
+                "The input input_tensors provided do not match the input buffers in the ",
+                "primfunc. Please check that the order of input te.Input_Tensors and the ",
+                "order of the primfunc variables in the params list agree.",
+            )
+
     output = extern(
         [buf.shape for buf in outputs],
         input_tensors,
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index c64b7dfe71..41a3f86233 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -19,50 +19,185 @@
 from tvm.ir import PrimExpr
 from tvm.runtime import const
 
-from .buffer import Buffer, decl_buffer, DataProducer
-from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
-from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
-from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
-from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
-from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
-
-from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
+from . import analysis
+from . import ir_builder_v1 as ir_builder
+from . import schedule, stmt_functor, transform, usmp
+from .buffer import Buffer, DataProducer, decl_buffer
+from .data_layout import BijectiveLayout, Layout, bijective_layout, layout
+from .expr import (
+    EQ,
+    GE,
+    GT,
+    LE,
+    LT,
+    NE,
+    Add,
+    And,
+    Any,
+    Broadcast,
+    BufferLoad,
+    Call,
+    CallEffectKind,
+    Cast,
+    CommReducer,
+    Div,
+    FloatImm,
+    FloorDiv,
+    FloorMod,
+    IntImm,
+    IterVar,
+    Let,
+    Load,
+    Max,
+    Min,
+    Mod,
+    Mul,
+    Not,
+    Or,
+    ProducerLoad,
+    Ramp,
+    Reduce,
+    Select,
+    Shuffle,
+    SizeVar,
+    StringImm,
+    Sub,
+    Var,
+)
+from .function import IndexMap, PrimFunc, TensorIntrin
+from .op import (
+    TVMBackendAllocWorkspace,
+    TVMBackendFreeWorkspace,
+    abs,
+    acos,
+    acosh,
+    address_of,
+    all,
+    any,
+    asin,
+    asinh,
+    assume,
+    atan,
+    atan2,
+    atanh,
+    call_cpacked,
+    call_cpacked_lowered,
+    call_extern,
+    call_intrin,
+    call_llvm_intrin,
+    call_llvm_pure_intrin,
+    call_packed,
+    call_packed_lowered,
+    call_pure_extern,
+    ceil,
+    ceildiv,
+    clz,
+    comm_reducer,
+    copysign,
+    cos,
+    cosh,
+    div,
+    erf,
+    exp,
+    exp2,
+    exp10,
+    floor,
+    floordiv,
+    floormod,
+    fmod,
+    hypot,
+    if_then_else,
+    indexdiv,
+    indexmod,
+    isfinite,
+    isinf,
+    isnan,
+    isnullptr,
+    ldexp,
+    likely,
+    log,
+    log1p,
+    log2,
+    log10,
+    lookup_param,
+    max,
+    max_value,
+    min,
+    min_value,
+    mma_fill,
+    mma_store,
+    nearbyint,
+    nextafter,
+    popcount,
+    power,
+    ptx_commit_group,
+    ptx_cp_async,
+    ptx_ldmatrix,
+    ptx_mma,
+    ptx_mma_sp,
+    ptx_wait_group,
+    q_multiply_shift,
+    ret,
+    round,
+    rsqrt,
+    shift_left,
+    shift_right,
+    sigmoid,
+    sin,
+    sinh,
+    sqrt,
+    sum,
+    tan,
+    tanh,
+    trace,
+    trunc,
+    truncdiv,
+    truncmod,
+    tvm_access_ptr,
+    tvm_bmma_sync,
+    tvm_fill_fragment,
+    tvm_load_matrix_sync,
+    tvm_mma_sync,
+    tvm_stack_alloca,
+    tvm_stack_make_array,
+    tvm_stack_make_shape,
+    tvm_store_matrix_sync,
+    tvm_struct_get,
+    tvm_struct_set,
+    tvm_thread_allreduce,
+    tvm_throw_last_error,
+    tvm_tuple,
+    undef,
+    vectorcombine,
+    vectorhigh,
+    vectorlow,
+)
+from .schedule import BlockScope, Schedule, ScheduleError, ScheduleState, StmtSRef
 from .stmt import (
-    BufferStore,
-    BufferRealize,
-    Store,
-    ProducerStore,
     Allocate,
     AllocateConst,
+    AssertStmt,
     AttrStmt,
+    Block,
+    BlockRealize,
+    BufferRealize,
+    BufferRegion,
+    BufferStore,
     DeclBuffer,
+    Evaluate,
+    For,
+    ForKind,
+    IfThenElse,
+    LetStmt,
+    MatchBufferRegion,
+    Prefetch,
+    ProducerRealize,
+    ProducerStore,
+    SeqStmt,
+    Stmt,
+    Store,
+    While,
+    stmt_list,
+    stmt_seq,
+    type_annotation,
 )
-
-from .stmt import ProducerRealize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
-from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-
-from .function import PrimFunc, TensorIntrin, IndexMap
-
-from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
-from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
-from .op import sin, sinh, asin, asinh
-from .op import cos, cosh, acos, acosh
-from .op import tan, tanh, atan, atan2, atanh
-from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
-from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
-from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
-from .op import comm_reducer, min, max, sum
-from .op import q_multiply_shift
-
-from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
-
-from . import schedule
-from . import ir_builder
-from . import transform
-from . import analysis
-from . import stmt_functor
-from . import usmp
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/tir/_ffi_ir_builder_api.py
similarity index 88%
rename from python/tvm/script/_ffi_api.py
rename to python/tvm/tir/_ffi_ir_builder_api.py
index 926d17b166..61b288d498 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/tir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
 import tvm._ffi
 
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder.tir", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 13674daa24..ea220fea22 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -21,9 +21,9 @@ from typing import Dict, List, Union
 from tvm import Object
 from tvm.ir import IRModule
 from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
+from tvm.tir.stmt import Block, BufferRegion, PrimExpr, Stmt
 
-from .. import Buffer, Stmt
+from ..buffer import Buffer
 from ..function import PrimFunc
 from . import _ffi_api
 
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index d9b0aec76a..e74eb15453 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -16,11 +16,12 @@
 # under the License.
 """Abstraction for array data structures."""
 from numbers import Integral
-import tvm._ffi
 
+import tvm._ffi
 from tvm._ffi.base import string_types
+from tvm.ir import PointerType, PrimExpr, PrimType, Range
 from tvm.runtime import Object, convert
-from tvm.ir import PrimExpr, PointerType, PrimType
+
 from . import _ffi_api
 
 
@@ -176,6 +177,40 @@ class Buffer(Object):
         """
         return _ffi_api.BufferOffsetOf(self, indices)  # type: ignore
 
+    def __getitem__(self, indices):
+        from ..arith import Analyzer  # pylint: disable=import-outside-toplevel
+        from .expr import BufferLoad, Ramp  # pylint: disable=import-outside-toplevel
+        from .stmt import BufferRegion  # pylint: disable=import-outside-toplevel
+
+        if not isinstance(indices, (tuple, list)):
+            indices = [indices]
+        if any(isinstance(index, slice) and index.step is None for index in indices):
+            region = []
+            for index in indices:
+                if isinstance(index, slice):
+                    region.append(
+                        Range.from_min_extent(
+                            index.start, Analyzer().simplify(index.stop - index.start)
+                        )
+                    )
+                else:
+                    region.append(Range.from_min_extent(index, 1))
+            return BufferRegion(self, region)
+        else:
+            expr_indices = []
+            for index in indices:
+                if isinstance(index, slice):
+                    lanes = Analyzer().simplify(
+                        (index.stop - index.start + index.step - 1) // index.step
+                    )
+                    if lanes == 1:
+                        expr_indices.append(index.start)
+                    else:
+                        expr_indices.append(Ramp(index.start, index.step, int(lanes)))
+                else:
+                    expr_indices.append(index)
+            return BufferLoad(self, expr_indices)
+
 
 def decl_buffer(
     shape,
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
   assert(y.a == x)
 """
 from typing import Optional, Union
-from tvm import ir
+
 import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
 from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
 
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
 from . import _ffi_api
+from . import generic as _generic
 
 
 def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
 class ExprOp(object):
     """Operator overloading for Expr like expressions."""
 
-    # TODO(tkonolige): use inspect to add source information to these objects
-
     def __add__(self, other):
         return _generic.add(self, other)
 
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
     """
 
     def __init__(self, condition, true_value, false_value, span=None):
+        if isinstance(condition, bool):
+            condition = IntImm("bool", condition)
         self.__init_handle_by_constructor__(
             _ffi_api.Select, condition, true_value, false_value, span  # type: ignore
         )
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index f06376147b..6c57e27b82 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -394,7 +394,7 @@ class IndexMap(Object):
                 raise TypeError(
                     "Expected mapping function to return list of "
                     "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR.  "
-                    "Instead received {val} of type {type(val)}."
+                    f"Instead received {val} of type {type(val)}."
                 )
 
         return IndexMap(initial_indices, final_indices), axis_separators
diff --git a/python/tvm/tir/ir_builder_frame.py b/python/tvm/tir/ir_builder_frame.py
new file mode 100644
index 0000000000..a1f457aad2
--- /dev/null
+++ b/python/tvm/tir/ir_builder_frame.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IRBuilder for TIR"""
+
+from typing import List
+
+from tvm._ffi import register_object as _register_object
+from tvm.ir.ir_builder import IRBuilderFrame
+
+from .buffer import Buffer
+from .expr import Var
+
+
+@_register_object("ir_builder.tir.TIRFrame")
+class TIRFrame(IRBuilderFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.BlockFrame")
+class BlockFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ForFrame")
+class ForFrame(TIRFrame):
+    def __enter__(self) -> List[Var]:
+        super().__enter__()
+        return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("ir_builder.tir.PrimFuncFrame")
+class PrimFuncFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+    ...
+
+
+@_register_object("ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder_v1.py
similarity index 99%
rename from python/tvm/tir/ir_builder.py
rename to python/tvm/tir/ir_builder_v1.py
index ce8cd1b403..8a68faaac6 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder_v1.py
@@ -17,13 +17,13 @@
 """Developer API of IR node builder make function."""
 import tvm
 from tvm._ffi.base import string_types
-from tvm.runtime import ObjectGeneric, convert, const
 from tvm.ir import container as _container
+from tvm.runtime import ObjectGeneric, const, convert
 
-from . import stmt as _stmt
-from . import expr as _expr
 from . import buffer as _buffer
+from . import expr as _expr
 from . import op
+from . import stmt as _stmt
 
 
 class WithScope(object):
diff --git a/python/tvm/tir/ir_builder_v2.py b/python/tvm/tir/ir_builder_v2.py
new file mode 100644
index 0000000000..b4c8aa3a1d
--- /dev/null
+++ b/python/tvm/tir/ir_builder_v2.py
@@ -0,0 +1,949 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""IRBuilder for TIR"""
+import functools
+import inspect
+from numbers import Integral
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
+from tvm.target import Target as target
+
+from . import _ffi_ir_builder_api as _ffi_api
+from . import ir_builder_frame as frame
+from . import op as _tir_op
+from .buffer import Buffer
+from .expr import Broadcast as broadcast
+from .expr import BufferLoad, CommReducer, IntImm, IterVar, Let, PrimExpr
+from .expr import Ramp as ramp, Cast
+from .expr import Select, Shuffle, StringImm, Var
+from .generic import cast
+from .stmt import BufferRegion, type_annotation
+
+
+def buffer_decl(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.BufferDecl(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def ptr(dtype, storage_scope="global"):
+    return _ffi_api.Ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+
+def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
+    return _ffi_api.Block(name, no_realize)  # pylint: disable=no-member # type: ignore
+
+
+def init() -> frame.BlockInitFrame:
+    return _ffi_api.Init()  # pylint: disable=no-member # type: ignore
+
+
+def where(predicate) -> None:
+    if isinstance(predicate, bool):
+        predicate = IntImm("bool", predicate)
+    if isinstance(predicate, int):
+        if predicate in [0, 1]:
+            predicate = IntImm("bool", predicate)
+        else:
+            raise ValueError("Invalid value for predicate: {}".format(predicate))
+    _ffi_api.Where(predicate)  # pylint: disable=no-member # type: ignore
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]  # type: ignore
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Reads(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Writes(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.BlockAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def alloc_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.AllocBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def _as_range(dom) -> Range:
+    if isinstance(dom, Range):
+        return dom
+    if isinstance(dom, (list, tuple)):
+        return Range(dom[0], dom[1])
+    return Range(0, dom)
+
+
+class axis:  # pylint: disable=invalid-name
+    @staticmethod
+    def spatial(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisSpatial(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def reduce(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisReduce(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def scan(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisScan(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def opaque(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisOpaque(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def remap(kinds, bindings, dtype="int32") -> Union[List[IterVar], IterVar]:
+        iter_vars = _ffi_api.AxisRemap(  # pylint: disable=no-member # type: ignore
+            kinds, bindings, dtype
+        )
+        return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+    S = spatial  # pylint: disable=invalid-name
+    R = reduce  # pylint: disable=invalid-name
+
+
+def serial(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Serial(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def parallel(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Parallel(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def vectorized(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Vectorized(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def unroll(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Unroll(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def thread_binding(
+    start,
+    stop=None,
+    thread=None,
+    *,
+    annotations=None,
+) -> frame.ForFrame:
+    if thread is None:
+        if not isinstance(stop, str):
+            raise ValueError("Thread cannot be None for thread_binding")
+        thread = stop
+        stop = start
+        start = 0
+    elif stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.ThreadBinding(  # pylint: disable=no-member # type: ignore
+        start, stop, thread, annotations
+    )
+
+
+def grid(*extents) -> frame.ForFrame:
+    return _ffi_api.Grid(extents)  # pylint: disable=no-member # type: ignore
+
+
+def prim_func() -> frame.PrimFuncFrame:
+    return _ffi_api.PrimFunc()  # pylint: disable=no-member # type: ignore
+
+
+def arg(name, obj):
+    return _ffi_api.Arg(name, obj)  # pylint: disable=no-member # type: ignore
+
+
+def func_name(name: str) -> str:
+    return _ffi_api.FuncName(name)  # pylint: disable=no-member # type: ignore
+
+
+def func_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.FuncAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def func_ret(ret_type) -> Type:
+    return _ffi_api.FuncRet(ret_type)  # pylint: disable=no-member # type: ignore
+
+
+def match_buffer(
+    param,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.MatchBuffer(  # pylint: disable=no-member # type: ignore
+        param,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def preflattened_buffer(
+    postflattened,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> None:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    _ffi_api.PreflattenedBuffer(  # pylint: disable=no-member # type: ignore
+        postflattened,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Assert(condition, message)  # pylint: disable=no-member # type: ignore
+
+
+def let(
+    v: Var,
+    value: PrimExpr,
+    body: PrimExpr = None,
+) -> frame.LetFrame:
+    if body is None:
+        return _ffi_api.Let(v, value)  # pylint: disable=no-member # type: ignore
+    return Let(v, value, body)
+
+
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str = "",
+    condition: PrimExpr = None,
+    annotations=None,
+) -> frame.AllocateFrame:
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.Allocate(  # pylint: disable=no-member # type: ignore
+        extents, dtype, scope, condition, annotations
+    )
+
+
+def allocate_const(
+    data: List[PrimExpr],
+    dtype: str,
+    extents: List[PrimExpr],
+    annotations=None,
+) -> frame.AllocateConstFrame:
+
+    return _ffi_api.AllocateConst(  # pylint: disable=no-member # type: ignore
+        ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+    )
+
+
+def realize(
+    buffer_slice: BufferRegion,
+    storage_scope: str,
+    condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+    return _ffi_api.Realize(  # pylint: disable=no-member # type: ignore
+        buffer_slice, storage_scope, condition
+    )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+    node = convert(node)
+    value = convert(value)
+    return _ffi_api.Attr(node, attr_key, value)  # pylint: disable=no-member # type: ignore
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.While(condition)  # pylint: disable=no-member # type: ignore
+
+
+def If(condition: PrimExpr) -> frame.IfFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.If(condition)  # pylint: disable=no-member # type: ignore
+
+
+def Then() -> frame.ThenFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Then()  # pylint: disable=no-member # type: ignore
+
+
+def Else() -> frame.ElseFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Else()  # pylint: disable=no-member # type: ignore
+
+
+def decl_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> frame.DeclBufferFrame:
+
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.DeclBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def launch_thread(
+    iter_var: IterVar,  # pylint: disable=redefined-outer-name
+    extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+    return _ffi_api.LaunchThread(iter_var, extent)  # pylint: disable=no-member # type: ignore
+
+
+def env_thread(thread_tag: str) -> IterVar:
+    return _ffi_api.EnvThread(thread_tag)  # pylint: disable=no-member # type: ignore
+
+
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+    from tvm.arith import Analyzer  # pylint: disable=import-outside-toplevel
+
+    expr_indices = []
+    for index in indices:
+        if isinstance(index, slice):
+            step = 1 if index.step is None else index.step
+            lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+            if lanes == 1:
+                expr_indices.append(index.start)
+            else:
+                expr_indices.append(ramp(index.start, step, int(lanes)))
+        else:
+            expr_indices.append(index)
+    if isinstance(value, bool) and buffer.dtype == "bool":
+        value = IntImm("bool", value)
+    return _ffi_api.BufferStore(  # pylint: disable=no-member # type: ignore
+        buffer, value, expr_indices
+    )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+    return _ffi_api.Prefetch(buffer, indices)  # pylint: disable=no-member # type: ignore
+
+
+def evaluate(value: PrimExpr) -> None:
+    if isinstance(value, str):
+        value = StringImm(value)
+    return _ffi_api.Evaluate(value)  # pylint: disable=no-member # type: ignore
+
+
+def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x4(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Boolean(expr)  # pylint: disable=no-member # type: ignore
+
+
+def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Handle(expr)  # pylint: disable=no-member # type: ignore
+
+
+def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Void(expr)  # pylint: disable=no-member # type: ignore
+
+
+def min(a, b):  # pylint: disable=redefined-builtin
+    """Compute the minimum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.min(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def max(a, b):  # pylint: disable=redefined-builtin
+    """Compute the maximum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.max(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def var(dtype, name="") -> Var:
+    return Var(name, dtype)  # pylint: disable=no-member # type: ignore
+
+
+def iter_var(v, dom, iter_type, thread_tag):
+    iter_type = getattr(IterVar, iter_type)
+    return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner, identity):
+    """Create a CommReducer from lambda inputs/outputs and the identities"""
+    params = inspect.signature(combiner).parameters
+    num_args = len(params)
+    args = []
+    for name, i in zip(params.keys(), identity + identity):
+        if isinstance(i, int):
+            args.append(Var(name, "int32"))
+        else:
+            args.append(Var(name, i.dtype))
+    res = combiner(*args)
+    if not isinstance(res, tuple):
+        res = (res,)
+    return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def llvm_lookup_intrinsic_id(name):
+    # pylint: disable=import-outside-toplevel
+    from tvm.target.codegen import llvm_lookup_intrinsic_id as f
+
+    # pylint: enable=import-outside-toplevel
+    return f(name)
+
+
+def _op_wrapper(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            kwargs.pop("dtype")
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+def _dtype_forward(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            args = (kwargs.pop("dtype"),) + args
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+# pylint: disable=invalid-name
+
+buffer_var = ptr
+abs = _op_wrapper(_tir_op.abs)  # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round)  # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+vectorlow = _dtype_forward(_tir_op.vectorlow)
+vectorhigh = _dtype_forward(_tir_op.vectorhigh)
+vectorcombine = _dtype_forward(_tir_op.vectorcombine)
+assume = _op_wrapper(_tir_op.assume)
+undef = _op_wrapper(_tir_op.undef)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+
+class inline:
+    def __init__(self, value) -> None:
+        self.value = value
+        self.i = 0
+
+    def __iter__(self):
+        def f():
+            for i in self.value:
+                yield inline(i)
+
+        return f()
+
+
+# pylint: enable=invalid-name
+
+
+__all__ = [
+    "Assert",
+    "Cast",
+    "Else",
+    "If",
+    "Let",
+    "Select",
+    "Shuffle",
+    "TVMBackendAllocWorkspace",
+    "TVMBackendFreeWorkspace",
+    "Then",
+    "While",
+    "abs",
+    "acos",
+    "acosh",
+    "address_of",
+    "alloc_buffer",
+    "allocate",
+    "allocate_const",
+    "arg",
+    "asin",
+    "asinh",
+    "assume",
+    "atan",
+    "atan2",
+    "atanh",
+    "attr",
+    "axis",
+    "block",
+    "block_attr",
+    "boolean",
+    "broadcast",
+    "buffer_decl",
+    "buffer_store",
+    "buffer_var",
+    "call_cpacked",
+    "call_cpacked_lowered",
+    "call_extern",
+    "call_intrin",
+    "call_llvm_intrin",
+    "call_llvm_pure_intrin",
+    "call_packed",
+    "call_packed_lowered",
+    "call_pure_extern",
+    "cast",
+    "ceil",
+    "ceildiv",
+    "clz",
+    "comm_reducer",
+    "copysign",
+    "cos",
+    "cosh",
+    "env_thread",
+    "erf",
+    "evaluate",
+    "exp",
+    "exp10",
+    "exp2",
+    "decl_buffer",
+    "fabs",
+    "float16",
+    "float32",
+    "float64",
+    "float8",
+    "floor",
+    "floordiv",
+    "floormod",
+    "fmod",
+    "func_attr",
+    "func_name",
+    "func_ret",
+    "grid",
+    "handle",
+    "hypot",
+    "if_then_else",
+    "infinity",
+    "init",
+    "inline",
+    "int16",
+    "int32",
+    "int32x16",
+    "int32x4",
+    "int32x8",
+    "int64",
+    "int8",
+    "isfinite",
+    "isinf",
+    "isnan",
+    "isnullptr",
+    "iter_var",
+    "launch_thread",
+    "ldexp",
+    "let",
+    "likely",
+    "llvm_lookup_intrinsic_id",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "lookup_param",
+    "match_buffer",
+    "max",
+    "max_value",
+    "min",
+    "min_value",
+    "mma_fill",
+    "mma_store",
+    "nearbyint",
+    "nextafter",
+    "parallel",
+    "popcount",
+    "power",
+    "prefetch",
+    "preflattened_buffer",
+    "prim_func",
+    "ptr",
+    "ptx_commit_group",
+    "ptx_cp_async",
+    "ptx_ldmatrix",
+    "ptx_mma",
+    "ptx_mma_sp",
+    "ptx_wait_group",
+    "q_multiply_shift",
+    "ramp",
+    "reads",
+    "realize",
+    "reinterpret",
+    "ret",
+    "round",
+    "rsqrt",
+    "serial",
+    "shift_left",
+    "shift_right",
+    "sigmoid",
+    "sin",
+    "sinh",
+    "sqrt",
+    "tan",
+    "tanh",
+    "target",
+    "thread_binding",
+    "trunc",
+    "truncdiv",
+    "truncmod",
+    "tvm_access_ptr",
+    "tvm_bmma_sync",
+    "tvm_call_cpacked",
+    "tvm_call_cpacked_lowered",
+    "tvm_call_packed",
+    "tvm_call_packed_lowered",
+    "tvm_fill_fragment",
+    "tvm_load_matrix_sync",
+    "tvm_mma_sync",
+    "tvm_stack_alloca",
+    "tvm_stack_make_array",
+    "tvm_stack_make_shape",
+    "tvm_store_matrix_sync",
+    "tvm_struct_get",
+    "tvm_struct_set",
+    "tvm_thread_allreduce",
+    "tvm_throw_last_error",
+    "tvm_tuple",
+    "type_annotation",
+    "uint16",
+    "uint32",
+    "uint64",
+    "uint8",
+    "undef",
+    "unroll",
+    "var",
+    "vectorcombine",
+    "vectorhigh",
+    "vectorized",
+    "vectorlow",
+    "void",
+    "where",
+    "writes",
+]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 17005b04a4..e2cad37bb6 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -14,17 +14,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=redefined-builtin, invalid-name
+# pylint: disable=redefined-builtin,invalid-name,no-member,protected-access
 """Operators used in TIR expression."""
+import warnings
 from typing import Any, Optional
+
 import tvm._ffi
-from tvm.ir.base import Span
-from tvm.runtime import convert, const
 from tvm.ir import Array, Op
+from tvm.ir.base import Span
+from tvm.runtime import const, convert
 
-from .buffer import Buffer
-from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer
 from . import _ffi_api
+from .buffer import Buffer
+from .expr import Call, CommReducer, PrimExprWithOp, StringImm, Var
 
 
 def _pack_buffer(buf, span=None):
@@ -100,6 +102,64 @@ def call_cpacked(*args, span=None):
     return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span)
 
 
+def call_packed_lowered(*args, span=None):
+    """Lowered version of call packed.
+
+    The argument to packed function can be Expr or Buffer.
+    The argument is the corresponding POD type when Expr is presented.
+
+    When the argument is Buffer, the corresponding PackedFunc
+    will recieve an TVMArrayHandle whose content is valid during the callback period.
+    If the PackedFunc is a python callback, then the corresponding argument is NDArray.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span)
+
+
+def call_cpacked_lowered(*args, span=None):
+    """Lowered version of call c-packed.
+
+    Same as call_packed, except that the first argument is the function name
+    (as in call_extern), and the last argument is the resource handle.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span)
+
+
 def call_intrin(dtype, func_name, *args, span=None):
     """Build expression by calling an intrinsic function.
 
@@ -151,7 +211,10 @@ def call_pure_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span
+        dtype,
+        Op.get("tir.call_pure_extern"),
+        convert((StringImm(func_name),) + args),
+        span,
     )
 
 
@@ -178,7 +241,10 @@ def call_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span
+        dtype,
+        Op.get("tir.call_extern"),
+        convert((StringImm(func_name),) + args),
+        span=span,
     )
 
 
@@ -207,10 +273,22 @@ def call_llvm_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    if llvm_id == 0:
+        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
     return call_intrin(
-        dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+        dtype,
+        Op.get("tir.call_llvm_intrin"),
+        tvm.tir.const(llvm_id, "uint32"),
+        *args,
+        span=span,
     )
 
 
@@ -239,8 +317,16 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    if llvm_id == 0:
+        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
     return call_intrin(
         dtype,
         Op.get("tir.call_llvm_pure_intrin"),
@@ -250,6 +336,326 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     )
 
 
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+    return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+    return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
+def tvm_stack_alloca(dtype_str, num):
+    return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num)
+
+
+def tvm_stack_make_shape(*args):
+    return call_intrin("handle", "tir.tvm_stack_make_shape", *args)
+
+
+def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
+    return call_intrin(
+        "handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset
+    )
+
+
+def address_of(buffer_load, span=None):
+    """Returns the address of an element in the buffer
+
+    Parameters
+    ----------
+    buffer_load: BufferLoad
+        The buffer load.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+
+
+def lookup_param(param_name, span=None):
+    """Returns the param by name
+
+    Parameters
+    ----------
+    param_name : str
+        The name of param.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.lookup_param", param_name, span=span)
+
+
+def tvm_tuple(*value):
+    return call_intrin("handle", "tir.tvm_tuple", *value)
+
+
+def tvm_struct_get(arr, index, field_id, dtype):
+    return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field_id)
+
+
+def tvm_struct_set(arr, index, field_id, value):
+    return call_intrin("handle", "tir.tvm_struct_set", arr, index, field_id, value)
+
+
+def tvm_thread_allreduce(*freduce_args):
+    return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_load_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def tvm_mma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_mma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_bmma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_bmma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_fill_fragment(fragment, m, n, k, index, value):
+    return call_intrin(
+        "handle",
+        "tir.tvm_fill_fragment",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        value,
+    )
+
+
+def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_store_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def ptx_mma(  # pylint: disable=missing-docstring
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    saturate,
+    operator=None,
+):
+    if operator is None:
+        return call_intrin(
+            dtype,
+            "tir.ptx_mma",
+            shape,
+            A_layout,
+            B_layout,
+            A_dtype,
+            B_dtype,
+            C_dtype,
+            multiplicand_a,
+            a_index,
+            multiplicand_b,
+            b_index,
+            accumulator,
+            c_index,
+            saturate,
+        )
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        saturate,
+        operator,
+    )
+
+
+def ptx_mma_sp(  # pylint: disable=missing-docstring
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    metadata,
+    meta_index,
+    sparse_selector,
+    saturate,
+):
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma_sp",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        metadata,
+        meta_index,
+        sparse_selector,
+        saturate,
+    )
+
+
+def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
+    return call_intrin(
+        dtype,
+        "tir.ptx_ldmatrix",
+        trans,
+        num,
+        type,
+        local_ptr,
+        local_offset,
+        smem_ptr,
+        smem_offset,
+    )
+
+
+def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
+    return call_intrin(
+        dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
+    )
+
+
+def ptx_commit_group():
+    return call_intrin("", "tir.ptx_commit_group")
+
+
+def ptx_wait_group(num):
+    return call_intrin("", "tir.ptx_wait_group", num)
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+    return call_intrin(
+        dtype,
+        "tir.mma_store",
+        m,
+        n,
+        dst_ptr,
+        src_ptr,
+        src_offset,
+        dst_stride,
+    )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+    return call_intrin(
+        dtype,
+        "tir.mma_fill",
+        local_size,
+        local_ptr,
+        offset,
+    )
+
+
+def vectorlow(dtype, vec):
+    return call_intrin(dtype, "tir.vectorlow", vec)
+
+
+def vectorhigh(dtype, vec):
+    return call_intrin(dtype, "tir.vectorhigh", vec)
+
+
+def vectorcombine(dtype, vec1, vec2):
+    return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)
+
+
+def assume(cond=None):
+    return call_intrin("int32", "tir.assume", cond)
+
+
+def undef():
+    return call_intrin("int32", "tir.undef")
+
+
 def ret(val):
     """Create a tir return expression
 
@@ -286,9 +692,9 @@ def any(*args, span=None):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    val = _ffi_api._OpOr(args[0], args[1], span)  # type: ignore
+    val = _ffi_api._OpOr(args[0], args[1], span)  # type: ignore # pylint: disable=no-member,protected-access
     for i in range(2, len(args)):
-        val = _ffi_api._OpOr(val, args[i], span)  # type: ignore
+        val = _ffi_api._OpOr(val, args[i], span)  # type: ignore # pylint: disable=no-member,protected-access
     return val
 
 
@@ -313,9 +719,9 @@ def all(*args, span=None):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    val = _ffi_api._OpAnd(args[0], args[1], span)  # type: ignore
+    val = _ffi_api._OpAnd(args[0], args[1], span)  # type: ignore  # pylint: disable=no-member,protected-access
     for i in range(2, len(args)):
-        val = _ffi_api._OpAnd(val, args[i], span)  # type: ignore
+        val = _ffi_api._OpAnd(val, args[i], span)  # type: ignore  # pylint: disable=no-member,protected-access
     return val
 
 
@@ -394,6 +800,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
     return _ffi_api.max_value(dtype, span)  # type: ignore
 
 
+def infinity(dtype: str, span: Optional[Span] = None) -> Any:
+    """infinity value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The infinity value of dtype.
+    """
+    return _ffi_api.infinity(dtype, span)  # type: ignore
+
+
+def reinterpret(dtype, value, span=None) -> Any:
+    """infinity value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    value : PrimExpr
+        The input value.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The reinterpret cast value of dtype.
+    """
+    return _ffi_api.reinterpret(dtype, value, span)  # type: ignore
+
+
 def exp(x):
     """Take exponential of input x.
 
@@ -998,6 +1445,25 @@ def ldexp(x1, x2):
     return call_intrin(x1.dtype, "tir.ldexp", x1, x2)  # type: ignore
 
 
+def likely(cond, span=None):
+    """Mark condition as likely.
+
+    Parameters
+    ----------
+    cond : PrimExpr
+        Input argument.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    y : PrimExpr
+        The marked expression.
+    """
+    return _ffi_api.likely(cond, span)  # type: ignore
+
+
 def isnan(x, span=None):
     """Check if input value is Nan.
 
@@ -1017,6 +1483,25 @@ def isnan(x, span=None):
     return _ffi_api.isnan(x, span)  # type: ignore
 
 
+def isnullptr(x, span=None):
+    """Check if input value is nullptr.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_intrin("bool", "tir.isnullptr", x, span=span)  # type: ignore
+
+
 def isfinite(x, span=None):
     """Check if input value is finite.
 
@@ -1122,6 +1607,42 @@ def q_multiply_shift(x, y, q, s):
     return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)
 
 
+def shift_left(x, y, span=None):
+    """Return the result of x left shifted by y bits.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+    y : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    z : PrimExpr
+        The result.
+    """
+    return _ffi_api.left_shift(x, y, span)
+
+
+def shift_right(x, y, span=None):
+    """Return the result of x right shifted by y bits.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+    y : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    z : PrimExpr
+        The result.
+    """
+    return _ffi_api.right_shift(x, y, span)
+
+
 def fmod(x, y):
     """Return the remainder of x divided by y with the same sign as x.
 
@@ -1306,8 +1827,8 @@ def truncmod(a, b, span=None):
     return _ffi_api._OpTruncMod(a, b, span)  # type: ignore
 
 
-def floordiv(a, b, span=None):
-    """Compute the floordiv of two expressions.
+def ceildiv(a, b, span=None):
+    """Compute the ceildiv of two expressions.
 
     Parameters
     ----------
@@ -1325,11 +1846,11 @@ def floordiv(a, b, span=None):
     res : PrimExpr
         The result expression.
     """
-    return _ffi_api._OpFloorDiv(a, b, span)  # type: ignore
+    return _ffi_api._OpCeilDiv(a, b, span)  # type: ignore
 
 
-def floormod(a, b, span=None):
-    """Compute the floormod of two expressions.
+def floordiv(a, b, span=None):
+    """Compute the floordiv of two expressions.
 
     Parameters
     ----------
@@ -1347,27 +1868,29 @@ def floormod(a, b, span=None):
     res : PrimExpr
         The result expression.
     """
-    return _ffi_api._OpFloorMod(a, b, span)  # type: ignore
+    return _ffi_api._OpFloorDiv(a, b, span)  # type: ignore
 
 
-def ceildiv(lhs, rhs, span=None):
-    """Generic ceildiv operator.
+def floormod(a, b, span=None):
+    """Compute the floormod of two expressions.
 
     Parameters
     ----------
-    lhs : object
-        The left operand.
-    rhs : object
-        The right operand.
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
     span : Optional[Span]
         The location of this operator in the source.
 
     Returns
     -------
-    op : tvm.Expr
-        The result Expr of ceildiv operaton.
+    res : PrimExpr
+        The result expression.
     """
-    return _ffi_api._OpCeilDiv(lhs, rhs, span)  # type: ignore
+    return _ffi_api._OpFloorMod(a, b, span)  # type: ignore
 
 
 def comm_reducer(fcombine, fidentity, name="reduce"):
@@ -1523,6 +2046,22 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
     return reducer
 
 
+def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
+    return call_intrin(
+        "handle",
+        "tir.TVMBackendAllocWorkspace",
+        device_type,
+        device_id,
+        nbytes,
+        dtype_code_hint,
+        dtype_bits_hint,
+    )
+
+
+def TVMBackendFreeWorkspace(device_type, device_id, ptr):
+    call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)
+
+
 # pylint: disable=unnecessary-lambda
 sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
 min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min")  # type: ignore
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
index 30e047b4f7..0ebaf212d1 100644
--- a/python/tvm/tir/schedule/block_scope.py
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -20,8 +20,8 @@ from typing import List, Optional, Union
 
 from tvm._ffi import register_object
 from tvm.runtime import Object
-from tvm.tir import Block, For
 
+from ..stmt import Block, For
 from . import _ffi_api
 
 
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index cf031c014c..f26f954d51 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -21,9 +21,11 @@ from tvm._ffi import register_object as _register_object
 from tvm.error import TVMError, register_error
 from tvm.ir import IRModule, PrimExpr
 from tvm.runtime import Object, String
-from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
 
-from ..function import IndexMap
+from ..buffer import Buffer
+from ..expr import FloatImm, IntImm
+from ..function import IndexMap, PrimFunc
+from ..stmt import Block, For
 from . import _ffi_api
 from ._type_checker import type_checked
 from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..3aed52fb50 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -22,8 +22,9 @@ from typing import Dict, Optional, Union
 from tvm._ffi import register_object
 from tvm.ir import IRModule
 from tvm.runtime import Object
-from tvm.tir import Block, BlockRealize, For, PrimFunc
 
+from ..function import PrimFunc
+from ..stmt import Block, BlockRealize, For
 from . import _ffi_api
 from .block_scope import BlockScope, StmtSRef
 
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4847e377de..3c2228e6d9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -754,3 +754,7 @@ def stmt_list(stmt):
             res += stmt_list(x)
         return res
     return [stmt]
+
+
+def type_annotation(dtype, span=None):
+    return _ffi_api.TypeAnnotation(dtype, span)
diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py
index a3b47ff6d5..f0725b666e 100644
--- a/python/tvm/tir/tensor_intrin/__init__.py
+++ b/python/tvm/tir/tensor_intrin/__init__.py
@@ -16,8 +16,4 @@
 # under the License.
 # pylint: disable=unused-import
 """Intrinsics for tensorization."""
-from .x86 import *
-from .arm_cpu import *
-from .dot_product_common import *
-from .rocm import *
-from .cuda import *
+from . import arm_cpu, cuda, rocm, x86
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py
index 3e934e1b9d..78b57d5fe1 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -17,8 +17,9 @@
 # pylint: disable=invalid-name,missing-function-docstring
 """Intrinsics for ARM tensorization."""
 from tvm.script import tir as T
-from .. import TensorIntrin
 
+from .. import TensorIntrin
+from .dot_product_common import DP4A_INTRIN  # pylint: disable=unused-import
 
 # TODO(masahi): Parametrize the TVMScript description of dot product by
 # shape and dtype, and share the common description with x86.
diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py
index 4ac9338ba8..028402a756 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -137,7 +137,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"):
                     v0, v1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(shared[v0, v1])
 
-                    thread_id, local_id = index_map(v0, v1)
+                    thread_id, local_id = T.inline(index_map(v0, v1))
                     T.writes(warp[thread_id, local_id])
                     warp[thread_id, local_id] = shared[v0, v1]
 
@@ -242,11 +242,11 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
                 with T.block("C"):
                     i, j, k = T.axis.remap("SSR", [i, j, k])
-                    b_row_ind, b_col_ind = maybe_swap(k, j)
+                    b_row_ind, b_col_ind = T.inline(maybe_swap(k, j))
 
-                    thread_id_C, local_id_C = index_map_C(i, j)
-                    thread_id_A, local_id_A = index_map_A(i, k)
-                    thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)
+                    thread_id_C, local_id_C = T.inline(index_map_C(i, j))
+                    thread_id_A, local_id_A = T.inline(index_map_A(i, k))
+                    thread_id_B, local_id_B = T.inline(index_map_B(b_row_ind, b_col_ind))
 
                     T.reads(
                         C[thread_id_C, local_id_C],
@@ -338,7 +338,7 @@ def get_mma_fill_intrin(dtype, local_size):
             for i0, i1 in T.grid(M_DIM, N_DIM):
                 with T.block("C_warp"):
                     i, j = T.axis.remap("SS", [i0, i1])
-                    thread_id, local_id = index_map(i, j)
+                    thread_id, local_id = T.inline(index_map(i, j))
                     T.reads()
                     T.writes(C_warp[thread_id, local_id])
                     C_warp[thread_id, local_id] = zero
@@ -375,7 +375,7 @@ def get_mma_store_intrin(dtype, local_size, scope="global"):
             for i0, i1 in T.grid(M_DIM, N_DIM):
                 with T.block("C_warp"):
                     v0, v1 = T.axis.remap("SS", [i0, i1])
-                    thread_id, local_id = index_map(v0, v1)
+                    thread_id, local_id = T.inline(index_map(v0, v1))
                     T.reads(C_warp[thread_id, local_id])
                     T.writes(C[v0, v1])
                     C[v0, v1] = C_warp[thread_id, local_id]
diff --git a/python/tvm/tir/tensor_intrin/rocm.py b/python/tvm/tir/tensor_intrin/rocm.py
index 7a989d0bcc..017b2722a8 100644
--- a/python/tvm/tir/tensor_intrin/rocm.py
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -37,7 +37,7 @@ def sdot4(
             T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
             T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
             T.int32(0),
-            T.bool(1),
+            T.boolean(1),
             dtype="int32",
         )
 
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
index f472172cf3..86d8bef356 100644
--- a/python/tvm/tir/usmp/transform/transform.py
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -20,8 +20,9 @@
 from typing import Dict
 
 import tvm
-from tvm.tir import Stmt
-from tvm.tir.usmp.utils import PoolAllocation
+
+from ...stmt import Stmt
+from ..utils import PoolAllocation
 from . import _ffi_api
 
 
diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc
index 336575a93e..06670c8cc2 100644
--- a/src/ir/diagnostic.cc
+++ b/src/ir/diagnostic.cc
@@ -71,7 +71,11 @@ DiagnosticBuilder Diagnostic::Help(Span span) {
 /* Diagnostic Renderer */
 TVM_REGISTER_NODE_TYPE(DiagnosticRendererNode);
 
-void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { (*this)->renderer(ctx); }
+void DiagnosticRenderer::Render(const DiagnosticContext& ctx) {
+  if ((*this)->renderer != nullptr) {
+    (*this)->renderer(ctx);
+  }
+}
 
 TVM_DLL DiagnosticRenderer::DiagnosticRenderer(
     TypedPackedFunc<void(DiagnosticContext ctx)> renderer) {
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index a3318bf94f..a7d95848da 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -21,6 +21,7 @@
  * \file src/ir/expr.cc
  * \brief The expression AST nodes for the common IR infra.
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/expr.h>
 #include <tvm/ir/function.h>
 #include <tvm/runtime/registry.h>
@@ -49,6 +50,18 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
   if (auto* ptr = ref.as<runtime::StringObj>()) {
     return tir::StringImm(GetRef<runtime::String>(ptr));
   }
+  if (auto* ptr = ref.as<tir::BufferRegionNode>()) {
+    tir::BufferRegion buffer_region = GetRef<tir::BufferRegion>(ptr);
+    Array<PrimExpr> indices;
+    for (Range r : buffer_region->region) {
+      if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+        indices.push_back(r->min);
+      } else {
+        indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+      }
+    }
+    return tir::BufferLoad(buffer_region->buffer, indices);
+  }
   Optional<String> actual_type = ObjectTypeChecker<PrimExpr>::CheckAndGetMismatch(ref.get());
   ICHECK(!actual_type.defined()) << "Expected type " << ObjectTypeChecker<PrimExpr>::TypeName()
                                  << " but got " << actual_type.value();
diff --git a/src/ir/ir_builder.cc b/src/ir/ir_builder.cc
new file mode 100644
index 0000000000..9f42cdb168
--- /dev/null
+++ b/src/ir/ir_builder.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ir/ir_builder.h>
+#include <tvm/ir/module.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace ir_builder {
+
+void IRBuilderFrameNode::EnterWithScope() {
+  IRBuilder::Current()->frames.push_back(GetRef<IRBuilderFrame>(this));
+}
+
+void IRBuilderFrameNode::ExitWithScope() {
+  for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
+    (*it)();
+  }
+  this->callbacks.clear();
+  IRBuilder::Current()->frames.pop_back();
+}
+
+void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc<void()> callback) {
+  if (IRBuilder::Current()->frames.empty()) {
+    LOG(FATAL) << "ValueError: No frames in Builder to add callback";
+  }
+  IRBuilder::Current()->frames.back()->callbacks.push_back(callback);
+}
+
+IRBuilder::IRBuilder() {
+  ObjectPtr<IRBuilderNode> n = make_object<IRBuilderNode>();
+  n->frames.clear();
+  n->result = NullOpt;
+  data_ = n;
+}
+
+std::vector<IRBuilder>* ThreadLocalBuilderStack() {
+  thread_local std::vector<IRBuilder> stack;
+  return &stack;
+}
+
+void IRBuilder::EnterWithScope() {
+  IRBuilderNode* n = this->get();
+  CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: "
+                           << n->frames.size()
+                           << ". Please use a fresh new builder every time building IRs";
+  n->result = NullOpt;
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  stack->push_back(*this);
+}
+
+void IRBuilder::ExitWithScope() {
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  ICHECK(!stack->empty());
+  stack->pop_back();
+}
+
+IRBuilder IRBuilder::Current() {
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  CHECK(!stack->empty()) << "ValueError: No builder in current scope";
+  return stack->back();
+}
+
+IRModuleFrame IRModule() {
+  ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
+  n->global_vars.clear();
+  n->functions.clear();
+  return IRModuleFrame(n);
+}
+
+void IRModuleFrameNode::ExitWithScope() {
+  ICHECK_EQ(functions.size(), global_vars.size());
+  int n = functions.size();
+  Map<GlobalVar, BaseFunc> func_map;
+  for (int i = 0; i < n; ++i) {
+    func_map.Set(global_vars[i], functions[i]);
+  }
+  IRBuilder builder = IRBuilder::Current();
+  ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+  builder->result = tvm::IRModule(func_map);
+}
+
+namespace details {
+
+Namer::FType& Namer::vtable() {
+  static FType inst;
+  return inst;
+}
+
+void Namer::Name(ObjectRef node, String name) {
+  static const FType& f = vtable();
+  CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name;
+  CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \""
+                              << node->GetTypeKey();
+  f(node, name);
+}
+
+}  // namespace details
+
+TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode);
+TVM_REGISTER_NODE_TYPE(IRBuilderNode);
+TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameEnter")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameExit")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameAddCallback")
+    .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::AddCallback);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); });
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderGet")
+    .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRModule").set_body_typed(IRModule);
+
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 7979c9f47a..e56a7bc4af 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -20,6 +20,7 @@
 /*!
  * \file expr.cc
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
@@ -828,12 +829,26 @@ TVM_REGISTER_GLOBAL("tir.Call")
     .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) {
       Array<PrimExpr> prim_expr_args;
       for (const auto& it : args) {
-        ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>())
+        ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() ||
+               it->IsInstance<IterVarNode>() || it->IsInstance<BufferRegionNode>())
             << "Argument " << it << " is not a string or primexpr";
         if (const auto* str = it.as<runtime::StringObj>()) {
           prim_expr_args.push_back(StringImm(str->data));
+        } else if (const auto* expr = it.as<PrimExprNode>()) {
+          prim_expr_args.push_back(GetRef<PrimExpr>(expr));
+        } else if (const auto* br = it.as<BufferRegionNode>()) {
+          BufferRegion buffer_region = GetRef<BufferRegion>(br);
+          Array<PrimExpr> indices;
+          for (Range r : buffer_region->region) {
+            if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+              indices.push_back(r->min);
+            } else {
+              indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+            }
+          }
+          prim_expr_args.push_back(BufferLoad(buffer_region->buffer, indices));
         } else {
-          prim_expr_args.push_back(Downcast<PrimExpr>(it));
+          prim_expr_args.push_back(Downcast<IterVar>(it).operator PrimExpr());
         }
       }
       return Call(type, op, prim_expr_args, span);
@@ -1089,6 +1104,11 @@ BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span) {
       << "-dimensional indices provided.";
 
   ObjectPtr<BufferLoadNode> node = make_object<BufferLoadNode>();
+  for (const PrimExpr& i : indices) {
+    ICHECK(i->dtype.is_int() || i->dtype.is_uint())
+        << "ValueError: index of BufferLoad should be int, but got type " << i->dtype
+        << " for index " << i;
+  }
   node->buffer = std::move(buffer);
   node->indices = std::move(indices);
   node->span = std::move(span);
diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index 7e3d3d1075..b11ca6650a 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -22,11 +22,10 @@
  * \brief Used by TVM Script parser to expand incomplete TIR input
  */
 
+#include "./script_complete.h"
+
 #include <tvm/arith/int_set.h>
-#include <tvm/runtime/registry.h>
 #include <tvm/tir/analysis.h>
-#include <tvm/tir/stmt.h>
-#include <tvm/tir/stmt_functor.h>
 
 #include <utility>
 
diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h
new file mode 100644
index 0000000000..8df0456646
--- /dev/null
+++ b/src/tir/ir/script/script_complete.h
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/ir/script/script_complete.h
+ * \brief Used by TVM Script parser to expand incomplete TIR input
+ */
+#ifndef TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_
+#define TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates);
+
+}  // namespace tir
+}  // namespace tvm
+#endif  // TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 524204f3d3..ef11d10257 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -810,6 +810,14 @@ BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
   CHECK_EQ(buffer->shape.size(), region.size())
       << "The dimension between " << buffer << " and region " << region
       << " mismatched, the buffer is " << buffer;
+  for (const Range& r : region) {
+    ICHECK(r->min->dtype.is_int() || r->min->dtype.is_uint())
+        << "ValueError: ranges of BufferRegion should be int, but got type " << r->min->dtype
+        << " for range " << r << " in its min value " << r->min;
+    ICHECK(r->extent->dtype.is_int() || r->extent->dtype.is_uint())
+        << "ValueError: ranges of BufferRegion should be int, but got type " << r->extent->dtype
+        << " for range " << r << " in its extent value " << r->extent;
+  }
   ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>();
   node->buffer = std::move(buffer);
   node->region = std::move(region);
@@ -1091,5 +1099,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
 TVM_REGISTER_OP("tir.type_annotation")
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
+TVM_REGISTER_GLOBAL("tir.TypeAnnotation").set_body_typed(TypeAnnotation);
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder.cc b/src/tir/ir_builder/ir_builder.cc
new file mode 100644
index 0000000000..96568e356c
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder.cc
@@ -0,0 +1,664 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+using tvm::tir::IterVar;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+                  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+                  String storage_scope, int align, int offset_factor, String buffer_type,
+                  Optional<Array<IntImm>> axis_separators) {
+  Var buffer_data;
+  if (!data.defined()) {
+    DataType storage_dtype = dtype;
+    if (storage_dtype == DataType::Bool()) {
+      storage_dtype = DataType::Int(8);
+    }
+    buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope));
+  } else {
+    buffer_data = data.value();
+  }
+  if (!elem_offset.defined() && offset_factor) {
+    DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype;
+    elem_offset = tvm::tir::Var("elem_offset", shape_dtype);
+  }
+  return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
+                elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor,
+                (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault,
+                axis_separators.value_or(Array<IntImm>()));
+}
+
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+                           Optional<Var> data, Optional<Array<PrimExpr>> strides,
+                           Optional<PrimExpr> elem_offset, String storage_scope, int align,
+                           int offset_factor, String buffer_type,
+                           Optional<Array<IntImm>> axis_separators) {
+  ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
+  n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
+                         align, offset_factor, buffer_type, axis_separators);
+  return DeclBufferFrame(n);
+}
+
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
+  return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
+}
+
+BlockFrame Block(String name, bool no_realize) {
+  ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
+  n->name = name;
+  n->iter_vars.clear();
+  n->reads = NullOpt;
+  n->writes = NullOpt;
+  n->init = NullOpt;
+  n->alloc_buffers.clear();
+  n->match_buffers.clear();
+  n->annotations = NullOpt;
+  n->iter_values.clear();
+  n->predicate = NullOpt;
+  n->no_realize = no_realize;
+  return BlockFrame(n);
+}
+
+BlockInitFrame Init() { return BlockInitFrame(make_object<BlockInitFrameNode>()); }
+
+void Where(PrimExpr predicate) {
+  BlockFrame frame = FindBlockFrame("T.where");
+  if (frame->predicate.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is "
+               << frame->predicate;
+  }
+  frame->predicate = predicate;
+}
+
+void Reads(Array<ObjectRef> buffer_slices) {
+  using namespace tvm::tir;
+  BlockFrame frame = FindBlockFrame("T.reads");
+  if (frame->reads.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads;
+  }
+  Array<BufferRegion> reads;
+  for (const ObjectRef& obj : buffer_slices) {
+    if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+      reads.push_back(GetRef<BufferRegion>(buffer_region));
+    } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+      reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+    } else {
+      LOG(FATAL) << "Invalid type for buffer reads.";
+    }
+  }
+  frame->reads = reads;
+}
+
+void Writes(Array<ObjectRef> buffer_slices) {
+  using namespace tvm::tir;
+  BlockFrame frame = FindBlockFrame("T.writes");
+  if (frame->writes.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is "
+               << frame->writes;
+  }
+  Array<BufferRegion> writes;
+  for (const ObjectRef& obj : buffer_slices) {
+    if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+      writes.push_back(GetRef<BufferRegion>(buffer_region));
+    } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+      writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+    } else {
+      LOG(FATAL) << "Invalid type for buffer writes.";
+    }
+  }
+  frame->writes = writes;
+}
+
+void BlockAttrs(Map<String, ObjectRef> attrs) {
+  BlockFrame frame = FindBlockFrame("T.block_attr");
+  if (frame->annotations.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
+  }
+  frame->annotations = attrs;
+}
+
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+                   Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+                   int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+  Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+                             offset_factor, buffer_type_str, axis_separators);
+  IRBuilder builder = IRBuilder::Current();
+  if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) {
+    frame.value()->alloc_buffers.push_back(buffer);
+  } else if (Optional<PrimFuncFrame> frame = builder->GetLastFrame<PrimFuncFrame>()) {
+    frame.value()->root_alloc_buffers.push_back(buffer);
+  } else {
+    LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure "
+                  "'T.alloc_buffer' is called under T.block() or T.prim_func()";
+  }
+  return buffer;
+}
+
+namespace axis {
+
+IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
+  if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+    BlockFrame frame = opt_frame.value();
+    frame->iter_vars.push_back(iter_var);
+    frame->iter_values.push_back(binding);
+  } else {
+    LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
+  }
+  return iter_var;
+}
+
+#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name)                                           \
+  Var Method(Range dom, PrimExpr binding, DataType dtype) {                                   \
+    ICHECK(dom.defined()) << Name << " axis must have a domain";                              \
+    int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
+    return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)),          \
+                                /*iter_type=*/Kind, /*thread_tag=*/""),                       \
+                        binding)                                                              \
+        ->var;                                                                                \
+  }
+TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
+TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
+TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
+TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
+#undef TVM_TIR_IR_BUILDER_AXIS
+
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
+  using namespace tvm::tir;
+  Array<Var> results;
+  ICHECK_EQ(kinds.size(), bindings.size());
+  int n = bindings.size();
+  results.reserve(n);
+  for (int i = 0; i < n; ++i) {
+    char c = kinds.c_str()[i];
+    PrimExpr e = bindings[i];
+    const VarNode* v = e.as<VarNode>();
+    ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
+    Range dom{nullptr};
+    for (const auto& frame : IRBuilder::Current()->frames) {
+      if (const auto* for_frame = frame.as<ForFrameNode>()) {
+        ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
+        int n = for_frame->doms.size();
+        for (int i = 0; i < n; ++i) {
+          if (for_frame->vars[i].get() == v) {
+            dom = for_frame->doms[i];
+            break;
+          }
+        }
+        if (dom.defined()) {
+          break;
+        }
+      }
+    }
+    ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
+    DataType dtype = v->dtype;
+    if (c == 'S') {
+      results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+                                             /*var=*/Var("", dtype),
+                                             /*iter_type=*/IterVarType::kDataPar,
+                                             /*thread_tag=*/""),
+                                     e)
+                            ->var);
+    } else if (c == 'R') {
+      results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+                                             /*var=*/Var("", dtype),
+                                             /*iter_type=*/IterVarType::kCommReduce,
+                                             /*thread_tag=*/""),
+                                     e)
+                            ->var);
+    } else {
+      LOG(FATAL) << "Unknown axis kind: " << c;
+    }
+  }
+  return results;
+}
+
+}  // namespace axis
+
+#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind)                                                \
+  ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) {  \
+    PrimExpr min = start;                                                                         \
+    PrimExpr extent = arith::Analyzer().Simplify(stop - start);                                   \
+    ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();                                      \
+    int bits = std::max(min.dtype().bits(), extent.dtype().bits());                               \
+    n->vars = {Var("v", DataType::Int(bits))};                                                    \
+    n->doms = {Range::FromMinExtent(min, extent)};                                                \
+    n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, tvm::tir::Stmt body) { \
+      ICHECK_EQ(vars.size(), 1);                                                                  \
+      ICHECK_EQ(doms.size(), 1);                                                                  \
+      return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt,           \
+                           annotations.value_or(Map<String, ObjectRef>()));                       \
+    };                                                                                            \
+    return ForFrame(n);                                                                           \
+  }
+
+TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled);
+
+#undef TVM_TIR_IR_BUILDER_FOR_FRAME
+
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+                       Optional<Map<String, ObjectRef>> annotations) {
+  using namespace tvm::tir;
+  PrimExpr min = start;
+  PrimExpr extent = arith::Analyzer().Simplify(stop - start);
+  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+  int bits = std::max(min.dtype().bits(), extent.dtype().bits());
+  n->vars = {Var("v", DataType::Int(bits))};
+  n->doms = {Range::FromMinExtent(min, extent)};
+  n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
+    ICHECK_EQ(vars.size(), 1);
+    ICHECK_EQ(doms.size(), 1);
+    IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
+                     thread);
+    return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
+               annotations.value_or(Map<String, ObjectRef>()));
+  };
+  return ForFrame(n);
+}
+
+ForFrame Grid(Array<PrimExpr> extents) {
+  using namespace tvm::tir;
+  ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+  n->vars.reserve(extents.size());
+  n->doms.reserve(extents.size());
+  for (const auto& extent : extents) {
+    DataType dtype = extent.dtype();
+    n->vars.push_back(Var("v", extent.dtype()));
+    n->doms.push_back(Range(make_const(dtype, 0), extent));
+  }
+  n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, Stmt body) -> Stmt {
+    ICHECK_EQ(vars.size(), doms.size());
+    int n = vars.size();
+    for (int i = n - 1; i >= 0; --i) {
+      Range dom = doms[i];
+      Var var = vars[i];
+      body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
+                 /*thread_binding=*/NullOpt, /*annotations=*/{});
+    }
+    return body;
+  };
+  return ForFrame(n);
+}
+
+PrimFuncFrame PrimFunc() {
+  ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>();
+  n->name = NullOpt;
+  n->args.clear();
+  n->ret_type = NullOpt;
+  n->buffer_map.clear();
+  n->preflattened_buffer_map.clear();
+  n->attrs = NullOpt;
+  n->env_threads.clear();
+  n->root_alloc_buffers.clear();
+  return PrimFuncFrame(n);
+}
+
+Var Arg(String name, Var var) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+  details::Namer::Name(var, name);
+  frame->args.push_back(var);
+  return var;
+}
+
+Buffer Arg(String name, Buffer buffer) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+  details::Namer::Name(buffer, name);
+  Var handle(buffer->name + "_handle", DataType::Handle());
+  frame->args.push_back(handle);
+  frame->buffer_map.Set(handle, buffer);
+  return buffer;
+}
+
+void FuncName(String name) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
+  if (frame->name.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value();
+  }
+  frame->name = name;
+}
+
+void FuncAttrs(Map<String, ObjectRef> attrs) {
+  using namespace tvm::tir;
+  PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr");
+  if (frame->attrs.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
+  }
+  frame->attrs = attrs;
+}
+
+tvm::Type FuncRet(tvm::Type ret_type) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type");
+  if (frame->ret_type.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is "
+               << frame->ret_type.value();
+  }
+  frame->ret_type = ret_type;
+  return ret_type;
+}
+
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+                   Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+                   int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+  Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+                             offset_factor, buffer_type_str, axis_separators);
+  if (const auto* var = param.as<tvm::tir::VarNode>()) {
+    PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer");
+    Var v = GetRef<Var>(var);
+    for (auto const& arg : frame->args) {
+      if (arg.same_as(v)) {
+        frame->buffer_map.Set(v, buffer);
+        return buffer;
+      }
+    }
+    LOG(FATAL) << "ValueError: Can not bind non-input param to buffer.";
+  } else if (const auto* buffer_load = param.as<tvm::tir::BufferLoadNode>()) {
+    BlockFrame frame = FindBlockFrame("T.match_buffer");
+    frame->match_buffers.push_back(tvm::tir::MatchBufferRegion(
+        buffer, BufferRegionFromLoad(GetRef<tvm::tir::BufferLoad>(buffer_load))));
+  } else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) {
+    BlockFrame frame = FindBlockFrame("T.match_buffer");
+    frame->match_buffers.push_back(
+        tvm::tir::MatchBufferRegion(buffer, GetRef<tvm::tir::BufferRegion>(buffer_region)));
+  } else {
+    LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
+  }
+  return buffer;
+}
+
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape, DataType dtype,
+                        Optional<Var> data, Array<PrimExpr> strides, PrimExpr elem_offset,
+                        String storage_scope, int align, int offset_factor, String buffer_type_str,
+                        Array<IntImm> axis_separators) {
+  PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer");
+  for (auto const& p : frame->buffer_map) {
+    if (p.second.same_as(postflattened_buffer)) {
+      String buffer_name(postflattened_buffer->name + "_preflatten");
+      Buffer buffer =
+          BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset,
+                     storage_scope, align, offset_factor, buffer_type_str, axis_separators);
+      details::Namer::Name(buffer, buffer_name);
+      frame->preflattened_buffer_map.Set(p.first, buffer);
+      return;
+    }
+  }
+  LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name
+             << " does not exist.";
+}
+
+AssertFrame Assert(PrimExpr condition, String message) {
+  ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>();
+  n->condition = condition;
+  n->message = tvm::tir::StringImm(message);
+  return AssertFrame(n);
+}
+
+LetFrame Let(Var var, PrimExpr value) {
+  ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
+  n->var = var;
+  n->value = value;
+  return LetFrame(n);
+}
+
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope,
+                       Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) {
+  ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>();
+  n->extents = extents;
+  n->dtype = dtype;
+  n->storage_scope = storage_scope;
+  n->condition = condition.value_or(tvm::Bool(true));
+  n->annotations = annotations.value_or(Map<String, ObjectRef>());
+  n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
+                         "default", NullOpt);
+  return AllocateFrame(n);
+}
+
+AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
+                                 Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+  ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
+  n->dtype = dtype;
+  n->extents = extents;
+  n->data = data;
+  n->annotations = annotations;
+  n->buffer =
+      BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+  return AllocateConstFrame(n);
+}
+
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
+  IterVar iter_var{nullptr};
+
+  if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
+    if (Optional<IterVar> opt_iter_var = opt_frame.value()->env_threads.Get(var)) {
+      iter_var = opt_iter_var.value();
+    } else {
+      LOG(INFO) << "ValueError: " << var->name_hint
+                << " is not an env_thread created using T.env_thread.";
+    }
+  } else {
+    LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc";
+  }
+  ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
+  if (!iter_var->dom.defined()) {
+    const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+  } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
+    LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
+               << iter_var->dom->extent << " vs " << extent;
+  }
+  n->iter_var = iter_var;
+  n->extent = extent;
+  n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent";
+  return LaunchThreadFrame(n);
+}
+
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
+                     PrimExpr condition) {
+  ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
+  n->buffer_slice = buffer_slice;
+  n->storage_scope = storage_scope;
+  n->condition = condition;
+  return RealizeFrame(n);
+}
+
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) {
+  ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
+  n->node = node;
+  n->attr_key = attr_key;
+  n->value = value;
+  return AttrFrame(n);
+}
+
+WhileFrame While(PrimExpr condition) {
+  ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>();
+  n->condition = condition;
+  return WhileFrame(n);
+}
+
+IfFrame If(PrimExpr condition) {
+  ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>();
+  n->condition = condition;
+  n->then_stmts = NullOpt;
+  n->else_stmts = NullOpt;
+  return IfFrame(n);
+}
+
+ThenFrame Then() {
+  ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>();
+  return ThenFrame(n);
+}
+
+ElseFrame Else() {
+  ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>();
+  return ElseFrame(n);
+}
+
+Var EnvThread(String thread_tag) {
+  IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
+                   thread_tag);
+  Var var = iter_var->var;
+  if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
+    opt_frame.value()->env_threads.Set(var, iter_var);
+  } else {
+    LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
+  }
+  return var;
+}
+
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+  AddToParent(tvm::tir::BufferStore(buffer, value, indices));
+}
+
+void Prefetch(Buffer buffer, Array<Range> bounds) {
+  AddToParent(tvm::tir::Prefetch(buffer, bounds));
+}
+
+void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
+
+using tvm::ir_builder::details::Namer;
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::BufferNode>([](const ObjectRef& node, String name) -> void {
+      tvm::tir::BufferNode* buffer =
+          const_cast<tvm::tir::BufferNode*>(node.as<tvm::tir::BufferNode>());
+      buffer->name = name;
+      Namer::Name(buffer->data, name);
+      int n = buffer->strides.size();
+      for (int i = 0; i < n; ++i) {
+        PrimExpr e = buffer->strides[i];
+        if (const tvm::tir::VarNode* v = e.as<tvm::tir::VarNode>()) {
+          Namer::Name(GetRef<tvm::tir::Var>(v), name + "_s" + std::to_string(i));
+        }
+      }
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::SizeVarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      SizeVarNode* var = const_cast<SizeVarNode*>(node.as<SizeVarNode>());
+      var->name_hint = name;
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::VarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      VarNode* var = const_cast<VarNode*>(node.as<VarNode>());
+      var->name_hint = name;
+    });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+    .set_dispatch<tvm::tir::IterVarNode>([](const ObjectRef& node, String name) -> void {
+      using namespace tvm::tir;
+      IterVarNode* var = const_cast<IterVarNode*>(node.as<IterVarNode>());
+      Namer::Name(var->var, name);
+    });
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Ptr").set_body_typed(Ptr);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Block").set_body_typed(Block);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Init").set_body_typed(Init);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Where").set_body_typed(Where);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Reads").set_body_typed(Reads);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Writes").set_body_typed(Writes);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Serial").set_body_typed(Serial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Parallel").set_body_typed(Parallel);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Vectorized").set_body_typed(Vectorized);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Unroll").set_body_typed(Unroll);
+TVM_REGISTER_GLOBAL("ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Grid").set_body_typed(Grid);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Arg")
+    .set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
+      using namespace tvm::tir;
+      if (const auto* var = obj.as<VarNode>()) {
+        return Arg(name, GetRef<tvm::tir::Var>(var));
+      }
+      if (const auto* buffer = obj.as<BufferNode>()) {
+        return Arg(name, GetRef<Buffer>(buffer));
+      }
+      LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey();
+      throw;
+    });
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncName").set_body_typed(FuncName);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncRet").set_body_typed(FuncRet);
+TVM_REGISTER_GLOBAL("ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Assert").set_body_typed(Assert);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Allocate").set_body_typed(Allocate);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Attr").set_body_typed(Attr);
+TVM_REGISTER_GLOBAL("ir_builder.tir.While").set_body_typed(While);
+TVM_REGISTER_GLOBAL("ir_builder.tir.If").set_body_typed(If);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Then").set_body_typed(Then);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Else").set_body_typed(Else);
+TVM_REGISTER_GLOBAL("ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.EnvThread").set_body_typed(EnvThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferStore").set_body_typed(BufferStore);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Prefetch").set_body_typed(Prefetch);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Evaluate").set_body_typed(Evaluate);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int8").set_body_typed(Int8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int16").set_body_typed(Int16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32").set_body_typed(Int32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int64").set_body_typed(Int64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt8").set_body_typed(UInt8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt16").set_body_typed(UInt16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt32").set_body_typed(UInt32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt64").set_body_typed(UInt64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float8").set_body_typed(Float8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float16").set_body_typed(Float16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float32").set_body_typed(Float32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float64").set_body_typed(Float64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32x4").set_body_typed(Int32x4);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32x8").set_body_typed(Int32x8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32x16").set_body_typed(Int32x16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Boolean").set_body_typed(Boolean);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Handle").set_body_typed(Handle);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Void").set_body_typed(Void);
+TVM_REGISTER_GLOBAL("ir_builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+  return tvm::min(a, b);
+});
+TVM_REGISTER_GLOBAL("ir_builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+  return tvm::max(a, b);
+});
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder_frame.cc b/src/tir/ir_builder/ir_builder_frame.cc
new file mode 100644
index 0000000000..18ccc2a1ad
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder_frame.cc
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/function.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "../../tir/ir/script/script_complete.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+void BlockFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  Array<tvm::tir::Buffer> tir_alloc_buffers;
+  for (const tvm::tir::Buffer& buffer : alloc_buffers) {
+    tir_alloc_buffers.push_back(buffer);
+  }
+  Map<String, ObjectRef> attrs = annotations.value_or({});
+  if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) {
+    attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access));
+  }
+  tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()),
+                        writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init,
+                        tir_alloc_buffers, match_buffers, attrs);
+  if (no_realize) {
+    CHECK(iter_values.empty())
+        << "ValueError: Block bindings are not allowed when `no_realize=True`";
+    CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
+    AddToParent(block);
+  } else {
+    AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
+  }
+}
+
+void BlockInitFrameNode::EnterWithScope() {
+  BlockFrame frame = FindBlockFrame("T.init");
+  if (frame->init.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block init declaration";
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void BlockInitFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  BlockFrame frame = FindBlockFrame("T.init");
+  frame->init = AsStmt(stmts);
+}
+
+void ForFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
+}
+
+void PrimFuncFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  tvm::tir::PrimFunc func(
+      /*params=*/args,
+      /*body=*/AsStmt(stmts),
+      /*ret_type=*/ret_type.value_or(TupleType::Empty()),
+      /*buffer_map=*/buffer_map,
+      /*preflattened_buffer_map=*/preflattened_buffer_map,
+      /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue<DictAttrs>());
+  func = tvm::tir::ScriptComplete(func, root_alloc_buffers);
+  IRBuilder builder = IRBuilder::Current();
+  if (builder->frames.empty()) {
+    ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+    builder->result = func;
+  } else if (Optional<IRModuleFrame> opt_frame = builder->FindFrame<IRModuleFrame>()) {
+    IRModuleFrame frame = opt_frame.value();
+    frame->global_vars.push_back(GlobalVar(name.value_or("")));
+    frame->functions.push_back(func);
+  } else {
+    LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
+  }
+}
+
+void AssertFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+}
+
+void LetFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
+}
+
+void AllocateFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
+                                 AsStmt(stmts), annotations));
+}
+
+void AllocateConstFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(
+      tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+}
+
+void LaunchThreadFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
+}
+
+void RealizeFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
+                                 tvm::tir::StringImm(storage_scope),
+                                 tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region,
+                                                         condition, AsStmt(stmts))));
+}
+
+void AttrFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts)));
+}
+
+void WhileFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::While(condition, AsStmt(stmts)));
+}
+
+void IfFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  if (!stmts.empty()) {
+    LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame";
+  }
+  if (!then_stmts.defined()) {
+    LOG(FATAL) << "IfThenElse frame should have at least one then branch";
+  }
+  AddToParent(tvm::tir::IfThenElse(
+      condition, AsStmt(then_stmts.value()),
+      else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr)));
+}
+
+void ThenFrameNode::EnterWithScope() {
+  IfFrame frame = FindIfFrame("T.then_");
+  if (frame->then_stmts.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is "
+               << frame->then_stmts.value();
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void ThenFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  FindIfFrame("T.then_")->then_stmts = stmts;
+}
+
+void ElseFrameNode::EnterWithScope() {
+  IfFrame frame = FindIfFrame("T.else_");
+  if (!frame->then_stmts.defined()) {
+    LOG(FATAL) << "The else branch should follow then branch";
+  }
+  if (frame->else_stmts.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is "
+               << frame->else_stmts.value();
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void ElseFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  FindIfFrame("T.else_")->else_stmts = stmts;
+}
+
+void DeclBufferFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+}
+
+TVM_REGISTER_NODE_TYPE(TIRFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
+TVM_REGISTER_NODE_TYPE(ForFrameNode);
+TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
+TVM_REGISTER_NODE_TYPE(AssertFrameNode);
+TVM_REGISTER_NODE_TYPE(LetFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode);
+TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
+TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
+TVM_REGISTER_NODE_TYPE(AttrFrameNode);
+TVM_REGISTER_NODE_TYPE(WhileFrameNode);
+TVM_REGISTER_NODE_TYPE(IfFrameNode);
+TVM_REGISTER_NODE_TYPE(ThenFrameNode);
+TVM_REGISTER_NODE_TYPE(ElseFrameNode);
+TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode);
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
diff --git a/src/tir/ir_builder/utils.h b/src/tir/ir_builder/utils.h
new file mode 100644
index 0000000000..6d9cbda72a
--- /dev/null
+++ b/src/tir/ir_builder/utils.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_UTILS_H_
+#define TVM_TIR_IR_BUILDER_UTILS_H_
+
+#include <tvm/tir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+inline void AddToParent(tvm::tir::Stmt stmt) {
+  IRBuilder builder = IRBuilder::Current();
+  if (builder->frames.empty()) {
+    ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+    builder->result = stmt;
+  } else if (const auto* tir_frame = builder->frames.back().as<TIRFrameNode>()) {
+    GetRef<TIRFrame>(tir_frame)->stmts.push_back(stmt);
+  } else {
+    LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back();
+  }
+}
+
+inline tvm::tir::Stmt AsStmt(const Array<tvm::tir::Stmt>& stmt) {
+  using namespace tvm::tir;
+  if (stmt.empty()) {
+    return tvm::tir::Evaluate(0);
+  } else if (stmt.size() == 1) {
+    return stmt[0];
+  } else {
+    return SeqStmt(stmt);
+  }
+}
+
+inline BlockFrame FindBlockFrame(const String& method) {
+  if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+    return frame.value();
+  }
+  LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
+             << "' is called under T.block()";
+  throw;
+}
+
+inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
+  if (Optional<PrimFuncFrame> frame = IRBuilder::Current()->GetLastFrame<PrimFuncFrame>()) {
+    return frame.value();
+  }
+  LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method
+             << "' is called under T.prim_func()";
+  throw;
+}
+
+inline IfFrame FindIfFrame(const String& method) {
+  if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) {
+    return frame.value();
+  } else {
+    LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method
+               << "' is called under T.if_()";
+  }
+  throw;
+}
+
+inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) {
+  Array<Range> ranges;
+  for (const PrimExpr& index : buffer_load->indices) {
+    ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1)));
+  }
+  return tvm::tir::BufferRegion(buffer_load->buffer, ranges);
+}
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_UTILS_H_
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 114571218b..dd3c7a0b0f 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
                    {x, y, q, s}, span);
 }
 
+// address_of
+PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
+  return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
+}
+
+// lookup_param
+PrimExpr lookup_param(String param_name, Span span) {
+  return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
+                   span);
+}
+
 // The public function with a quick checking path.
 void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) {  // NOLINT(*)
   CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
@@ -702,6 +713,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
   }
 }
 
+// isnullptr
+PrimExpr isnullptr(PrimExpr x, Span span) {
+  return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
+}
+
 // isinf
 PrimExpr isinf(PrimExpr x, Span span) {
   DataType t = DataType::Bool(x.dtype().lanes());
@@ -931,6 +947,8 @@ TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
 
 TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
 
+TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);
+
 TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);
 
 TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
@@ -949,6 +967,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
 
 TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
 
+TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity);
+
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
 // operator overloading, smarter than make
 #define REGISTER_MAKE_BINARY_OP(Node, Func)                                                \
   TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
@@ -997,6 +1019,8 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
 REGISTER_MAKE_BIT_OP(left_shift, left_shift);  // NOLINT(*)
 REGISTER_MAKE_BIT_OP(right_shift, right_shift);
 
+TVM_REGISTER_GLOBAL("tir._OpNot").set_body_typed(logical_not);
+
 TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
     .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
       return if_then_else(cond, true_value, false_value, span);
diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc
index c3b8fd6766..4e0aa4ae24 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -121,7 +121,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info,
   // Create block vars, block's accessed region and accessing indices
   for (const PrimExpr& dim : cache_region->buffer->shape) {
     Var var("v" + std::to_string(access_indices.size()), dim.dtype());
-    block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, dim),
+    block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(make_zero(dim->dtype), dim),
                                  /*var=*/var,
                                  /*IterVarType=*/kDataPar));
     access_indices.push_back(var);
diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_meta_schedule_auto_tensorize.py
index 5a60c19568..04ccdc514d 100644
--- a/tests/python/integration/test_meta_schedule_auto_tensorize.py
+++ b/tests/python/integration/test_meta_schedule_auto_tensorize.py
@@ -14,12 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Integration test for metascheduler's auto tensorization."""
+"""Integration test for MetaSchedule's auto tensorization."""
 import tempfile
 
 import numpy as np
 import pytest
-
 import tvm
 import tvm.testing
 import tvm.topi.testing
@@ -29,8 +28,9 @@ from tvm.meta_schedule import ApplyHistoryBest, postproc, schedule_rule
 from tvm.meta_schedule.relay_integration import extract_task_from_relay
 from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
 from tvm.meta_schedule.tune import tune_extracted_tasks
-from tvm.tir.tensor_intrin import AMDGPU_SDOT4_INTRIN, DP4A_INTRIN
-from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN
+from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
 
 CONFIG = ms.TuneConfig(
     strategy="evolutionary",
@@ -393,7 +393,7 @@ def test_cuda_tensor_core(model_name, input_shape):
             )
         print(profiler.table())
 
-        # Compile without meta-scheduler for correctness check
+        # Compile without MetaSchedule for correctness check
         with tvm.transform.PassContext(opt_level=0):
             rt_mod2 = relay.build(mod, target=target, params=params)
 
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py
index 9c597a55e5..bf0edaef5c 100644
--- a/tests/python/unittest/test_aot_legalize_packed_call.py
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -85,7 +85,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
@@ -94,7 +94,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
@@ -103,7 +103,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
index a1184c1edf..fc624cd5a6 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
@@ -16,9 +16,9 @@
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 import tvm
-import tvm.tir.tensor_intrin
 from tvm.meta_schedule import TuneContext, postproc
 from tvm.script import tir as T
+from tvm.tir.tensor_intrin import arm_cpu, cuda, rocm, x86
 
 
 @tvm.script.ir_module
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index d415ae9ce6..4da870e455 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -31,8 +31,8 @@ from tvm.meta_schedule.tune_context import TuneContext
 from tvm.script import tir as T
 from tvm.target import Target
 from tvm.te import create_prim_func
-from tvm.tir.tensor_intrin import DP4A_INTRIN
-from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
 
 
 def _create_context(mod, target, rule) -> TuneContext:
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index ce333887ec..97e8a69556 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -910,7 +910,7 @@ def test_cuda_nrm():
                 for i0_1 in T.thread_binding(128, thread="threadIdx.x"):
                     with T.block("D"):
                         b = T.axis.spatial(1, i0_1)
-                        T.where(0 * 128 + i0_1 < 1)
+                        T.where(T.int32(0) * T.int32(128) + i0_1 < 1)
                         T.reads(C_shared[b])
                         T.writes(D[b])
                         D[b] = T.sqrt(C_shared[b], dtype="float32")
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py
index d86b6fe48b..7d85b8757a 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -152,7 +152,7 @@ def test_meta_schedule_tune_relay(
                 work_dir=work_dir,
             )
         print(profiler.table())
-        # Compile without meta-scheduler for correctness check
+        # Compile without meta-schedule for correctness check
         with tvm.transform.PassContext(opt_level=0):
             rt_mod2 = relay.build(mod, target=target, params=params)
 
@@ -252,7 +252,7 @@ def test_meta_schedule_te2primfunc_argument_order():
         ):
             rt_mod1 = relay.build(mod, target=target, params=params)
 
-    # Compile without meta-scheduler for correctness check
+    # Compile without meta-schedule for correctness check
     with tvm.transform.PassContext(opt_level=0):
         rt_mod2 = relay.build(mod, target=target, params=params)
 
@@ -314,7 +314,7 @@ def test_meta_schedule_relay_lowering():
             ):
                 rt_mod1 = relay.build(mod, target=target, params=params)
 
-        # Compile without meta-scheduler for correctness check
+        # Compile without meta-schedule for correctness check
         with tvm.transform.PassContext(opt_level=0):
             rt_mod2 = relay.build(mod, target=target, params=params)
 
@@ -516,7 +516,7 @@ def test_tune_relay_manual_tir_vnni():
         attrs={"schedule_rule": "meta_schedule.dense_vnni"},
     )
 
-    When the meta scheduler encounters a TensorIR block with the "schedule_rule" annotation,
+    When the MetaSchedule encounters a TensorIR block with the "schedule_rule" annotation,
     it looks up the packed func registry for a function that is associated with the given schedule
     rule key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule
     functions must be
diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py
index 18bf9d1184..c576483828 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -17,20 +17,19 @@
 import collections
 import ctypes
 import json
+import math
+import re
 import sys
 
+import numpy as np
+import pytest
 import tvm
 import tvm.testing
 from tvm import te
+from tvm.contrib import clang, utils
 from tvm.relay.backend import Runtime
-from tvm.contrib import utils, clang
-from tvm.target.codegen import llvm_lookup_intrinsic_id, llvm_get_intrinsic_name
-import tvm.script.tir as T
-import numpy as np
-
-import math
-import re
-import pytest
+from tvm.script import tir as T
+from tvm.target.codegen import llvm_get_intrinsic_name, llvm_lookup_intrinsic_id
 
 
 @tvm.testing.requires_llvm
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py
index 93b7caf9cd..0e08c0df8c 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -16,7 +16,6 @@
 # under the License.
 
 import pytest
-
 import tvm
 from tvm.script import tir as T
 
@@ -63,10 +62,23 @@ def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None:
 
 
 @tvm.ir.register_op_attr("tir.intrin_test", "")
-def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1):
+def _intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1):
     return 0
 
 
+def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1, dtype):
+    return tvm.tir.call_intrin(
+        dtype,
+        "tir.intrin_test",
+        data,
+        elem_offset,
+        stride_0,
+        stride_1,
+        shape_0,
+        shape_1,
+    )
+
+
 @T.prim_func
 def opaque_access(a: T.handle, b: T.handle) -> None:
     A = T.match_buffer(a, (32, 64, 128))
@@ -82,7 +94,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None:
                 offset_factor=1,
             )
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     sub_A.data,
                     sub_A.elem_offset,
                     sub_A.strides[0],
@@ -105,7 +117,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None:
                 offset_factor=1,
             )
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     sub_B.data,
                     sub_B.elem_offset,
                     sub_B.strides[0],
@@ -126,7 +138,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None:
             T.reads([])
             T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16])
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     A.data,
                     i * 131072 + j * 128 + k * 16,
                     8192,
@@ -141,7 +153,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None:
             T.reads([])
             T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8])
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     B.data,
                     i * 4096 + j * 2048 + k * 8,
                     64,
@@ -169,7 +181,7 @@ def high_dim_opaque_access(a: T.handle) -> None:
                 offset_factor=1,
             )
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     sub_A.data,
                     sub_A.elem_offset,
                     sub_A.strides[0],
@@ -189,7 +201,7 @@ def transformed_high_dim_opaque_access(a: T.handle) -> None:
             T.reads([])
             T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     A.data,
                     i * 2048 + j * 1024 + k * 16,
                     64,
@@ -217,7 +229,7 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None:
                 offset_factor=1,
             )
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     sub_A.data,
                     sub_A.elem_offset,
                     sub_A.strides[0],
@@ -237,7 +249,7 @@ def transformed_high_dim_opaque_access_with_source_strides(a: T.handle) -> None:
             T.reads([])
             T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     A.data,
                     i * 2576 + j * 1280 + k * 16,
                     80,
@@ -298,7 +310,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None:
                         offset_factor=1,
                     )
                     T.evaluate(
-                        T.intrin_test(
+                        intrin_test(
                             sub_sub_A.data,
                             sub_sub_A.elem_offset,
                             sub_sub_A.strides[0],
@@ -343,7 +355,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None:
                         ]
                     )
                     T.evaluate(
-                        T.intrin_test(
+                        intrin_test(
                             A.data,
                             i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4,
                             64,
@@ -375,7 +387,7 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None:
                 sub_A[ii, jj] = 1
             for j in range(0, 4):
                 T.evaluate(
-                    T.intrin_test(
+                    intrin_test(
                         sub_B.data,
                         sub_B.elem_offset,
                         sub_B.strides[0],
@@ -399,7 +411,7 @@ def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32)
                 A[i * m + ii, jj] = 1
             for j in range(0, 4):
                 T.evaluate(
-                    T.intrin_test(
+                    intrin_test(
                         B.data,
                         i * n * (m * 4),
                         m * 4,
@@ -423,7 +435,7 @@ def rank0_buffer(a: T.handle, b: T.handle) -> None:
             sub_B = T.match_buffer(B[i, j], (), offset_factor=1)
             sub_A[()] = 1
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     sub_B.data,
                     sub_B.elem_offset,
                     0,
@@ -445,7 +457,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None:
             T.writes([A[i, j], B[i, j]])
             A[i, j] = 1
             T.evaluate(
-                T.intrin_test(
+                intrin_test(
                     B.data,
                     i * 8 + j,
                     0,
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py
index a97060f01b..929a6cfa19 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -16,19 +16,20 @@
 # under the License.
 # pylint: disable=missing-function-docstring,missing-module-docstring
 import sys
+
 import pytest
 import tvm
 import tvm.testing
-from tvm import tir, te
+from tvm import te, tir
 from tvm.script import tir as T
 from tvm.tir.schedule.testing import verify_trace_roundtrip
-from tvm.tir.tensor_intrin import (
-    VNNI_DOT_16x4_INTRIN,
+from tvm.tir.tensor_intrin.arm_cpu import (
+    DP4A_INTRIN,
     ARM_DOT_4x4_i8_NEON_INTRIN,
     ARM_DOT_4x4_i8_SDOT_INTRIN,
-    AMDGPU_SDOT4_INTRIN,
-    DP4A_INTRIN,
 )
+from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
 
 # fmt: off
 # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py
index 01496e0e0f..056a13cb1a 100644
--- a/tests/python/unittest/test_tir_transform_helpers.py
+++ b/tests/python/unittest/test_tir_transform_helpers.py
@@ -15,15 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
-
 import tvm
-from tvm.script import tir as T
 import tvm.testing
+from tvm.script import tir as T
 
 
 def test_annotate_entry_func_single_primfunc():
     @tvm.script.ir_module
-    class MockModule:
+    class MockModule2:
         @T.prim_func
         def func1(A: T.Buffer[(16,), "float32"]):
             for i in T.serial(16):
@@ -31,7 +30,7 @@ def test_annotate_entry_func_single_primfunc():
                     if i == 5:
                         A[i] = 0.0
 
-    mod = MockModule
+    mod = MockModule2
     assert mod
     assert mod["func1"].attrs is None
     after = tvm.tir.transform.AnnotateEntryFunc()(mod)
diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py b/tests/python/unittest/test_tir_transform_hoist_expression.py
index 8b7fc98bfd..26464baed3 100644
--- a/tests/python/unittest/test_tir_transform_hoist_expression.py
+++ b/tests/python/unittest/test_tir_transform_hoist_expression.py
@@ -15,11 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import tir
 import tvm.testing
-
+from tvm import tir
+from tvm.script import from_source
 from tvm.script import tir as T
-from tvm.tir.transform import HoistExpression, HoistedConditionals, HoistedLetBindings
+from tvm.tir.transform import HoistedConditionals, HoistedLetBindings, HoistExpression
 
 
 class BaseBeforeAfter:
@@ -27,7 +27,7 @@ class BaseBeforeAfter:
     hoisted_let_bindings = tvm.testing.parameter(HoistedLetBindings.All)
 
     def test_hoist(self, hoisted_conditionals, hoisted_let_bindings):
-        before = self.before
+        before = from_source(self.before)
         before_mod = tvm.IRModule.from_expr(before)
 
         config = {
@@ -41,7 +41,7 @@ class BaseBeforeAfter:
             after_mod = tvm.tir.transform.HoistExpression()(before_mod)
 
         after = after_mod["main"]
-        expected = self.expected
+        expected = from_source(self.expected)
 
         try:
             tvm.ir.assert_structural_equal(after, expected)
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index edaeb7c9b6..34f988c77c 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -14,16 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import pytest
 import sys
-import numpy as np
 
+import numpy as np
+import pytest
 import tvm
 import tvm.testing
 import tvm.tir.tensor_intrin.cuda
-from tvm import tir, te, TVMError
-from tvm.script import tir as T
+from tvm import TVMError, te, tir
 from tvm.meta_schedule.testing import te_workload
+from tvm.script import tir as T
 from tvm.testing.tir import mma_schedule
 from tvm.tir.tensor_intrin.cuda import (
     LDMATRIX_16x16_A_DYN_INTRIN,
@@ -1060,7 +1060,7 @@ def test_simple_compute_async():
                     T.writes(B[0, tx, 0])
                     with T.attr(0, "async_commit_queue_scope", 0):
                         with T.attr(0, "async_scope", 1):
-                            B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2)
+                            B[T.int32(0) % 2, tx, 0] = A[tx, 0] * T.float32(2)
                 with T.block():
                     T.reads(A[tx, 1:16], B[0:2, tx, 0])
                     T.writes(B[0:2, tx, 0], C[tx, 0:15])
@@ -1080,11 +1080,11 @@ def test_simple_compute_async():
                                 with T.attr(0, "async_wait_inflight_count", 1):
                                     C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1)
                 with T.block():
-                    T.reads(B[15 % 2, tx, 0])
+                    T.reads(B[T.int32(15) % 2, tx, 0])
                     T.writes(C[tx, 15])
                     with T.attr(0, "async_wait_queue_scope", 0):
                         with T.attr(0, "async_wait_inflight_count", 0):
-                            C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1)
+                            C[tx, 15] = B[T.int32(15) % 2, tx, 0] + T.float32(1)
 
     tvm.ir.assert_structural_equal(mod["main"], ref, True)
 
diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
index b96afb6a09..f80573c43b 100644
--- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
+++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
@@ -15,8 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+import tvm.testing
 from tvm import te
-
 from tvm.script import tir as T
 
 vthread_name = tvm.testing.parameter("vthread", "cthread")
@@ -153,10 +153,10 @@ def test_vthread_simplified():
         B = T.allocate([16], "int32", "shared")
         # The indices for B should each be a single Ramp node, and
         # should not be the sum of a Ramp and Broadcast node.
-        B[0 * 4 : 0 * 4 + 4] = T.broadcast(0, 4)
-        B[1 * 4 : 1 * 4 + 4] = T.broadcast(1, 4)
-        B[2 * 4 : 2 * 4 + 4] = T.broadcast(2, 4)
-        B[3 * 4 : 3 * 4 + 4] = T.broadcast(3, 4)
+        B[T.int32(0) * 4 : T.int32(0) * 4 + 4] = T.broadcast(0, 4)
+        B[T.int32(1) * 4 : T.int32(1) * 4 + 4] = T.broadcast(1, 4)
+        B[T.int32(2) * 4 : T.int32(2) * 4 + 4] = T.broadcast(2, 4)
+        B[T.int32(3) * 4 : T.int32(3) * 4 + 4] = T.broadcast(3, 4)
 
     before_mod = tvm.IRModule.from_expr(before_func)
     after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
@@ -178,10 +178,10 @@ def test_vthread_vectorized():
     @T.prim_func
     def expected_func():
         B = T.allocate([4], "int32x4", "shared")
-        B[0 * 4 / 4] = T.broadcast(0, 4)
-        B[1 * 4 / 4] = T.broadcast(1, 4)
-        B[2 * 4 / 4] = T.broadcast(2, 4)
-        B[3 * 4 / 4] = T.broadcast(3, 4)
+        B[T.int32(0) * 4 / 4] = T.broadcast(0, 4)
+        B[T.int32(1) * 4 / 4] = T.broadcast(1, 4)
+        B[T.int32(2) * 4 / 4] = T.broadcast(2, 4)
+        B[T.int32(3) * 4 / 4] = T.broadcast(3, 4)
 
     before_mod = tvm.IRModule.from_expr(before_func)
     intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py
index acc68af065..aff503c272 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -76,14 +76,6 @@ def test_missing_type_annotation():
     check_error(missing_type_annotation, 1)
 
 
-def invalid_expr_stmt() -> None:
-    T.max(1, 2)  # error
-
-
-def test_invalid_expr_stmt():
-    check_error(invalid_expr_stmt, 2)
-
-
 def invalid_for_function(a: T.handle) -> None:
     A = T.match_buffer(a, (16, 16), "float32")
 
@@ -115,14 +107,6 @@ def test_return_not_allowed():
     check_error(return_not_allowed, 2)
 
 
-def tir_assert(a: T.handle) -> None:
-    T.Assert(0, "")  # error
-
-
-def test_tir_assert():
-    check_error(tir_assert, 2)
-
-
 def no_body(a: T.handle) -> None:
     A = T.match_buffer(a, (16, 16), "float32")
     T.realize(A, "")  # error
@@ -250,19 +234,6 @@ def test_invalid_match_buffer_region():
     check_error(invalid_match_buffer_region, 5)
 
 
-def duplicate_buffer() -> None:
-    A = T.alloc_buffer((128, 128), "float32")
-    for i, j in T.grid(128, 128):
-        with T.block():
-            vi, vj = T.axis.remap("SS", [i, j])
-            A = T.alloc_buffer((128, 128), "float32")  # error
-            T.evaluate(1.0)
-
-
-def test_duplicate_buffer():
-    check_error(duplicate_buffer, 6)
-
-
 def duplicate_reads() -> None:
     A = T.alloc_buffer((128, 128), "float32")
     for i, j in T.grid(128, 128):
@@ -334,7 +305,7 @@ def opaque_access_during_complete(a: T.handle) -> None:  # error
 
 
 def test_opaque_access_during_complete():
-    check_error(opaque_access_during_complete, 1)
+    check_error(opaque_access_during_complete, 0)
 
 
 def convert_slice_to_bufferload() -> None:
@@ -608,15 +579,6 @@ def test_binop_bad_type():
     check_error(binop_bad_type, 3)
 
 
-def floor_dtype(h: T.handle):
-    h_ = T.match_buffer(h, [1])
-    h_[0] = T.floor(2)  # error floor requires a dtype
-
-
-def test_floor_dtype():
-    check_error(floor_dtype, 3)
-
-
 def non_integer_typed_block_iter():
     with T.block():
         i = T.axis.S(0.1, 0.1)  # error IterVar requires an integer dtype
diff --git a/tests/python/unittest/test_tvmscript_spans.py b/tests/python/unittest/test_tvmscript_spans.py
deleted file mode 100644
index f863a4dd98..0000000000
--- a/tests/python/unittest/test_tvmscript_spans.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-from tvm.script import tir as T
-
-
-@T.prim_func
-def loops() -> None:
-    for i in T.parallel(0, 2):
-        for j in T.serial(0, 1):
-            for z in T.vectorized(3, 4):
-                T.evaluate(0)
-
-
-def test_loops():
-    start_line = 23
-    parsed = loops
-
-    assert parsed.span.line == start_line
-
-    assert parsed.body.span.line == start_line + 1
-    assert parsed.body.min.span.column == 25
-    assert parsed.body.extent.span.column == 28
-    assert parsed.body.extent.span.line == start_line + 1
-
-    assert parsed.body.body.span.line == start_line + 2
-    assert parsed.body.body.loop_var.span.line == start_line + 2
-    assert parsed.body.body.loop_var.span.column == 13
-
-    assert parsed.body.body.body.span.line == start_line + 3
-    assert parsed.body.body.body.span.column == 22
-
-    assert parsed.body.body.body.body.span.line == start_line + 4
-    assert parsed.body.body.body.body.span.column == 17
-
-
-@T.prim_func
-def statements() -> None:
-    T.evaluate(1)
-    T.evaluate("test")
-
-
-def test_statements():
-    start_line = 53
-    parsed = statements
-
-    assert parsed.body.span.line == start_line + 1
-
-    assert parsed.body[0].span.line == start_line + 1
-    assert parsed.body[0].span.column == 5
-
-    assert parsed.body[0].span.line == start_line + 1
-    assert parsed.body[0].span.column == 5
-
-
-if __name__ == "__main__":
-    test_loops()
-    test_statements()
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 7248a3a5f4..d09a0d143a 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -20,8 +20,8 @@ import sys
 import pytest
 import tvm.testing
 from tvm.ir import assert_structural_equal
+from tvm.script import from_source
 from tvm.script import tir as T
-from tvm.script.parser import from_source
 from tvm.testing import check_error
 
 
@@ -164,15 +164,24 @@ def test_match_buffer_1d():
 
 
 # match buffer failed case
-def test_match_buffer_no_kwargs_failed():
-    with pytest.raises(ValueError) as e:
-
-        @T.prim_func
-        def elementwise_buffer_no_kwargs_failed(
-            a: T.Buffer[(128, 128, 128, 128)],
-            b: T.Buffer[(128, 128, 128, 128)],
-        ) -> None:
-            pass
+def test_match_buffer_without_dtype():
+    @T.prim_func
+    def no_dtype(
+        a: T.Buffer[(128, 128, 128, 128)],
+        b: T.Buffer[(128, 128, 128, 128)],
+    ) -> None:
+        pass
+
+    a0, a1, a2, a3 = no_dtype.buffer_map[no_dtype.params[0]].shape
+    b0, b1, b2, b3 = no_dtype.buffer_map[no_dtype.params[1]].shape
+    assert a0 == 128
+    assert a1 == 128
+    assert a2 == 128
+    assert a3 == 128
+    assert b0 == 128
+    assert b1 == 128
+    assert b2 == 128
+    assert b3 == 128
 
 
 # dynamic shape gemm
@@ -274,8 +283,8 @@ def test_letstmt_bind_with_constant():
 
     @T.prim_func
     def constant_binds_wrapped():
-        x = T.int32(1)
-        y = T.float32(42.0)
+        x = T.inline(T.int32(1))
+        y = T.inline(T.float32(42.0))
         T.evaluate(T.cast(x, "float32") + y)
 
     assert_structural_equal(constant_binds, constant_binds_wrapped)
@@ -298,9 +307,9 @@ def test_func_call():
             for i, j, k in T.grid(16, 16, 16):
                 with T.block("C"):
                     i, j, k = T.axis.remap("SSR", [i, j, k])
-                    thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
-                    thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
-                    thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)
+                    thread_id_C, local_id_C = T.inline(shared_16x16_to_ldmatrix_32x8_layout(i, j))
+                    thread_id_A, local_id_A = T.inline(shared_16x16_to_ldmatrix_32x8_layout(i, k))
+                    thread_id_B, local_id_B = T.inline(shared_16x16_to_ldmatrix_32x8_layout(k, j))
 
                     T.reads(
                         C[thread_id_C, local_id_C],