You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 22:31:54 UTC

[tvm] 01/01: [TVMScript] New Parser

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch ir-builder-v2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit fbba02c1a41917cdf803070cc2f66f2a8b8b03c7
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Aug 15 12:33:04 2022 -0700

    [TVMScript] New Parser
---
 include/tvm/script/ir_builder/base.h               |  166 +++
 include/tvm/script/ir_builder/ir/frame.h           |   61 +
 include/tvm/script/ir_builder/ir/ir.h              |   39 +
 include/tvm/script/ir_builder/tir/frame.h          |  459 +++++++
 include/tvm/script/ir_builder/tir/ir.h             |  142 ++
 include/tvm/tir/op.h                               |   34 +-
 python/tvm/script/__init__.py                      |    6 +-
 python/tvm/script/context_maintainer.py            |  251 ----
 python/tvm/script/diagnostics.py                   |   55 -
 python/tvm/script/{ => ir_builder}/__init__.py     |    7 +-
 .../script/{__init__.py => ir_builder/_ffi_api.py} |    7 +-
 python/tvm/script/ir_builder/base.py               |   76 ++
 python/tvm/script/{ => ir_builder/ir}/__init__.py  |    8 +-
 .../{__init__.py => ir_builder/ir/_ffi_api.py}     |    7 +-
 .../{tir/__init__.py => ir_builder/ir/frame.py}    |   17 +-
 .../script/{__init__.py => ir_builder/ir/ir.py}    |    9 +-
 python/tvm/script/{ => ir_builder/tir}/__init__.py |    7 +-
 .../{__init__.py => ir_builder/tir/_ffi_api.py}    |    7 +-
 python/tvm/script/ir_builder/tir/frame.py          |  116 ++
 python/tvm/script/ir_builder/tir/ir.py             |  954 ++++++++++++++
 python/tvm/script/meta_unparser.py                 |   45 -
 python/tvm/script/parser.py                        | 1385 --------------------
 python/tvm/script/{ => parser}/__init__.py         |   12 +-
 python/tvm/script/{__init__.py => parser/_core.py} |   13 +-
 python/tvm/script/{ => parser/core}/__init__.py    |    7 +-
 python/tvm/script/parser/core/diagnostics.py       |  175 +++
 python/tvm/script/parser/core/dispatch.py          |   63 +
 python/tvm/script/parser/core/doc.py               |  361 +++++
 .../script/{printer => parser/core}/doc_core.py    |    0
 .../{tir/prim_func.py => parser/core/entry.py}     |   46 +-
 python/tvm/script/parser/core/evaluator.py         |  282 ++++
 python/tvm/script/parser/core/parser.py            |  273 ++++
 .../{tir/__init__.py => parser/core/utils.py}      |   27 +-
 python/tvm/script/{ => parser/ir}/__init__.py      |    8 +-
 .../{tir/prim_func.py => parser/ir/entry.py}       |   32 +-
 .../{tir/__init__.py => parser/ir/parser.py}       |   28 +-
 python/tvm/script/{ => parser/tir}/__init__.py     |   11 +-
 python/tvm/script/parser/tir/entry.py              |  101 ++
 python/tvm/script/parser/tir/operation.py          |   84 ++
 python/tvm/script/parser/tir/parser.py             |  268 ++++
 python/tvm/script/registry.py                      |   62 -
 python/tvm/script/tir/__init__.pyi                 |  477 -------
 python/tvm/script/tir/intrin.py                    |  222 ----
 python/tvm/script/tir/node.py                      |  218 ---
 python/tvm/script/tir/scope_handler.py             |  788 -----------
 python/tvm/script/tir/special_stmt.py              |  964 --------------
 python/tvm/script/tir/ty.py                        |  216 ---
 python/tvm/script/utils.py                         |  105 --
 python/tvm/tir/__init__.py                         |  216 ++-
 python/tvm/tir/analysis/analysis.py                |    6 +-
 python/tvm/tir/expr.py                             |   15 +-
 python/tvm/tir/op.py                               |  601 ++++++++-
 python/tvm/tir/schedule/block_scope.py             |    2 +-
 python/tvm/tir/schedule/schedule.py                |    6 +-
 python/tvm/tir/schedule/state.py                   |    3 +-
 python/tvm/tir/stmt.py                             |    4 +
 python/tvm/tir/usmp/transform/transform.py         |    5 +-
 src/script/ir_builder/base.cc                      |  115 ++
 src/script/ir_builder/ir/frame.cc                  |   43 +
 src/script/ir_builder/ir/ir.cc                     |   38 +
 src/script/ir_builder/tir/frame.cc                 |  210 +++
 src/script/ir_builder/tir/ir.cc                    |  665 ++++++++++
 src/script/ir_builder/tir/utils.h                  |   95 ++
 src/tir/ir/script/script_complete.h                |   37 +
 src/tir/ir/stmt.cc                                 |   10 +
 src/tir/op/op.cc                                   |   24 +
 66 files changed, 5782 insertions(+), 5014 deletions(-)

diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h
new file mode 100644
index 0000000000..179cca42df
--- /dev/null
+++ b/include/tvm/script/ir_builder/base.h
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_
+#define TVM_SCRIPT_IR_BUILDER_BASE_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+
+////////////////////////////// Core Infra: Frame //////////////////////////////
+
+class IRBuilderFrameNode : public runtime::Object {
+ public:
+  std::vector<runtime::TypedPackedFunc<void()>> callbacks;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `callbacks` is not visited.
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
+
+ public:
+  virtual ~IRBuilderFrameNode() = default;
+  virtual void EnterWithScope();
+  virtual void ExitWithScope();
+
+  void AddCallback(runtime::TypedPackedFunc<void()> callback);
+};
+
+class IRBuilderFrame : public runtime::ObjectRef {
+ public:
+  virtual ~IRBuilderFrame() = default;
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
+
+ protected:
+  IRBuilderFrame() = default;
+
+ public:
+  inline void EnterWithScope();
+  inline void ExitWithScope();
+};
+
+////////////////////////////// Core Infra: Builder //////////////////////////////
+///
+class IRBuilderNode : public runtime::Object {
+ public:
+  runtime::Array<IRBuilderFrame> frames;
+  Optional<ObjectRef> result;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("frames", &frames);
+    v->Visit("result", &result);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.IRBuilder";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
+
+ public:
+  template <typename TFrame>
+  inline Optional<TFrame> FindFrame() const;
+  template <typename TFrame>
+  inline Optional<TFrame> GetLastFrame() const;
+  template <typename TObjectRef>
+  inline TObjectRef Get() const;
+};
+
+class IRBuilder : public runtime::ObjectRef {
+ public:
+  IRBuilder();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
+
+ public:
+  void EnterWithScope();
+  void ExitWithScope();
+  static IRBuilder Current();
+  template <class TObjectRef>
+  inline static TObjectRef Name(String name, TObjectRef obj);
+};
+
+////////////////////////////// Details //////////////////////////////
+
+namespace details {
+
+class Namer {
+ public:
+  using FType = NodeFunctor<void(const ObjectRef&, String)>;
+  static FType& vtable();
+  static void Name(ObjectRef node, String name);
+};
+
+}  // namespace details
+
+template <class TObjectRef>
+inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
+  details::Namer::Name(obj, name);
+  return Downcast<TObjectRef>(obj);
+}
+
+inline void IRBuilderFrame::EnterWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
+}
+
+inline void IRBuilderFrame::ExitWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
+  data_.reset();
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::FindFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
+    if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
+      return GetRef<TFrame>(p);
+    }
+  }
+  return NullOpt;
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
+    return Downcast<TFrame>(frames.back());
+  }
+  return NullOpt;
+}
+
+template <typename TObjectRef>
+inline TObjectRef IRBuilderNode::Get() const {
+  using TObject = typename TObjectRef::ContainerType;
+  CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
+  const auto* n = result.as<TObject>();
+  CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key;
+  return GetRef<TObjectRef>(n);
+}
+
+}  // namespace ir_builder
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_IR_BUILDER_BASE_H_
diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h
new file mode 100644
index 0000000000..9a8791be7c
--- /dev/null
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_
+#define TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+#include <tvm/script/ir_builder/base.h>
+
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+
+class IRModuleFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<GlobalVar> global_vars;
+  Array<BaseFunc> functions;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("global_vars", &global_vars);
+    v->Visit("functions", &functions);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IRModuleFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame,
+                                                    IRModuleFrameNode);
+};
+
+}  // namespace ir_builder
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_
diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h
new file mode 100644
index 0000000000..b58e51a945
--- /dev/null
+++ b/include/tvm/script/ir_builder/ir/ir.h
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_IR_IR_H_
+#define TVM_SCRIPT_IR_BUILDER_IR_IR_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+#include <tvm/script/ir_builder/ir/frame.h>
+
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+
+TVM_DLL IRModuleFrame IRModule();
+
+}
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_IR_IR_BUILDER_IR_IR_H_
diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
new file mode 100644
index 0000000000..d2d2485bbe
--- /dev/null
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -0,0 +1,459 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
+#define TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
+
+#include <tvm/script/ir_builder/base.h>
+#include <tvm/script/ir_builder/ir/frame.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+namespace tir {
+
+class TIRFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<tvm::tir::Stmt> stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("stmts", &stmts);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.TIRFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode);
+};
+
+class TIRFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode);
+
+ protected:
+  TIRFrame() = default;
+};
+
+class BlockFrameNode : public TIRFrameNode {
+ public:
+  String name;
+  Array<tvm::tir::IterVar> iter_vars;
+  Optional<Array<tvm::tir::BufferRegion>> reads;
+  Optional<Array<tvm::tir::BufferRegion>> writes;
+  Optional<tvm::tir::Stmt> init;
+  Array<tvm::tir::Buffer> alloc_buffers;
+  Array<tvm::tir::MatchBufferRegion> match_buffers;
+  Optional<Map<String, ObjectRef>> annotations;
+
+  Array<PrimExpr> iter_values;
+  Optional<PrimExpr> predicate;
+  bool no_realize;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("iter_vars", &iter_vars);
+    v->Visit("reads", &reads);
+    v->Visit("writes", &writes);
+    v->Visit("init", &init);
+    v->Visit("alloc_buffers", &alloc_buffers);
+    v->Visit("match_buffers", &match_buffers);
+    v->Visit("annotations", &annotations);
+    v->Visit("iter_values", &iter_values);
+    v->Visit("predicate", &predicate);
+    v->Visit("no_realize", &no_realize);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.BlockFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class BlockFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
+};
+
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.BlockInitFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class BlockInitFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode);
+};
+
+class ForFrameNode : public TIRFrameNode {
+ public:
+  using FMakeForLoop =
+      runtime::TypedPackedFunc<tvm::tir::Stmt(Array<tvm::tir::Var>, Array<Range>, tvm::tir::Stmt)>;
+
+  Array<tvm::tir::Var> vars;
+  Array<Range> doms;
+  FMakeForLoop f_make_for_loop;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("vars", &vars);
+    v->Visit("doms", &doms);
+    // `f_make_for_loop` is not visited.
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.ForFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class ForFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
+};
+
+class PrimFuncFrameNode : public TIRFrameNode {
+ public:
+  Optional<String> name;
+  Array<tvm::tir::Var> args;
+  Optional<Type> ret_type;
+  Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
+  Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
+  Optional<Map<String, ObjectRef>> attrs;
+  Map<tvm::tir::Var, tvm::tir::IterVar> env_threads;
+  Array<tvm::tir::Buffer> root_alloc_buffers;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("args", &args);
+    v->Visit("ret_type", &ret_type);
+    v->Visit("buffer_map", &buffer_map);
+    v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
+    v->Visit("attrs", &attrs);
+    v->Visit("env_threads", &env_threads);
+    v->Visit("root_alloc_buffers", &root_alloc_buffers);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.PrimFuncFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class PrimFuncFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
+};
+
+class AssertFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  PrimExpr message;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("message", &message);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AssertFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AssertFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+class LetFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Var var;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LetFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+  Array<PrimExpr> extents;
+  DataType dtype;
+  String storage_scope;
+  PrimExpr condition;
+  Map<String, ObjectRef> annotations;
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extents", &extents);
+    v->Visit("dtype", &dtype);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+    v->Visit("annotations", &annotations);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+  DataType dtype;
+  Array<PrimExpr> extents;
+  tvm::runtime::NDArray data;
+  tvm::tir::Buffer buffer;
+  Map<String, ObjectRef> annotations;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("dtype", &dtype);
+    v->Visit("extents", &extents);
+    v->Visit("data", &data);
+    v->Visit("buffer", &buffer);
+    v->Visit("annotations", &annotations);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateConstFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+                                                    AllocateConstFrameNode);
+};
+
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr extent;
+  String attr_key;
+  tvm::tir::IterVar iter_var;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extent", &extent);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("iter_var", &iter_var);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LaunchThreadFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+                                                    LaunchThreadFrameNode);
+};
+
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::BufferRegion buffer_slice;
+  String storage_scope;
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer_slice", &buffer_slice);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class RealizeFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
+
+class AttrFrameNode : public TIRFrameNode {
+ public:
+  ObjectRef node;
+  String attr_key;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("node", &node);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AttrFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+class WhileFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class WhileFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+class IfFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  Optional<Array<tvm::tir::Stmt>> then_stmts;
+  Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("then_stmts", &then_stmts);
+    v->Visit("else_stmts", &else_stmts);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IfFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+class ThenFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ThenFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+class ElseFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ElseFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
new file mode 100644
index 0000000000..c26d552737
--- /dev/null
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
+#define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
+
+#include <tvm/script/ir_builder/base.h>
+#include <tvm/script/ir_builder/tir/frame.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+namespace tir {
+
+using tvm::runtime::NDArray;
+using tvm::tir::Buffer;
+using tvm::tir::Var;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+                  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+                  String storage_scope, int align, int offset_factor, String buffer_type,
+                  Optional<Array<IntImm>> axis_separators);
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
+BlockFrame Block(String name, bool no_realize = false);
+BlockInitFrame Init();
+void Where(PrimExpr predicate);
+void Reads(Array<ObjectRef> buffer_slices);
+void Writes(Array<ObjectRef> buffer_slices);
+void BlockAttrs(Map<String, ObjectRef> attrs);
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
+                   int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+
+namespace axis {
+Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+}  // namespace axis
+
+ForFrame Serial(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Parallel(PrimExpr start, PrimExpr stop,
+                  Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Vectorized(PrimExpr start, PrimExpr stop,
+                    Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Unroll(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Grid(Array<PrimExpr> extents);
+
+PrimFuncFrame PrimFunc();
+Var Arg(String name, Var var);
+Buffer Arg(String name, Buffer buffer);
+void FuncName(String name);
+void FuncAttrs(Map<String, ObjectRef> attrs);
+Type FuncRet(Type ret_type);
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
+                   int align = -1, int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
+                        DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
+                        Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
+                        String storage_scope = "global", int align = -1, int offset_factor = 0,
+                        String buffer_type = "default", Array<IntImm> axis_separators = {});
+
+AssertFrame Assert(PrimExpr condition, String message);
+LetFrame Let(Var var, PrimExpr value);
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+                       Optional<PrimExpr> condition = NullOpt,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+AllocateConstFrame AllocateConst(
+    NDArray data, DataType dtype, Array<PrimExpr> extents,
+    Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+WhileFrame While(PrimExpr condition);
+IfFrame If(PrimExpr condition);
+ThenFrame Then();
+ElseFrame Else();
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
+Var EnvThread(String thread_tag);
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+void Prefetch(Buffer buffer, Array<Range> bounds);
+void Evaluate(PrimExpr value);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)                             \
+  inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) {                        \
+    DataType dtype = DType;                                                            \
+    return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
+  }
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
+
+#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 7236c6a611..09758cc923 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
  * \return The result expression.
  */
 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
 /*!
  * \brief Check if x is infinite.
  * \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
                       Span span = Span());
 
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
 /*!
  * \brief Calculate floor(x)
  * \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
                                   Span span = Span());
 
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
 // Intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
   inline PrimExpr OpName(PrimExpr x, Span span = Span()) {              \
@@ -701,6 +732,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
 TVM_DECLARE_INTRIN_UNARY(log);
 TVM_DECLARE_INTRIN_UNARY(log2);
 TVM_DECLARE_INTRIN_UNARY(log10);
+TVM_DECLARE_INTRIN_UNARY(log1p);
 TVM_DECLARE_INTRIN_UNARY(popcount);
 TVM_DECLARE_INTRIN_UNARY(tan);
 TVM_DECLARE_INTRIN_UNARY(cos);
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..3107ada88b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,5 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+from . import parser
+from .parser import ir, ir_module, parse, tir
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py
deleted file mode 100644
index f7f16855c7..0000000000
--- a/python/tvm/script/context_maintainer.py
+++ /dev/null
@@ -1,251 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Context Maintainer for TIR"""
-
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
-
-import tvm
-from tvm.ir import Span
-from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
-from tvm.runtime import Object
-from tvm.tir.expr import IterVar
-from .tir.node import BufferSlice
-
-
-class BlockInfo:
-    """Information for block and block_realize signature
-
-    Examples
-    ----------
-    .. code-block:: python
-
-        @T.prim_func
-        def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
-            A = T.match_buffer(a, (16, 16), "float32")
-            B = T.match_buffer(b, (16, 16), "float32")
-            C = T.match_buffer(a, (16, 16), "float32")
-
-            for i, j, k in T.grid(16, 16, 16):
-                with T.block("matmul"):
-                    vi = T.axis.S(16, i)
-                    vj = T.axis.S(16, j)
-                    vk = T.axis.R(16, k)         # iter_bindings = {vj: i, vj: j, vk: k}
-
-                    T.where(True)         # predicate of the block_realize
-
-                    T.reads(A[0:16, 0:16], B[0: 16, 0: 16])      # reads region of the block
-                    T.writes(C[0: 16, 0: 16])                    # writes region of the block
-                    T.block_attr({"attr_key": "attr_value"})     # block annotations
-
-                    # alloc_buffers inside the block
-                    CC = T.alloc_buffer((1, 1), dtype="float32")
-
-                    # match_buffers of the block,
-                    # which bind a sub-region of source buffer into a new buffer
-                    D = T.match_buffer(C[vi, vj], ())
-
-                    # init part of the block, executed when all reduce axes are the beginning value
-                    with T.init():
-                        C[vi, vj] = T.float32(0)
-
-                    # block body
-                    CC[0, 0] = A[vi, vk] * B[vj, vk]
-                    D[()] += CC[0, 0]         # The same as C[vi, vj] += CC[0, 0]
-    """
-
-    alloc_buffers: List[Buffer] = []
-    """List[Buffer]: list of T.alloc_buffer statements in the block signature"""
-    match_buffers: List[MatchBufferRegion] = []
-    """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
-    iter_values: List[PrimExpr] = []
-    """List[PrimExpr]: list of binding values for iter vars"""
-    iter_vars: List[IterVar] = []
-    """List[PrimExpr]: list of iter vars in the block"""
-    reads: Optional[List[BufferSlice]] = None
-    """Optional[List[BufferSlice]]:
-    list of T.reads statements in the block signature, None for not-visited"""
-    writes: Optional[List[BufferSlice]] = None
-    """Optional[List[BufferSlice]]:
-    list of T.writes statements in the block signature, None for not-visited"""
-    annotations: Optional[Mapping[str, Object]] = None
-    """Optional[Mapping[str, Object]]:
-    list of T.block_attr statements in the block signature, None for not-visited"""
-    predicate: Optional[PrimExpr] = None
-    """Optional[PrimExpr]: block realize predicate, None for not-visited"""
-    init: Optional[Stmt] = None
-    """Optional[Stmt]: init part of the block, None for not-visited"""
-
-    def __init__(self):
-        self.alloc_buffers = []
-        self.match_buffers = []
-        self.iter_values = []
-        self.iter_vars = []
-        self.reads = None
-        self.writes = None
-        self.annotations = None
-        self.predicate = None
-        self.init = None
-
-
-class ContextMaintainer:
-    """Maintain all the necessary context info
-    Parameters
-    ----------
-    _report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
-        The report error function handle
-    """
-
-    # scope context
-    node_stack: List[List[synr.ast.Node]] = []
-    """List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
-    block_info_stack: List[BlockInfo] = []
-    """List[BlockInfo]: The block info for the current block scope"""
-    loop_stack: Dict[Var, Range] = {}
-    """Dict[Var, Range]: The dict from loop var to its domain outside the block"""
-    symbols: List[Dict[str, Union[Var, Buffer]]] = []
-    """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
-    closure_vars: Dict[str, Object] = {}
-    """ClosureVars: The closure vars defined in Python interpreter"""
-
-    # function context
-    func_params: List[Var] = []
-    """List[Var]: The function parameters"""
-    func_buffer_map: Mapping[Var, Buffer] = {}
-    """Mapping[Var, Buffer]: The function buffer map"""
-    func_preflattened_buffer_map: Mapping[Var, Buffer] = {}
-    """Mapping[Var, Buffer]: The function buffer map, prior to any flattening."""
-    func_dict_attr: Mapping[str, Object] = {}
-    """Mapping[str, Object]: The function attrs"""
-    func_var_env_dict: Mapping[Var, str] = {}
-    """Mapping[Var, str]: The map from var to env thread"""
-
-    # parser and analyzer
-    analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
-    """tvm.arith.Analyzer: The analyzer for simplifying"""
-    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-    """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle"""
-
-    # root alloc_buffer
-    root_alloc_buffers: List[Buffer] = []
-    """List[Buffer]: The buffers allocated under root block"""
-
-    def __init__(
-        self,
-        _report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
-        closure_vars: Dict[str, Object],
-    ):
-        # scope context
-        self.node_stack = []
-        self.block_info_stack = []
-        self.loop_stack = {}
-        self.symbols = []
-        self.closure_vars = closure_vars
-        # function context
-        self.func_params = []
-        self.func_buffer_map = {}
-        self.func_preflattened_buffer_map = {}
-        self.func_dict_attr = {}
-        self.func_var_env_dict = {}
-        # parser and analyzer
-        self._report_error = _report_error
-        self.analyzer = tvm.arith.Analyzer()
-        # root alloc_buffer
-        self.root_alloc_buffers = []
-
-    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
-        """Creates a new scope
-
-        Note
-        ----
-        This function is used for normal scopes that do not involve
-        a `with block` scope. Use `enter_block_scope`
-        for block scope cases.
-
-        Parameters
-        ----------
-        nodes : Optional[List[synr.ast.Node]]
-            The synr AST nodes in new scope
-        """
-        if nodes is None:
-            nodes = []
-        self.node_stack.append(list(reversed(nodes)))
-        self.symbols.append(dict())
-
-    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
-        """Creates a new block scope, the function will call `enter_scope` implicitly
-        Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack
-        to maintain block info.
-
-        Note
-        ----
-        This function should be used to handle a block scope,
-        aka the blocks that involve a `with block` scope.
-
-        Parameters
-        ----------
-        nodes : Optional[List[synr.ast.Node]]
-            The synr AST nodes in new scope
-        """
-        self.enter_scope(nodes)
-        # Create a new BlockInfo for the new block
-        self.block_info_stack.append(BlockInfo())
-
-    def exit_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
-
-    def exit_block_scope(self):
-        """Pop the inner most block scope, the function will call `exit_scope` implicitly"""
-        self.exit_scope()
-        # Pop block_info
-        self.block_info_stack.pop()
-
-    def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
-        """Append a symbol into current scope"""
-        if isinstance(symbol, Buffer):
-            if name in self.symbols[0]:
-                self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
-            self.symbols[0][name] = symbol
-        else:
-            self.symbols[-1][name] = symbol
-
-    def remove_symbol(self, name: str):
-        """Remove a symbol"""
-        for symbols in reversed(self.symbols):
-            if name in symbols:
-                symbols.pop(name)
-                return
-        raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)
-
-    def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
-        """Look up symbol by name"""
-        for symbols in reversed(self.symbols):
-            if name in symbols:
-                return symbols[name]
-        return self.closure_vars.get(name)
-
-    def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
-        self._report_error(message, span)
-
-    def current_block_scope(self) -> BlockInfo:
-        if self.block_info_stack:
-            return self.block_info_stack[-1]
-        return None
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/diagnostics.py
deleted file mode 100644
index e676461ab3..0000000000
--- a/python/tvm/script/diagnostics.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Bridge from synr's (the library used for parsing the python AST)
-   DiagnosticContext to TVM's diagnostics
-"""
-from synr import DiagnosticContext, ast
-
-import tvm
-from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
-
-
-class TVMDiagnosticCtx(DiagnosticContext):
-    """TVM diagnostics for synr"""
-
-    diag_ctx: TVMCtx
-
-    def __init__(self) -> None:
-        self.diag_ctx = TVMCtx(tvm.IRModule(), get_renderer())
-        self.source_name = None
-
-    def to_tvm_span(self, src_name, ast_span: ast.Span) -> tvm.ir.Span:
-        return tvm.ir.Span(
-            src_name,
-            ast_span.start_line,
-            ast_span.end_line,
-            ast_span.start_column,
-            ast_span.end_column,
-        )
-
-    def add_source(self, name: str, source: str) -> None:
-        src_name = self.diag_ctx.module.source_map.add(name, source)
-        self.source_name = src_name
-
-    def emit(self, _level, message, span):
-        span = self.to_tvm_span(self.source_name, span)
-        self.diag_ctx.emit(Diagnostic(DiagnosticLevel.ERROR, span, message))
-        self.diag_ctx.render()  # Raise exception on the first error we hit. TODO remove
-
-    def render(self):
-        self.diag_ctx.render()
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/__init__.py
similarity index 87%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/__init__.py
index 555659d0c5..53721d2432 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/__init__.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
+"""tvm.script.ir_builder is a generic IR builder for TVM."""
 from . import tir
-
-from .parser import ir_module, from_source
+from .base import IRBuilder
+from .ir import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/_ffi_api.py
similarity index 84%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/_ffi_api.py
index 555659d0c5..68811c9e01 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/_ffi_api.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+"""FFI APIs for tvm.script.ir_builder"""
+import tvm._ffi
 
-from . import tir
-
-from .parser import ir_module, from_source
+tvm._ffi._init_api("script.ir_builder", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py
new file mode 100644
index 0000000000..d8b965d03b
--- /dev/null
+++ b/python/tvm/script/ir_builder/base.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A generic IRBuilder across the TVM stack"""
+from typing import List, TypeVar
+
+from tvm._ffi import register_object as _register_object
+from tvm.runtime import Object as _Object
+
+from . import _ffi_api
+
+
+@_register_object("script.ir_builder.IRBuilderFrame")
+class IRBuilderFrame(_Object):
+    def __enter__(self) -> "IRBuilderFrame":
+        _ffi_api.IRBuilderFrameEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameExit(self)  # pylint: disable=no-member # type: ignore
+
+    def add_callback(self, callback) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameAddCallback(  # pylint: disable=no-member # type: ignore
+            self, callback
+        )
+
+
+@_register_object("script.ir_builder.IRBuilder")
+class IRBuilder(_Object):
+    def __init__(self) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.IRBuilder  # pylint: disable=no-member # type: ignore
+        )
+
+    def __enter__(self) -> "IRBuilder":
+        _ffi_api.IRBuilderEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderExit(self)  # pylint: disable=no-member # type: ignore
+
+    @staticmethod
+    def current() -> "IRBuilder":
+        return _ffi_api.IRBuilderCurrent()  # pylint: disable=no-member # type: ignore
+
+    def get(self) -> _Object:
+        return _ffi_api.IRBuilderGet(self)  # pylint: disable=no-member # type: ignore
+
+
+DefType = TypeVar("DefType", bound=_Object)
+
+
+def name(s: str, v: DefType) -> DefType:
+    return _ffi_api.IRBuilderName(s, v)  # pylint: disable=no-member # type: ignore
+
+
+def name_many(  # pylint: disable=invalid-name
+    s: List[str],
+    vs: List[DefType],
+) -> List[DefType]:
+    assert len(s) == len(vs)
+    return [name(i, v) for i, v in zip(s, vs)]
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/ir/__init__.py
index 555659d0c5..ebb9728737 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -14,8 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""Package tvm.script.ir_builder.ir"""
+from .frame import IRModuleFrame
+from .ir import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/ir/_ffi_api.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/ir/_ffi_api.py
index 555659d0c5..874cc278af 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/ir/_ffi_api.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+"""FFI APIs"""
+import tvm._ffi
 
-from . import tir
-
-from .parser import ir_module, from_source
+tvm._ffi._init_api("script.ir_builder.ir", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/ir_builder/ir/frame.py
similarity index 63%
copy from python/tvm/script/tir/__init__.py
copy to python/tvm/script/ir_builder/ir/frame.py
index 2f2b4bbc25..e16d86dc22 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/frame.py
@@ -14,18 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+"""Package tvm.script.ir_builder.ir.frame"""
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
+from tvm._ffi import register_object as _register_object
 
-from .prim_func import prim_func
+from ..base import IRBuilderFrame
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+@_register_object("script.ir_builder.IRModuleFrame")
+class IRModuleFrame(IRBuilderFrame):
+    ...
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/ir/ir.py
similarity index 79%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/ir/ir.py
index 555659d0c5..df92036435 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -14,8 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+"""Package tvm.script.ir_builder.ir.ir"""
 
-from . import tir
+from . import _ffi_api
+from .frame import IRModuleFrame
 
-from .parser import ir_module, from_source
+
+def ir_module() -> IRModuleFrame:
+    return _ffi_api.IRModule()  # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/tir/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/tir/__init__.py
index 555659d0c5..1e43d1af34 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/tir/__init__.py
@@ -14,8 +14,5 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""Package tvm.script.ir_builder.tir"""
+from .ir import *  # pylint: disable=wildcard-import,redefined-builtin
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/tir/_ffi_api.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/tir/_ffi_api.py
index 555659d0c5..876f5f3a35 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/tir/_ffi_api.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+"""FFI APIs"""
+import tvm._ffi
 
-from . import tir
-
-from .parser import ir_module, from_source
+tvm._ffi._init_api("script.ir_builder.tir", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py
new file mode 100644
index 0000000000..22b03ccdd4
--- /dev/null
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IRBuilder for TIR"""
+from typing import List
+
+from tvm._ffi import register_object as _register_object
+from tvm.tir import Buffer, Var
+
+from ..base import IRBuilderFrame
+
+
+@_register_object("script.ir_builder.tir.TIRFrame")
+class TIRFrame(IRBuilderFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.BlockFrame")
+class BlockFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.ForFrame")
+class ForFrame(TIRFrame):
+    def __enter__(self) -> List[Var]:
+        super().__enter__()
+        return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("script.ir_builder.tir.PrimFuncFrame")
+class PrimFuncFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("script.ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("script.ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
new file mode 100644
index 0000000000..ebd764cf1d
--- /dev/null
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -0,0 +1,954 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""IRBuilder for TIR"""
+import functools
+import inspect
+from numbers import Integral
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
+from tvm.tir import Broadcast as broadcast
+from tvm.tir import (
+    Buffer,
+    BufferLoad,
+    BufferRegion,
+    Cast,
+    CommReducer,
+    IntImm,
+    IterVar,
+    Let,
+    PrimExpr,
+)
+from tvm.tir import Ramp as ramp
+from tvm.tir import Select, Shuffle, StringImm, Var, cast
+from tvm.tir import op as _tir_op
+from tvm.tir import type_annotation
+
+from . import _ffi_api, frame
+
+
+def buffer_decl(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.BufferDecl(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def ptr(dtype, storage_scope="global"):
+    return _ffi_api.Ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+
+def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
+    return _ffi_api.Block(name, no_realize)  # pylint: disable=no-member # type: ignore
+
+
+def init() -> frame.BlockInitFrame:
+    return _ffi_api.Init()  # pylint: disable=no-member # type: ignore
+
+
+def where(predicate) -> None:
+    if isinstance(predicate, bool):
+        predicate = IntImm("bool", predicate)
+    if isinstance(predicate, int):
+        if predicate in [0, 1]:
+            predicate = IntImm("bool", predicate)
+        else:
+            raise ValueError("Invalid value for predicate: {}".format(predicate))
+    _ffi_api.Where(predicate)  # pylint: disable=no-member # type: ignore
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]  # type: ignore
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Reads(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Writes(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.BlockAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def alloc_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.AllocBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def _as_range(dom) -> Range:
+    if isinstance(dom, Range):
+        return dom
+    if isinstance(dom, (list, tuple)):
+        return Range(dom[0], dom[1])
+    return Range(0, dom)
+
+
+class axis:  # pylint: disable=invalid-name
+    @staticmethod
+    def spatial(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisSpatial(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def reduce(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisReduce(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def scan(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisScan(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def opaque(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisOpaque(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def remap(kinds, bindings, dtype="int32") -> Union[List[IterVar], IterVar]:
+        iter_vars = _ffi_api.AxisRemap(  # pylint: disable=no-member # type: ignore
+            kinds, bindings, dtype
+        )
+        return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+    S = spatial  # pylint: disable=invalid-name
+    R = reduce  # pylint: disable=invalid-name
+
+
+def serial(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Serial(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def parallel(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Parallel(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def vectorized(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Vectorized(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def unroll(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Unroll(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def thread_binding(
+    start,
+    stop=None,
+    thread=None,
+    *,
+    annotations=None,
+) -> frame.ForFrame:
+    if thread is None:
+        if not isinstance(stop, str):
+            raise ValueError("Thread cannot be None for thread_binding")
+        thread = stop
+        stop = start
+        start = 0
+    elif stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.ThreadBinding(  # pylint: disable=no-member # type: ignore
+        start, stop, thread, annotations
+    )
+
+
+def grid(*extents) -> frame.ForFrame:
+    return _ffi_api.Grid(extents)  # pylint: disable=no-member # type: ignore
+
+
+def prim_func() -> frame.PrimFuncFrame:
+    return _ffi_api.PrimFunc()  # pylint: disable=no-member # type: ignore
+
+
+def arg(name, obj):
+    return _ffi_api.Arg(name, obj)  # pylint: disable=no-member # type: ignore
+
+
+def func_name(name: str) -> str:
+    return _ffi_api.FuncName(name)  # pylint: disable=no-member # type: ignore
+
+
+def func_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.FuncAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def func_ret(ret_type) -> Type:
+    return _ffi_api.FuncRet(ret_type)  # pylint: disable=no-member # type: ignore
+
+
+def match_buffer(
+    param,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.MatchBuffer(  # pylint: disable=no-member # type: ignore
+        param,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def preflattened_buffer(
+    postflattened,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> None:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    _ffi_api.PreflattenedBuffer(  # pylint: disable=no-member # type: ignore
+        postflattened,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Assert(condition, message)  # pylint: disable=no-member # type: ignore
+
+
+def let(
+    v: Var,
+    value: PrimExpr,
+    body: PrimExpr = None,
+) -> frame.LetFrame:
+    if body is None:
+        return _ffi_api.Let(v, value)  # pylint: disable=no-member # type: ignore
+    return Let(v, value, body)
+
+
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str = "",
+    condition: PrimExpr = None,
+    annotations=None,
+) -> frame.AllocateFrame:
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.Allocate(  # pylint: disable=no-member # type: ignore
+        extents, dtype, scope, condition, annotations
+    )
+
+
+def allocate_const(
+    data: List[PrimExpr],
+    dtype: str,
+    extents: List[PrimExpr],
+    annotations=None,
+) -> frame.AllocateConstFrame:
+
+    return _ffi_api.AllocateConst(  # pylint: disable=no-member # type: ignore
+        ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+    )
+
+
+def realize(
+    buffer_slice: BufferRegion,
+    storage_scope: str,
+    condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+    return _ffi_api.Realize(  # pylint: disable=no-member # type: ignore
+        buffer_slice, storage_scope, condition
+    )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+    node = convert(node)
+    value = convert(value)
+    return _ffi_api.Attr(node, attr_key, value)  # pylint: disable=no-member # type: ignore
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.While(condition)  # pylint: disable=no-member # type: ignore
+
+
+def If(condition: PrimExpr) -> frame.IfFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.If(condition)  # pylint: disable=no-member # type: ignore
+
+
+def Then() -> frame.ThenFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Then()  # pylint: disable=no-member # type: ignore
+
+
+def Else() -> frame.ElseFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Else()  # pylint: disable=no-member # type: ignore
+
+
+def decl_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> frame.DeclBufferFrame:
+
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.DeclBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def launch_thread(
+    iter_var: IterVar,  # pylint: disable=redefined-outer-name
+    extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+    return _ffi_api.LaunchThread(iter_var, extent)  # pylint: disable=no-member # type: ignore
+
+
+def env_thread(thread_tag: str) -> IterVar:
+    return _ffi_api.EnvThread(thread_tag)  # pylint: disable=no-member # type: ignore
+
+
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+    from tvm.arith import Analyzer  # pylint: disable=import-outside-toplevel
+
+    expr_indices = []
+    for index in indices:
+        if isinstance(index, slice):
+            step = 1 if index.step is None else index.step
+            lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+            if lanes == 1:
+                expr_indices.append(index.start)
+            else:
+                expr_indices.append(ramp(index.start, step, int(lanes)))
+        else:
+            expr_indices.append(index)
+    if isinstance(value, bool) and buffer.dtype == "bool":
+        value = IntImm("bool", value)
+    return _ffi_api.BufferStore(  # pylint: disable=no-member # type: ignore
+        buffer, value, expr_indices
+    )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+    return _ffi_api.Prefetch(buffer, indices)  # pylint: disable=no-member # type: ignore
+
+
+def evaluate(value: PrimExpr) -> None:
+    if isinstance(value, str):
+        value = StringImm(value)
+    return _ffi_api.Evaluate(value)  # pylint: disable=no-member # type: ignore
+
+
+def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x4(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Boolean(expr)  # pylint: disable=no-member # type: ignore
+
+
+def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Handle(expr)  # pylint: disable=no-member # type: ignore
+
+
+def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Void(expr)  # pylint: disable=no-member # type: ignore
+
+
+def min(a, b):  # pylint: disable=redefined-builtin
+    """Compute the minimum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.min(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def max(a, b):  # pylint: disable=redefined-builtin
+    """Compute the maximum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.max(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def var(dtype, name="") -> Var:
+    return Var(name, dtype)  # pylint: disable=no-member # type: ignore
+
+
+def iter_var(v, dom, iter_type, thread_tag):
+    iter_type = getattr(IterVar, iter_type)
+    return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner, identity):
+    """Create a CommReducer from lambda inputs/outputs and the identities"""
+    params = inspect.signature(combiner).parameters
+    num_args = len(params)
+    args = []
+    for name, i in zip(params.keys(), identity + identity):
+        if isinstance(i, int):
+            args.append(Var(name, "int32"))
+        else:
+            args.append(Var(name, i.dtype))
+    res = combiner(*args)
+    if not isinstance(res, tuple):
+        res = (res,)
+    return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def llvm_lookup_intrinsic_id(name):
+    # pylint: disable=import-outside-toplevel
+    from tvm.target.codegen import llvm_lookup_intrinsic_id as f
+
+    # pylint: enable=import-outside-toplevel
+    return f(name)
+
+
+def _op_wrapper(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            kwargs.pop("dtype")
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+def _dtype_forward(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            args = (kwargs.pop("dtype"),) + args
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+# pylint: disable=invalid-name
+
+buffer_var = ptr
+abs = _op_wrapper(_tir_op.abs)  # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round)  # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+vectorlow = _dtype_forward(_tir_op.vectorlow)
+vectorhigh = _dtype_forward(_tir_op.vectorhigh)
+vectorcombine = _dtype_forward(_tir_op.vectorcombine)
+assume = _op_wrapper(_tir_op.assume)
+undef = _op_wrapper(_tir_op.undef)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+
+class inline:
+    def __init__(self, value) -> None:
+        self.value = value
+        self.i = 0
+
+    def __iter__(self):
+        def f():
+            for i in self.value:
+                yield inline(i)
+
+        return f()
+
+
+# pylint: enable=invalid-name
+
+
+__all__ = [
+    "Assert",
+    "Cast",
+    "Else",
+    "If",
+    "Let",
+    "Select",
+    "Shuffle",
+    "TVMBackendAllocWorkspace",
+    "TVMBackendFreeWorkspace",
+    "Then",
+    "While",
+    "abs",
+    "acos",
+    "acosh",
+    "address_of",
+    "alloc_buffer",
+    "allocate",
+    "allocate_const",
+    "arg",
+    "asin",
+    "asinh",
+    "assume",
+    "atan",
+    "atan2",
+    "atanh",
+    "attr",
+    "axis",
+    "block",
+    "block_attr",
+    "boolean",
+    "broadcast",
+    "buffer_decl",
+    "buffer_store",
+    "buffer_var",
+    "call_cpacked",
+    "call_cpacked_lowered",
+    "call_extern",
+    "call_intrin",
+    "call_llvm_intrin",
+    "call_llvm_pure_intrin",
+    "call_packed",
+    "call_packed_lowered",
+    "call_pure_extern",
+    "cast",
+    "ceil",
+    "ceildiv",
+    "clz",
+    "comm_reducer",
+    "copysign",
+    "cos",
+    "cosh",
+    "env_thread",
+    "erf",
+    "evaluate",
+    "exp",
+    "exp10",
+    "exp2",
+    "decl_buffer",
+    "fabs",
+    "float16",
+    "float32",
+    "float64",
+    "float8",
+    "floor",
+    "floordiv",
+    "floormod",
+    "fmod",
+    "func_attr",
+    "func_name",
+    "func_ret",
+    "grid",
+    "handle",
+    "hypot",
+    "if_then_else",
+    "infinity",
+    "init",
+    "inline",
+    "int16",
+    "int32",
+    "int32x16",
+    "int32x4",
+    "int32x8",
+    "int64",
+    "int8",
+    "isfinite",
+    "isinf",
+    "isnan",
+    "isnullptr",
+    "iter_var",
+    "launch_thread",
+    "ldexp",
+    "let",
+    "likely",
+    "llvm_lookup_intrinsic_id",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "lookup_param",
+    "match_buffer",
+    "max",
+    "max_value",
+    "min",
+    "min_value",
+    "mma_fill",
+    "mma_store",
+    "nearbyint",
+    "nextafter",
+    "parallel",
+    "popcount",
+    "power",
+    "prefetch",
+    "preflattened_buffer",
+    "prim_func",
+    "ptr",
+    "ptx_commit_group",
+    "ptx_cp_async",
+    "ptx_ldmatrix",
+    "ptx_mma",
+    "ptx_mma_sp",
+    "ptx_wait_group",
+    "q_multiply_shift",
+    "ramp",
+    "reads",
+    "realize",
+    "reinterpret",
+    "ret",
+    "round",
+    "rsqrt",
+    "serial",
+    "shift_left",
+    "shift_right",
+    "sigmoid",
+    "sin",
+    "sinh",
+    "sqrt",
+    "tan",
+    "tanh",
+    "thread_binding",
+    "trunc",
+    "truncdiv",
+    "truncmod",
+    "tvm_access_ptr",
+    "tvm_bmma_sync",
+    "tvm_call_cpacked",
+    "tvm_call_cpacked_lowered",
+    "tvm_call_packed",
+    "tvm_call_packed_lowered",
+    "tvm_fill_fragment",
+    "tvm_load_matrix_sync",
+    "tvm_mma_sync",
+    "tvm_stack_alloca",
+    "tvm_stack_make_array",
+    "tvm_stack_make_shape",
+    "tvm_store_matrix_sync",
+    "tvm_struct_get",
+    "tvm_struct_set",
+    "tvm_thread_allreduce",
+    "tvm_throw_last_error",
+    "tvm_tuple",
+    "type_annotation",
+    "uint16",
+    "uint32",
+    "uint64",
+    "uint8",
+    "undef",
+    "unroll",
+    "var",
+    "vectorcombine",
+    "vectorhigh",
+    "vectorized",
+    "vectorlow",
+    "void",
+    "where",
+    "writes",
+]
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/meta_unparser.py
deleted file mode 100644
index b1472ccdc7..0000000000
--- a/python/tvm/script/meta_unparser.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Unparse meta AST node into a dict"""
-# pylint: disable=invalid-name
-
-from synr import Transformer
-
-
-class MetaUnparser(Transformer):
-    """Python AST Visitor to unparse meta AST node into a dict"""
-
-    def transform(self, node):
-        method = "transform_" + node.__class__.__name__
-        visitor = getattr(self, method, None)
-        if visitor is None:
-            self.error(f"Unexpected node type {type(node)} when parsing __tvm_meta__", node.span)
-        return visitor(node)
-
-    def transform_DictLiteral(self, node):
-        keys = [self.visit(key) for key in node.keys]
-        values = [self.visit(value) for value in node.values]
-        return dict(zip(keys, values))
-
-    def transform_Tuple(self, node):
-        return tuple(self.visit(element) for element in node.elts)
-
-    def transform_ArrayLiteral(self, node):
-        return [self.visit(element) for element in node.elts]
-
-    def transform_Constant(self, node):
-        return node.value
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
deleted file mode 100644
index 908af081c9..0000000000
--- a/python/tvm/script/parser.py
+++ /dev/null
@@ -1,1385 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser For TIR
-
-We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
-different python versions. Synr also provides an error handling context that we
-use for error reporting.
-"""
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
-import json
-import operator
-import inspect
-from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
-
-import tvm
-from tvm import IRModule
-from tvm._ffi.base import TVMError
-from tvm.ir import GlobalVar
-from tvm.ir.function import BaseFunc
-from tvm.tir import buffer
-from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
-
-from .context_maintainer import ContextMaintainer
-from .meta_unparser import MetaUnparser
-from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
-from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
-from .tir.special_stmt import SpecialStmt
-from .tir import ty
-
-
-class CallArgumentReader(object):
-    """Helper class to read required arguments from passed arguments.
-
-    When parsing a function call, we need to match the arguments provided in
-    the AST to the required arguments of the function. This class makes sure
-    all the positional arguments are filled and also fill keyword arguments
-    with thier default value if a different value was not provided.
-    """
-
-    def __init__(self, func_name, args, kwargs, parser, node):
-        self.func_name = func_name
-        self.args = args
-        self.kwargs = kwargs
-        self.parser = parser
-        self.node = node
-
-    def get_pos_only_arg(self, pos, name):
-        """Get corresponding position only function argument from argument list"""
-        if len(self.args) >= pos:
-            arg = self.args[pos - 1]
-        elif name not in self.kwargs:
-            # If no positional argument was found in the AST, we see if it was
-            # defined by name instead.
-            # TODO(tkonolige): this error message is not quite correct. The
-            # number of required arguments is >= pos
-            self.parser.report_error(
-                f"{self.func_name} requires {pos} arguments, but only {len(self.args)} were given.",
-                self.node.span,
-            )
-        else:
-            arg = self.kwargs[name]
-
-        return arg
-
-    def get_kwarg(self, pos, name, default):
-        """Get corresponding keyword function argument from argument list.
-
-        If the user hasn't provided the argument, set it to the default value.
-        """
-        if len(self.args) >= pos:
-            arg = self.args[pos - 1]
-        elif name in self.kwargs:
-            arg = self.kwargs[name]
-        else:
-            return default
-
-        return arg
-
-    def get_varargs(self, pos):
-        """Get corresponding variable argument from argument list"""
-        if len(self.args) >= pos and len(self.kwargs) == 0:
-            return self.args[pos - 1 :]
-        return []
-
-
-class TVMScriptParser(Transformer):
-    """Synr AST visitor pass which finally lowers to TIR.
-
-    Notes for Extension
-    -------------------
-    1. To support a new type of AST node, add a function transform_xxx().
-    2. To support new functions, add the function to the appropriate registry:
-        We divide allowed function calls in TVM script into 3 categories,
-        intrin, scope_handler and special_stmt.
-        1. intrin functions are low level functions like mod, load, and
-           constants. They correspond to a tir `IRNode`. They must have a
-           return value. The user can register intrin functions for the parser to
-           use.
-        2. scope_handler functions have no return value. They take two
-           arguments: the parser and the AST node. scope_handler functions are
-           used in with and for statements.
-        3. special_stmt functions handle cases that do not have a corresponding
-           tir `IRNode`. These functions take the parser and the AST node as
-           arguments and may return a value.
-        When visiting a Call node, we check the special_stmt registry first. If
-        no registered function is found, we then check the intrin registry.
-        When visiting With node, we check the with_scope registry.
-        When visiting For node, we check the for_scope registry.
-    """
-
-    _binop_maker = {
-        ast.BuiltinOp.Add: tvm.tir.Add,
-        ast.BuiltinOp.Sub: tvm.tir.Sub,
-        ast.BuiltinOp.Mul: tvm.tir.Mul,
-        ast.BuiltinOp.Div: tvm.tir.Div,
-        ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv,
-        ast.BuiltinOp.Mod: tvm.tir.FloorMod,
-        ast.BuiltinOp.BitOr: lambda lhs, rhs, span: operator.or_(lhs, rhs),
-        ast.BuiltinOp.BitAnd: lambda lhs, rhs, span: operator.and_(lhs, rhs),
-        ast.BuiltinOp.BitXor: lambda lhs, rhs, span: operator.xor(lhs, rhs),
-        ast.BuiltinOp.GT: tvm.tir.GT,
-        ast.BuiltinOp.GE: tvm.tir.GE,
-        ast.BuiltinOp.LT: tvm.tir.LT,
-        ast.BuiltinOp.LE: tvm.tir.LE,
-        ast.BuiltinOp.Eq: tvm.tir.EQ,
-        ast.BuiltinOp.NotEq: tvm.tir.NE,
-        ast.BuiltinOp.And: tvm.tir.And,
-        ast.BuiltinOp.Or: tvm.tir.Or,
-    }
-
-    _unaryop_maker = {
-        ast.BuiltinOp.USub: lambda rhs, span: operator.neg(rhs),
-        ast.BuiltinOp.Invert: lambda rhs, span: operator.invert(rhs),
-        ast.BuiltinOp.Not: tvm.tir.Not,
-    }
-
-    # pylint gets confused here with synr.Transformer which doesn't have a
-    # custom init, so just disable it
-    def __init__(
-        self, base_lineno, tir_namespace, closure_vars
-    ):  # pylint: disable=super-init-not-called
-        self.context = None
-
-        self.base_lineno = base_lineno
-        self.current_lineno = 0
-        self.current_col_offset = 0
-        self.tir_namespace = tir_namespace
-        self.closure_vars = closure_vars
-        self.meta = None
-        self._inside_buffer_sugar = False
-
-    def init_function_parsing_env(self):
-        """Initialize function parsing environment"""
-        self.context = ContextMaintainer(self.report_error, self.closure_vars)  # scope emitter
-
-    def init_meta(self, meta_dict):
-        if meta_dict is not None:
-            self.meta = tvm.ir.load_json(json.dumps(meta_dict))
-
-    def transform(self, node):
-        """Generic transformation for visiting the AST. Dispatches to
-        `transform_ClassName` for the appropriate ClassName."""
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-
-        if hasattr(node, "lineno"):
-            self.current_lineno = self.base_lineno + node.lineno - 1
-        if hasattr(node, "col_offset"):
-            self.current_col_offset = node.col_offset
-
-        method = "transform_" + node.__class__.__name__
-        visitor = getattr(self, method, self.generic_visit)
-        transform_res = visitor(node)
-
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-
-        return transform_res
-
-    def match_tir_namespace(self, identifier: str) -> bool:
-        """Check if the namespace is equal to tvm.script.tir"""
-        return identifier in self.tir_namespace
-
-    def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]):
-        """Report an error occuring at a location.
-
-        This just dispatches to synr's DiagnosticContext.
-
-        Parameters
-        ----------
-        message : str
-            Error message
-        span : Union[synr.ast.Span, tvm.ir.Span]
-            Location of the error
-        """
-        if isinstance(span, tvm.ir.Span):
-            span = synr_span_from_tvm(span)
-        self.error(message, span)
-
-    def parse_body(self, parent):
-        """Parse remaining statements in this scope.
-
-        Parameters
-        ----------
-        parent : synr.ast.Node
-            Parent node of this scope. Errors will be reported here.
-        """
-        body = []
-        spans = []
-        stmt = parent
-        while len(self.context.node_stack[-1]) > 0:
-            stmt = self.context.node_stack[-1].pop()
-            spans.append(stmt.span)
-            res = self.transform(stmt)
-            if res is not None:
-                body.append(res)
-        if len(body) == 0:
-            self.report_error(
-                "Expected another statement at the end of this block. Perhaps you "
-                "used a concise statement and forgot to include a body afterwards.",
-                stmt.span,
-            )
-        else:
-            return (
-                tvm.tir.SeqStmt(body, tvm_span_from_synr(ast.Span.union(spans)))
-                if len(body) > 1
-                else body[0]
-            )
-
-    def parse_arg_list(self, func, node_call):
-        """Match the arguments of a function call in the AST to the required
-        arguments of the function. This handles positional arguments,
-        positional arguments specified by name, keyword arguments, and varargs.
-
-        Parameters
-        ----------
-        func : Function
-            The function that provides the signature
-
-        node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
-            The AST call node that calls into the function.
-
-        Returns
-        -------
-        arg_list : list
-            The parsed positional argument.
-        """
-        assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
-        # collect arguments
-        args = [self.transform(arg) for arg in node_call.params]
-        if isinstance(node_call, ast.TypeApply):
-            kw_args = {}  # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
-        else:
-            kw_args = {
-                self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
-            }
-        # get the name and parameter list of func
-        if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
-            func_name, param_list = func.signature()
-        else:
-            self.report_error(
-                "Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, "
-                f"but it is {type(func).__name__}",
-                node_call.span,
-            )
-        # check arguments and parameter list and get a list of arguments
-        reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
-        pos_only, kwargs, varargs = param_list
-        internal_args = list()
-
-        for i, arg_name in enumerate(pos_only):
-            internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
-        for i, arg_info in enumerate(kwargs):
-            arg_name, default = arg_info
-            internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default))
-        if varargs is not None:
-            internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1))
-        elif len(args) + len(kw_args) > len(pos_only) + len(kwargs):
-            self.report_error(
-                "Arguments mismatched. "
-                + f"Expected {len(pos_only) + len(kwargs)} args but got "
-                + f"{len(args) + len(kw_args)}",
-                node_call.span,
-            )
-        return internal_args
-
-    def parse_type(self, type_node, parent):
-        """Parse a type annotation.
-
-        We require the parent object to the type so that we have a place to
-        report the error message if the type does not exist.
-        """
-        if type_node is None:
-            self.report_error("A type annotation is required", parent.span)
-        res_type = self.transform(type_node)
-        return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate()
-
-    def generic_visit(self, node):
-        """Fallback visitor if node type is not handled. Reports an error."""
-
-        self.report_error(type(node).__name__ + " AST node is not supported", node.span)
-
-    def transform_Module(self, node):
-        """Module visitor
-
-        Right now, we only support two formats for TVM Script.
-
-        Example
-        -------
-        1. Generate a PrimFunc (If the code is printed, then it may also contain metadata)
-        .. code-block:: python
-
-            import tvm
-
-            @tvm.script
-            def A(...):
-                ...
-
-            # returns a PrimFunc
-            func = A
-
-        2. Generate an IRModule
-        .. code-block:: python
-
-            import tvm
-
-            @tvm.script.ir_module
-            class MyMod():
-                @T.prim_func
-                def A(...):
-                    ...
-                @T.prim_func
-                def B(...):
-                    ...
-
-                __tvm_meta__ = ...
-
-            # returns an IRModule
-            mod = MyMod
-        """
-        if len(node.funcs) == 1:
-            return self.transform(next(iter(node.funcs.values())))
-        elif len(node.funcs) == 0:
-            self.report_error(
-                "You must supply at least one class or function definition", node.span
-            )
-        else:
-            self.report_error(
-                "Only one-function, one-class or function-with-meta source code is allowed",
-                ast.Span.union([x.span for x in list(node.funcs.values())[1:]]),
-            )
-
-    def transform_Class(self, node):
-        """Class definition visitor.
-
-        A class can have multiple function definitions and a single
-        :code:`__tvm_meta__` statement. Each class corresponds to a single
-        :code:`IRModule`.
-
-        Example
-        -------
-        .. code-block:: python
-
-            @tvm.script.ir_module
-            class MyClass:
-                __tvm_meta__ = {}
-                def A():
-                    T.evaluate(0)
-        """
-        if len(node.assignments) == 1:
-            if not (
-                len(node.assignments[0].lhs) == 1
-                and isinstance(node.assignments[0].lhs[0], ast.Var)
-                and node.assignments[0].lhs[0].id.name == "__tvm_meta__"
-            ):
-                self.report_error(
-                    "The only top level assignments allowed are `__tvm_meta__ = ...`",
-                    node.assignments[0].span,
-                )
-            self.init_meta(
-                MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context)
-            )
-        elif len(node.assignments) > 1:
-            self.report_error(
-                "Only a single top level `__tvm_meta__` is allowed",
-                ast.Span.union([x.span for x in node.assignments[1:]]),
-            )
-
-        return IRModule(
-            {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()}
-        )
-
-    def transform_Function(self, node):
-        """Function definition visitor.
-
-        Each function definition is translated to a single :code:`PrimFunc`.
-
-        There are a couple restrictions on TVM Script functions:
-        1. Function arguments must have their types specified.
-        2. The body of the function can contain :code:`func_attr` to specify
-           attributes of the function (like it's name).
-        3. The body of the function can also contain multiple :code:`buffer_bind`s,
-           which give shape and dtype information to arguments.
-        4. Return statements are implicit.
-
-        Example
-        -------
-        .. code-block:: python
-
-            @T.prim_func
-            def my_function(x: T.handle):  # 1. Argument types
-                T.func_attr({"global_symbol": "mmult"})  # 2. Function attributes
-                X_1 = tir.buffer_bind(x, [1024, 1024])  # 3. Buffer binding
-                T.evaluate(0)  # 4. This function returns 0
-        """
-
-        def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]):
-            if isinstance(decorator, ast.Call):
-                if len(decorator.params) != 1:
-                    return False
-                func_name = decorator.func_name
-            else:
-                func_name = decorator
-            if isinstance(func_name, ast.Var):
-                return func_name.id.name == "as_torch"
-
-        def check_decorator(decorators: List[ast.Expr]) -> bool:
-            """Check the decorator is `T.prim_func"""
-            if len(decorators) > 2 or len(decorators) == 0:
-                return False
-            if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]):
-                return False
-            d: ast.Expr = decorators[-1]
-            return (
-                isinstance(d, ast.Attr)
-                and isinstance(d.object, ast.Var)
-                and self.match_tir_namespace(d.object.id.name)
-                and d.field.name == "prim_func"
-            )
-
-        self.init_function_parsing_env()
-        self.context.enter_scope(nodes=node.body.stmts)
-
-        # add parameters of function
-        for arg in node.params:
-            # Note that this case is for T.match_buffer syntax sugar
-            if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
-                self.transform(arg.ty.func_name), ty.GenericBufferType
-            ):
-                result = self.handle_match_buffer_type(arg.ty, arg.name)
-                if not isinstance(result, buffer.Buffer):
-                    self.report_error(
-                        "The result type of evaluating TypeCall and TypeApply stmt"
-                        f" is wrong: {type(result)}. It should be a Buffer",
-                        node.span,
-                    )
-                arg_name_with_handle = arg.name + "_handle"
-                arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
-                self.context.func_buffer_map[arg_var] = result
-                self.context.update_symbol(arg.name, result, node)
-            else:
-                arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
-                self.context.update_symbol(arg.name, arg_var, node)
-            self.context.func_params.append(arg_var)
-
-        if not check_decorator(node.decorators):
-            self.report_error(
-                "All functions should be decorated by `T.prim_func`",
-                node.span,
-            )
-
-        # fetch the body of root block
-        body = self.parse_body(node.body)
-
-        # return a tir.PrimFunc
-        dict_attr = self.context.func_dict_attr
-        ret_type = self.parse_type(node.ret_type, node) if node.ret_type is not None else None
-        func = tvm.tir.PrimFunc(
-            self.context.func_params,
-            body,
-            ret_type,
-            buffer_map=self.context.func_buffer_map,
-            preflattened_buffer_map=self.context.func_preflattened_buffer_map,
-            attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None,
-            span=tvm_span_from_synr(node.span),
-        )
-
-        # New Scope : Implicit root block
-        # Each function contains an implicit root block in TensorIR,
-        # so here we need a block scope for it.
-        # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func),
-        # the root block will not be added. The logic to add root block is in `_ffi_api.Complete`
-
-        # Fix the PrimFunc
-        # 1. generate root block if necessary
-        # 2. generate surrounding loops for blocks if necessary
-
-        func = call_with_error_reporting(
-            self.report_error,
-            node.span,
-            _ffi_api.Complete,
-            func,
-            self.context.root_alloc_buffers,
-        )
-
-        self.context.exit_scope()
-        return func
-
-    def transform_Lambda(self, node):
-        """Lambda visitor
-
-        Return an array of input parameters and the transformed lambda body.
-        """
-
-        self.context.enter_scope(nodes=[node.body])
-
-        # add parameters of the lambda
-        arg_vars = []
-        for arg in node.params:
-            # Use "void" for dtype here. The actual type is not yet known and will be
-            # determined later. Using void type will allow IRSubstitute to do the
-            # replacement without flagging a type-mismatch error.
-            arg_var = tvm.te.var(arg.name, dtype="")
-            arg_vars.append(arg_var)
-            self.context.update_symbol(arg.name, arg_var, node)
-
-        # the body of a lambda must be an expr
-        if not isinstance(node.body, ast.Expr):
-            self.report_error("The body of a lambda must be an expression", node.span)
-
-        # transform the body of the lambda
-        body = self.transform(node.body)
-
-        self.context.exit_scope()
-        return arg_vars, body
-
-    def transform_Assign(self, node):
-        """Assign visitor
-        AST abstract grammar:
-            Assign(expr* targets, expr value, string? type_comment)
-
-        By now 5 patterns of Assign is supported:
-            1. special stmts with return value
-                1.1 Buffer = T.match_buffer()/T.buffer_decl()
-                1.2 Var = T.var()
-                1.3 Var = T.env_thread()
-            2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
-            3. (Store)       Var[PrimExpr] = PrimExpr
-            4. with scope handlers with concise scoping and var def
-                4.1 var = T.allocate()
-            5. A call to a pure python function, consuming and producing TVMScript values.
-               The outputs are inlined into the following body (no variable is created).
-               x, y = f(...)
-        """
-
-        if isinstance(node.rhs, ast.Call):
-            # Pattern 1 & Pattern 4
-            if isinstance(node.rhs.func_name, ast.Op):
-                func = None
-            else:
-                func = self.transform(node.rhs.func_name)
-
-            if isinstance(func, WithScopeHandler):
-                if not func.concise_scope or not func.def_symbol:
-                    self.report_error(
-                        "with scope handler " + func.signature()[0] + " is not suitable here",
-                        node.rhs.span,
-                    )
-                # Pattern 4
-                arg_list = self.parse_arg_list(func, node.rhs)
-                func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-                func.body = self.parse_body(node)
-                return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-            elif isinstance(func, SpecialStmt):
-                # Pattern 1
-                arg_list = self.parse_arg_list(func, node.rhs)
-                func.handle(node, self.context, arg_list, node.rhs.func_name.span)
-                return self.parse_body(node)
-            elif isinstance(func, types.FunctionType):
-                # Pattern 5
-                args = [self.transform(arg) for arg in node.rhs.params]
-                try:
-                    out = func(*args)
-                except Exception as e:
-                    self.report_error(
-                        "Error occured when invoking the function "
-                        + func.__name__
-                        + ": \n"
-                        + str(e),
-                        node.rhs.span,
-                    )
-
-                if len(node.lhs) == 1 and not isinstance(out, list):
-                    out = [out]
-
-                assert len(out) == len(node.lhs)
-
-                for var, value in zip(node.lhs, out):
-                    self.context.update_symbol(var.id.name, value, node)
-
-                body = self.parse_body(node)
-
-                for var, value in zip(node.lhs, out):
-                    self.context.remove_symbol(var.id.name)
-
-                return body
-
-        if isinstance(node.rhs, (ast.Call, ast.Constant)):
-            # Pattern 4 of let binding
-            value = self.transform(node.rhs)
-            if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
-                # This is a little confusing because it only is true when
-                # we have taken this branch. We might need to clarify what
-                # exectly is allowed in Assignments in tvmscript.
-                self.report_error(
-                    "Left hand side of assignment must be an unqualified variable",
-                    node.span,
-                )
-            ast_var = node.lhs[0]
-
-            if node.ty is None and hasattr(value, "dtype"):
-                var_ty = value.dtype
-            else:
-                var_ty = self.parse_type(node.ty, ast_var)
-
-            var = tvm.te.var(
-                ast_var.id.name,
-                var_ty,
-                span=tvm_span_from_synr(ast_var.span),
-            )
-            self.context.update_symbol(var.name, var, node)
-            body = self.parse_body(node)
-            self.context.remove_symbol(var.name)
-            return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
-
-        self.report_error(
-            """Assignments should be one of:
-            1. A "special statement" with return value
-                1.1 Buffer = T.match_buffer()/T.buffer_decl()
-                1.2 Var = T.var()
-                1.3 Var = T.env_thread()
-            2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
-            3. A store into a variable: Var[PrimExpr] = PrimExpr
-            4. A with scope handler with concise scoping and var def
-                4.1 var = T.allocate()
-            5. The right-hand side being a call to a pure python function, consuming and
-               producing TVMScript values.
-               x, y = f(...)""",
-            node.span,
-        )
-
-    def transform_SubscriptAssign(self, node):
-        """Visitor for statements of the form :code:`x[1] = 2`."""
-        symbol = self.transform(node.params[0])
-        indexes = self.transform(node.params[1])
-        rhs = self.transform(node.params[2])
-        rhs_span = tvm_span_from_synr(node.params[2].span)
-        if isinstance(symbol, tvm.tir.Buffer):
-            if len(indexes) != len(symbol.shape):
-                self.report_error(
-                    f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, "
-                    f"cannot be indexed by {len(indexes)}-dimensional indices.",
-                    node.params[1].span,
-                )
-
-            def __convert_index(x):
-                if isinstance(x, Slice):
-                    return x.as_index_expr(self.report_error)
-                return x
-
-            # BufferStore
-            indexes = [__convert_index(x) for x in indexes]
-            return tvm.tir.BufferStore(
-                symbol,
-                tvm.runtime.convert(rhs, span=rhs_span),
-                indexes,
-                span=tvm_span_from_synr(node.span),
-            )
-        else:
-            if symbol.dtype == "handle" and len(indexes) != 1:
-                self.report_error(
-                    "Handles only support one-dimensional indexing. Use `T.match_buffer` to "
-                    "construct a multidimensional buffer from a handle.",
-                    node.params[0].span,
-                )
-            if len(indexes) != 1:
-                self.report_error(
-                    f"Store is only allowed with one index, but {len(indexes)} were provided.",
-                    node.params[1].span,
-                )
-            self.report_error(
-                "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span
-            )
-
-    def transform_AttrAssign(self, node):
-        """Visitor for statements of the form :code:`x.y = 2`."""
-        obj = self.transform(node.params[0])
-        field = node.params[1]
-        value = self.transform(node.params[2])
-
-        if not hasattr(obj, field.name):
-            self.error(f"Field {field.name} does not exist", field.span)
-
-        var = getattr(obj, field.name)
-
-        if not isinstance(var, tvm.tir.Var):
-            self.error(
-                f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span
-            )
-
-        body = self.parse_body(node)
-        return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
-
-    def transform_Assert(self, node):
-        """Assert visitor
-
-        Pattern corresponds to concise mode of :code:`with T.Assert()`.
-        """
-
-        condition = self.transform(node.condition)
-        if node.msg is None:
-            self.report_error("Assert statements must have an error message.", node.span)
-        message = self.transform(node.msg)
-        body = self.parse_body(node)
-        return tvm.tir.AssertStmt(
-            condition, tvm.runtime.convert(message), body, span=tvm_span_from_synr(node.span)
-        )
-
-    def transform_For(self, node):
-        """For visitor
-        AST abstract grammar:
-            For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
-        By now 1 pattern of For is supported:
-            1. for scope handler
-                for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/
-                            T.grid()/T.thread_binding()
-        """
-
-        if not isinstance(node.rhs, ast.Call):
-            self.report_error("The loop iterator should be a function call.", node.rhs.span)
-        func = self.transform(node.rhs.func_name)
-        if not isinstance(func, ForScopeHandler):
-            self.report_error(
-                "Only For scope handlers can be used in a for statement.", node.rhs.func_name.span
-            )
-        # prepare for new for scope
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno = node.span.start_line
-        self.current_col_offset = node.span.start_column
-        self.context.enter_scope(nodes=node.body.stmts)
-        # for scope handler process the scope
-        arg_list = [
-            tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span))
-            for arg in self.parse_arg_list(func, node.rhs)
-        ]
-        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        func.body = self.parse_body(node)
-        res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        # exit the scope
-        self.context.exit_scope()
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-        return res
-
-    def transform_While(self, node):
-        """While visitor
-        AST abstract grammar:
-            While(expr condition, stmt* body)
-        """
-        condition = self.transform(node.condition)
-        # body
-        self.context.enter_scope(nodes=node.body.stmts)
-        body = self.parse_body(node)
-        self.context.exit_scope()
-
-        return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span))
-
-    def transform_With(self, node):
-        """With visitor
-        AST abstract grammar:
-            With(withitem* items, stmt* body, string? type_comment)
-            withitem = (expr context_expr, expr? optional_vars)
-        By now 2 patterns of With is supported:
-            1. with scope handler with symbol def
-                with T.allocate() as targets:
-            2. with scope handler without symbol def
-                with T.block(*axes)/T.let()/T.Assert()/T.attr()/T.realize()
-        """
-
-        if not isinstance(node.rhs, ast.Call):
-            self.report_error(
-                "The context expression of a `with` statement should be a function call.",
-                node.rhs.span,
-            )
-
-        func = self.transform(node.rhs.func_name)
-
-        if not isinstance(func, WithScopeHandler):
-            self.report_error(
-                f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span
-            )
-        # prepare for new block scope
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno = node.body.span.start_line
-        self.current_col_offset = node.body.span.start_column
-        self.context.enter_block_scope(nodes=node.body.stmts)
-        # with scope handler process the scope
-        arg_list = self.parse_arg_list(func, node.rhs)
-        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        func.body = self.parse_body(node)
-        res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        # exit the scope
-        self.context.exit_block_scope()
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-        return res
-
-    def transform_If(self, node):
-        """If visitor
-        AST abstract grammar:
-            If(expr test, stmt* body, stmt* orelse)
-        """
-
-        condition = self.transform(node.condition)
-        # then body
-        self.context.enter_scope(nodes=node.true.stmts)
-        then_body = self.parse_body(node)
-        self.context.exit_scope()
-
-        # else body
-        if len(node.false.stmts) > 0:
-            self.context.enter_scope(nodes=node.false.stmts)
-            else_body = self.parse_body(node)
-            self.context.exit_scope()
-        else:
-            else_body = None
-
-        return tvm.tir.IfThenElse(
-            condition, then_body, else_body, span=tvm_span_from_synr(node.span)
-        )
-
-    def transform_Call(self, node):
-        """Call visitor
-
-        3 different Call patterns are allowed:
-            1. Intrin representing a PrimExpr/IterVar
-                1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max
-                1.2 tir.range/reduce_axis/scan_axis/opaque_axis
-            2. tir.Op(dtype, ...)
-            3. other callable functions
-        """
-
-        if isinstance(node.func_name, ast.Op):
-            if node.func_name.name == ast.BuiltinOp.Subscript:
-                return self.transform_Subscript(node)
-            if node.func_name.name in self._binop_maker:
-                lhs = self.transform(node.params[0])
-                # There is no supertype for everything that can appear in
-                # an expression, so we manually add what we might get here.
-                if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)):
-                    # We would really like to report a more specific
-                    # error here, but this parser contains no distinction
-                    # between parsing statements and parsing expressions. All
-                    # rules just call `transform`.
-                    self.report_error(
-                        f"Left hand side of binary op must be a PrimExpr, "
-                        "but it is a {type(lhs).__name__}",
-                        node.params[0].span,
-                    )
-                rhs = self.transform(node.params[1])
-                if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)):
-                    self.report_error(
-                        f"Right hand side of binary op must be a PrimExpr, "
-                        "but it is a {type(rhs).__name__}",
-                        node.params[1].span,
-                    )
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.span,
-                    lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name](
-                        lhs, rhs, span=span
-                    ),
-                    node,
-                    lhs,
-                    rhs,
-                    tvm_span_from_synr(node.span),
-                )
-            if node.func_name.name in self._unaryop_maker:
-                rhs = self.transform(node.params[0])
-                return self._unaryop_maker[node.func_name.name](
-                    rhs, span=tvm_span_from_synr(node.span)
-                )
-            self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span)
-        else:
-            func = self.transform(node.func_name)
-            if isinstance(func, Intrin) and not func.stmt:
-                # pattern 1
-                arg_list = self.parse_arg_list(func, node)
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.func_name.span,
-                    func.handle,
-                    arg_list,
-                    node.func_name.span,
-                )
-            else:
-                args = [self.transform(arg) for arg in node.params]
-                kw_args = {
-                    self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
-                }
-                if isinstance(func, tvm.tir.op.Op):
-                    if not "dtype" in kw_args.keys():
-                        self.report_error(f"{func} requires a dtype keyword argument.", node.span)
-                    # pattern 2
-                    return tvm.tir.Call(
-                        kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span)
-                    )
-                elif callable(func):
-                    # pattern 3
-                    return func(*args, **kw_args)
-                else:
-                    self.report_error(
-                        f"Function is neither callable nor a tvm.tir.op.Op (it is a {type(func)}).",
-                        node.func_name.span,
-                    )
-
-    def transform_UnassignedCall(self, node):
-        """Visitor for statements that are function calls.
-
-        This handles function calls that appear on thier own line like `tir.realize`.
-
-        Examples
-        --------
-        .. code-block:: python
-
-            @T.prim_func
-            def f():
-                A = T.buffer_decl([10, 10])
-                T.realize(A[1:2, 1:2], "")  # This is an UnassignedCall
-                A[1, 1] = 2  # This is also an UnassignedCall
-        """
-        # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
-        if isinstance(node.call.func_name, ast.Op):
-            if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign:
-                return self.transform_SubscriptAssign(node.call)
-
-            if node.call.func_name.name == ast.BuiltinOp.AttrAssign:
-                return self.transform_AttrAssign(node.call)
-
-            self.report_error(
-                "Binary and unary operators are not allowed as a statement", node.span
-            )
-
-        # handle a regular function call
-        func = self.transform(node.call.func_name)
-        arg_list = self.parse_arg_list(func, node.call)
-
-        if isinstance(func, tir.scope_handler.AssertHandler):
-            self.report_error(
-                "A standalone `T.Assert` is not allowed. Use `assert condition, message` "
-                "instead.",
-                node.call.func_name.span,
-            )
-
-        if isinstance(func, Intrin):
-            if func.stmt:
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.call.func_name.span,
-                    func.handle,
-                    arg_list,
-                    node.call.func_name.span,
-                )
-            else:
-                self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span)
-        elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
-            func.enter_scope(node, self.context, arg_list, node.call.func_name.span)
-            func.body = self.parse_body(node)
-            return func.exit_scope(node, self.context, arg_list, node.call.func_name.span)
-        elif isinstance(func, SpecialStmt) and not func.def_symbol:
-            func.handle(node, self.context, arg_list, node.call.func_name.span)
-            return
-
-        self.report_error(
-            "Unexpected statement. Expected an assert, an intrinsic, a with statement, or a "
-            f"special statement, but got {type(func).__name__}.",
-            node.call.func_name.span,
-        )
-
-    def transform_Slice(self, node):
-        """Index slice visitor."""
-        start = self.transform(node.start)
-        end = self.transform(node.end)
-        if not (
-            isinstance(node.step, ast.Constant)
-            and isinstance(node.step.value, int)
-            and node.step.value > 0
-        ):
-            self.report_error(
-                "Only positive integer step size is supported for slices.", node.step.span
-            )
-        return Slice(start, end, node.step.value, tvm_span_from_synr(node.span))
-
-    def transform_Subscript(self, node):
-        """Array access visitor.
-
-        By now only 3 types of Subscript are supported:
-            1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
-               Var[index] Buffer element access()
-            2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...]))
-            3. Array[index], Buffer element access
-        """
-
-        symbol = self.transform(node.params[0])
-        if symbol is None:
-            self.report_error(
-                f"Variable {node.params[0].id.name} is not defined.", node.params[0].span
-            )
-
-        indexes = [self.transform(x) for x in node.params[1].values]
-        if isinstance(symbol, tvm.tir.expr.Var):
-            if symbol.dtype == "handle":
-                self.report_error(
-                    "Cannot read directly from a handle, use `T.match_buffer` "
-                    "to create a buffer to read from.",
-                    node.params[0].span,
-                )
-            if len(indexes) > 1:
-                self.report_error(
-                    "Only a single index can be provided when indexing into a `var`.",
-                    node.params[1].span,
-                )
-            index = indexes[0]
-            if not isinstance(index, (tvm.tir.PrimExpr, int)):
-                self.report_error(
-                    "Var load index should be an int or PrimExpr, but it is a" + type(index),
-                    node.span,
-                )
-
-            self.report_error(
-                "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span
-            )
-        elif isinstance(symbol, tvm.tir.Buffer):
-            return BufferSlice(
-                symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
-            )
-        elif isinstance(symbol, tvm.container.Array):
-            if len(indexes) > 1:
-                self.report_error(
-                    "Array access should be one-dimension access, but the indices are "
-                    + str(indexes),
-                    node.span,
-                )
-            index = indexes[0]
-            if not isinstance(index, (int, tvm.tir.expr.IntImm)):
-                self.report_error(
-                    "Array access index expected int or IntImm, but got " + type(index),
-                    node.span,
-                )
-            if int(index) >= len(symbol):
-                self.report_error(
-                    f"Array access out of bound, size: {len(symbol)}, got index {index}.",
-                    node.span,
-                )
-            return symbol[int(index)]
-        else:
-            self.report_error(
-                f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
-                "buffers are supported.",
-                node.params[0].span,
-            )
-
-    def transform_Attr(self, node):
-        """Visitor for field access of the form `x.y`.
-
-        This visitor is used to lookup function and symbol names. We have two
-        cases to handle here:
-        1. If we have a statement of the form `tir.something`, then we lookup
-           `tir.something` in the `Registry`. If the function is not in the
-           registry, then we try to find a `tvm.ir.op.Op` with the same name.
-        2. All other names `tvm.something` are lookup up in this current python
-           namespace.
-        """
-
-        def get_full_attr_name(node: ast.Attr) -> str:
-            reverse_field_names = [node.field.name]
-            while isinstance(node.object, ast.Attr):
-                node = node.object
-                reverse_field_names.append(node.field.name)
-            if isinstance(node.object, ast.Var):
-                reverse_field_names.append(node.object.id.name)
-            return ".".join(reversed(reverse_field_names))
-
-        if isinstance(node.object, (ast.Var, ast.Attr)):
-            full_attr_name = get_full_attr_name(node)
-            attr_object, fields = full_attr_name.split(".", maxsplit=1)
-            if self.match_tir_namespace(attr_object):
-                func_name = "tir." + fields
-                res = Registry.lookup(func_name)
-                if res is not None:
-                    return res
-                try:
-                    return tvm.ir.op.Op.get(func_name)
-                except TVMError as e:
-                    # Check if we got an attribute error
-                    if e.args[0].find("AttributeError"):
-                        self.report_error(f"Unregistered function `tir.{fields}`.", node.span)
-                    else:
-                        raise e
-
-        symbol = self.transform(node.object)
-        if symbol is None:
-            self.report_error("Unsupported Attribute expression.", node.object.span)
-        if not hasattr(symbol, node.field.name):
-            self.report_error(
-                f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span
-            )
-        res = getattr(symbol, node.field.name)
-        return res
-
-    def transform_TypeAttr(self, node):
-        """Visitor for field access of the form `x.y` for types.
-
-        We have two cases here:
-        1. If the type is of the form `T.something`, we look up the type in
-           the `tir` namespace in this module.
-        2. If the type is of the form `tvm.x.something` then we look up
-           `tvm.x.something` in this modules namespace.
-        """
-        if isinstance(node.object, ast.TypeVar):
-            if self.match_tir_namespace(node.object.id.name):
-                if not hasattr(tir, node.field.name):
-                    self.report_error(
-                        f"Invalid type annotation `tir.{node.field.name}`.", node.span
-                    )
-                return getattr(tir, node.field.name)
-
-        symbol = self.transform(node.object)
-        if symbol is None:
-            self.report_error("Unsupported Attribute expression", node.object.span)
-        if not hasattr(symbol, node.field):
-            self.report_error(
-                f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span
-            )
-        res = getattr(symbol, node.field)
-        return res
-
-    def transform_DictLiteral(self, node):
-        """Dictionary literal visitor.
-
-        Handles dictionary literals of the form `{x:y, z:2}`.
-        """
-
-        keys = [self.transform(key) for key in node.keys]
-        values = [self.transform(value) for value in node.values]
-
-        return dict(zip(keys, values))
-
-    def transform_Tuple(self, node):
-        """Tuple visitor.
-
-        Handles tuples of the form `(x, y, 2)`.
-        """
-
-        return tuple(self.transform(element) for element in node.values)
-
-    def transform_ArrayLiteral(self, node):
-        """List literal visitor.
-
-        Handles lists of the form `[x, 2, 3]`.
-        """
-
-        return [self.transform(element) for element in node.values]
-
-    def transform_Var(self, node):
-        """Variable visitor
-
-        Handles variables like `x` in `x = 2`.
-        """
-
-        name = node.id.name
-        if name == "meta":
-            return self.meta
-        symbol = Registry.lookup(name)
-        if symbol is not None:
-            return symbol
-        symbol = self.context.lookup_symbol(name)
-        if symbol is not None:
-            return symbol
-        self.report_error(f"Unknown identifier {name}.", node.span)
-
-    def transform_TypeVar(self, node):
-        """Type variable visitor.
-
-        Equivalent to `transform_Var` but for types.
-        """
-        name = node.id.name
-        symbol = Registry.lookup(name) or self.context.lookup_symbol(name)
-        if symbol is not None:
-            return symbol
-        self.report_error(f"Unknown identifier {name}.", node.span)
-
-    def transform_Constant(self, node):
-        """Constant value visitor.
-
-        Constant values include `None`, `"strings"`, `2` (integers), `4.2`
-        (floats), and `true` (booleans).
-        """
-        return tvm.runtime.convert(node.value, span=tvm_span_from_synr(node.span))
-
-    def transform_TypeConstant(self, node):
-        """Constant value visitor for types.
-
-        See `transform_Constant`.
-        """
-        if self._inside_buffer_sugar:
-            return self.transform_Constant(node)
-
-        return node.value
-
-    def transform_TypeTuple(self, node):
-        """Tuple value visitor for types.
-
-        Mostly used in `transform_TypeCall` and `transform_TypeApply`.
-        """
-        return [self.transform(value) for value in node.values]
-
-    def transform_TypeCall(self, node):
-        """TypeCall visitor
-
-        This occurs when an expression is used inside a T.Buffer
-        parameter annotation.
-        """
-
-        # ast.Call has the BuiltinOp as node.func_name.name, where
-        # ast.TypeCall has the BuiltinOp as node.func_name.  So we can
-        # delegate to self.transform_Call, but the error messages for
-        # unsupported operations will highlight the entire expression
-        # and not just the function itself.
-        op = ast.Op(node.span, node.func_name)
-        call = ast.Call(node.span, op, node.params, node.keyword_params)
-        return self.transform_Call(call)
-
-    def transform_TypeApply(self, node):
-        """Visitor for Type[Type] expressions.
-
-        Mostly used for ``T.Ptr`` expressions.
-        """
-        func = self.transform(node.func_name)
-
-        if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"):
-            self.report_error(
-                f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), "
-                f"but found {type(func).__name__} instead.",
-                node.span,
-            )
-
-        param_types = []
-        for idx, param in enumerate(node.params):
-            param_type = self.transform(param)
-            if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx):
-                self.report_error(
-                    f"Expected a type but found {type(param).__name__} "
-                    f"at {idx}th type argument",
-                    param.span,
-                )
-
-            param_types.append(param_type)
-
-        if len(param_types) == 1:
-            return func[param_types[0]]
-        else:
-            return func[param_types]
-
-    def handle_match_buffer_type(self, node, buffer_name):
-        """special function to handle syntax sugar for match buffer.
-
-        This method is for buffer declarations in the function parameters.
-        """
-        func = self.transform(node.func_name)
-        assert isinstance(func, SpecialStmt)
-
-        # parse args and kwargs for TypeCall and TypeApply
-        self._inside_buffer_sugar = True
-        try:
-            arg_list = self.parse_arg_list(func, node)
-        finally:
-            self._inside_buffer_sugar = False
-
-        # Note that the third element in arg_list would always be the 'name'
-        # TODO: This index is hardcoded as a workaround. Better to make it programmatic
-        if arg_list[2] is None:
-            arg_list[2] = buffer_name
-        buf = func.handle(node, self.context, arg_list, node.func_name.span)
-        return buf
-
-    def transform_Return(self, node):
-        self.report_error(
-            "TVM script does not support return statements. Instead the last statement in any "
-            "block is implicitly returned.",
-            node.span,
-        )
-
-
-def get_tir_namespace(script: Union[Callable, type]) -> List[str]:
-    assert inspect.isfunction(script) or inspect.isclass(script)
-    env: Dict[str, Any] = script.__globals__
-    return [key for key in env.keys() if env[key] == tir]
-
-
-def from_source(
-    input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None
-) -> Union[PrimFunc, IRModule]:
-    """Parse function or string into PrimFunc or IRModule.
-
-    If possible, pass the TVM script in as a function so that line numbers and
-    filename will be accurate.
-
-    Parameters
-    ----------
-    input_module : Union[str, Callable]
-        The python function to be parsed.
-
-    tir_prefix : Optional[List[str]]
-        The tir prefix list. Only works for str input, default by "tir" and "T".
-
-    Returns
-    -------
-    output : Union[Function, Module]
-        The Function or Module in IR.
-    """
-    if isinstance(input_func, str):
-        tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
-        return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
-    elif inspect.isfunction(input_func):
-        _, start_line = inspect.getsourcelines(input_func)
-        env: Dict[str, Any] = input_func.__globals__
-        namespace = [key for key in env.keys() if env[key] is tir]
-        _closure_vars = inspect.getclosurevars(input_func)
-        closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
-        parser = TVMScriptParser(start_line, namespace, closure_vars)
-        result = to_ast(input_func, TVMDiagnosticCtx(), parser)
-        return result
-    else:
-        raise TypeError("Only function definitions are supported.")
-
-
-def ir_module(input_module: type) -> IRModule:
-    """Decorate a python class as tvm IRModule.
-
-    Parameters
-    ----------
-    input_module : type
-        The python class to be parsed.
-
-    Returns
-    -------
-    output : IRModule
-        The result IRModule.
-    """
-    if inspect.isclass(input_module):
-        func_dict = {
-            name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
-        }
-        return IRModule(func_dict)
-    raise TypeError("Only class definitions are supported.")
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 83%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..5161a2601c 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,9 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+# under the Licens.
+"""The parser"""
+from . import _core, ir, tir
+from ._core import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/_core.py
similarity index 76%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/_core.py
index 555659d0c5..4f5411dc36 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/_core.py
@@ -13,9 +13,10 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+# under the Licens.
+"""The core parser infra"""
+# pylint: disable=unused-import
+from .core import dispatch, doc, utils
+from .core.dispatch import OpMethod, register_op
+from .core.entry import parse
+from .core.parser import Parser
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/core/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/core/__init__.py
index 555659d0c5..94d8dab032 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/core/__init__.py
@@ -14,8 +14,5 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""The core parser infra"""
+from . import diagnostics, dispatch, doc, doc_core, entry, evaluator, parser, utils
diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py
new file mode 100644
index 0000000000..51c26bbc24
--- /dev/null
+++ b/python/tvm/script/parser/core/diagnostics.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+import re
+import sys
+from typing import Union
+
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+
+
+class Source:
+    source_name: str
+    start_line: int
+    start_column: int
+    source: str
+    full_source: str
+
+    def __init__(self, program: Union[str, doc.AST]):
+        if isinstance(program, str):
+            self.source_name = "<str>"
+            self.start_line = 1
+            self.start_column = 0
+            self.source = program
+            self.full_source = program
+            return
+
+        self.source_name = inspect.getsourcefile(program)  # type: ignore
+        lines, self.start_line = getsourcelines(program)  # type: ignore
+        if lines:
+            self.start_column = len(lines[0]) - len(lines[0].lstrip())
+        else:
+            self.start_column = 0
+        if self.start_column and lines:
+            self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+        else:
+            self.source = "".join(lines)
+        try:
+            # It will cause a problem when running in Jupyter Notebook.
+            # `mod` will be <module '__main__'>, which is a built-in module
+            # and `getsource` will throw a TypeError
+            mod = inspect.getmodule(program)
+            if mod:
+                self.full_source = inspect.getsource(mod)
+            else:
+                self.full_source = self.source
+        except TypeError:
+            # It's a work around for Jupyter problem.
+            # Since `findsource` is an internal API of inspect, we just use it
+            # as a fallback method.
+            src, _ = inspect.findsource(program)  # type: ignore
+            self.full_source = "".join(src)
+
+    def as_ast(self) -> doc.AST:
+        return doc.parse(self.source)
+
+
+_getfile = inspect.getfile  # pylint: disable=invalid-name
+_findsource = inspect.findsource  # pylint: disable=invalid-name
+
+
+def _patched_inspect_getfile(obj):
+    if not inspect.isclass(obj):
+        return _getfile(obj)
+    mod = getattr(obj, "__module__", None)
+    if mod is not None:
+        file = getattr(sys.modules[mod], "__file__", None)
+        if file is not None:
+            return file
+    for _, member in inspect.getmembers(obj):
+        if inspect.isfunction(member):
+            if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+                return inspect.getfile(member)
+    raise TypeError(f"Source for {obj:!r} not found")
+
+
+def findsource(obj):
+    import linecache  # pylint: disable=import-outside-toplevel
+
+    if not inspect.isclass(obj):
+        return _findsource(obj)
+
+    file = inspect.getsourcefile(obj)
+    if file:
+        linecache.checkcache(file)
+    else:
+        file = inspect.getfile(obj)
+        if not (file.startswith("<") and file.endswith(">")):
+            raise OSError("source code not available")
+
+    module = inspect.getmodule(obj, file)
+    if module:
+        lines = linecache.getlines(file, module.__dict__)
+    else:
+        lines = linecache.getlines(file)
+    if not lines:
+        raise OSError("could not get source code")
+    qual_names = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
+    pattern_list = []
+    for name in qual_names:
+        if name.endswith("<locals>"):
+            pattern_list.append(re.compile(r"^(\s*)def\s*" + name[:-8] + r"\b"))
+        else:
+            pattern_list.append(re.compile(r"^(\s*)class\s*" + name + r"\b"))
+    for i, line in enumerate(lines):
+        match = pattern_list[0].match(line)
+        if match:
+            pattern_list.pop(0)
+        if not pattern_list:
+            return lines, i
+    raise OSError("could not find class definition")
+
+
+def getsourcelines(obj):
+    obj = inspect.unwrap(obj)
+    lines, l_num = findsource(obj)
+    return inspect.getblock(lines[l_num:]), l_num + 1
+
+
+inspect.getfile = _patched_inspect_getfile
+
+
+class Diagnostics:
+
+    source: Source
+    ctx: diagnostics.DiagnosticContext
+
+    def __init__(self, source: Source):
+        mod = IRModule()
+        mod.source_map.add(source.source_name, source.full_source)
+        self.source = source
+        self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+    def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+        lineno = node.lineno or self.source.start_line
+        col_offset = node.col_offset or self.source.start_column
+        end_lineno = node.end_lineno or lineno
+        end_col_offset = node.end_col_offset or col_offset
+        lineno += self.source.start_line - 1
+        end_lineno += self.source.start_line - 1
+        col_offset += self.source.start_column + 1
+        end_col_offset += self.source.start_column + 1
+        self.ctx.emit(
+            diagnostics.Diagnostic(
+                level=level,
+                span=Span(
+                    source_name=SourceName(self.source.source_name),
+                    line=lineno,
+                    end_line=end_lineno,
+                    column=col_offset,
+                    end_column=end_col_offset,
+                ),
+                message=message,
+            )
+        )
+
+    def error(self, node: doc.AST, message: str) -> None:
+        self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+        self.ctx.render()
diff --git a/python/tvm/script/parser/core/dispatch.py b/python/tvm/script/parser/core/dispatch.py
new file mode 100644
index 0000000000..f10b90961a
--- /dev/null
+++ b/python/tvm/script/parser/core/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+    """Register a method for a dispatch token and type name"""
+
+    def f(method: ParseMethod):
+        ParseVTable[(token, type_name)] = method
+
+    return f
+
+
+def get(
+    token: str,
+    type_name: str,
+    default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+    return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int):  # pylint: disable=invalid-name
+    def f(method: OpMethod):
+        OpVTable[(ty, op, operand_index)] = method
+
+    return f
+
+
+def get_op(  # pylint: disable=invalid-name
+    ty: Type,
+    op: Type,
+    operand_index: int,
+    default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+    return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py
new file mode 100644
index 0000000000..f6a641cb64
--- /dev/null
+++ b/python/tvm/script/parser/core/doc.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import *  # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+    to_doc: typing.Optional[FnToDoc]
+    from_doc: typing.Optional[FnFromDoc]
+
+    def __init__(self):
+        self.to_doc = None
+        self.from_doc = None
+
+
+class Registry:
+    _inst: typing.Optional["Registry"] = None
+    table: typing.Dict[str, Entry]
+
+    def __init__(self):
+        self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+    def f(to_doc: FnToDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].to_doc = to_doc
+
+    return f
+
+
+def register_from_doc(name: str):
+    def f(to_doc: FnFromDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].from_doc = to_doc
+
+    return f
+
+
+def _is_atomic_type(node):
+    return (
+        node is None
+        or node in [..., True, False]
+        or isinstance(
+            node,
+            (
+                int,
+                float,
+                str,
+                bool,
+                bytes,
+                complex,
+            ),
+        )
+    )
+
+
+def _get_registry_entry(cls_name, attr):
+    cls_name = cls_name.split(".")[-1]
+    reg = Registry._inst  # pylint: disable=protected-access
+    if cls_name in reg.table:
+        entry = reg.table[cls_name]
+        return getattr(entry, attr, None)
+    return None
+
+
+def from_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(from_doc(n) for n in node)
+    if isinstance(node, list):
+        return [from_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "from_doc")
+    if not func:
+        raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def to_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(to_doc(n) for n in node)
+    if isinstance(node, list):
+        return [to_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "to_doc")
+    if not func:
+        raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def parse(
+    source,
+    filename="<unknown>",
+    mode="exec",
+) -> doc.AST:
+    try:
+        program = ast.parse(  # pylint: disable=unexpected-keyword-arg
+            source=source,
+            filename=filename,
+            mode=mode,
+            feature_version=(3, 8),
+        )
+    except:  # pylint: disable=bare-except
+        program = ast.parse(
+            source=source,
+            filename=filename,
+            mode=mode,
+        )
+    return to_doc(program)
+
+
+class NodeVisitor:
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> None:
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                self.visit(value)
+
+
+class NodeTransformer:
+    def visit(self, node: doc.AST) -> doc.AST:
+        if isinstance(node, list):
+            return [self.visit(item) for item in node]
+        if isinstance(node, tuple):
+            return tuple(self.visit(item) for item in node)
+        if not isinstance(node, doc.AST):
+            return node
+        return getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> doc.AST:
+        kv: typing.Dict[str, typing.Any] = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                value = self.visit(value)
+            kv[field] = value
+        return node.__class__(**kv)
+
+
+def _register_default():
+    class DefaultTranslator:
+        def __init__(self, doc_cls, func, fields):
+            self.doc_cls = doc_cls  # getattr(doc, name)
+            self.func = func
+            self.fields = fields
+
+        def __call__(self, node):
+            kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+            return self.doc_cls(**kv)
+
+    Registry._inst = Registry()  # pylint: disable=protected-access
+    for cls_name in dir(doc):
+        doc_cls = getattr(doc, cls_name)
+        if not hasattr(ast, cls_name):
+            continue
+        if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+            assert "." not in cls_name
+            register_to_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(doc, cls_name),
+                    to_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+            register_from_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(ast, cls_name),
+                    from_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+    return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+    if _py_version() not in [(3, 6), (3, 7)]:
+        return
+
+    def as_constant(f) -> doc.Constant:
+        def to_doc_func(x: ast.AST) -> doc.Constant:
+            return doc.Constant(
+                value=getattr(x, f) if isinstance(f, str) else f(x),
+                kind=None,
+                s=None,
+                n=None,
+                lineno=x.lineno,
+                col_offset=x.col_offset,
+                end_lineno=x.lineno,
+                end_col_offset=x.col_offset,
+            )
+
+        return to_doc_func
+
+    register_to_doc("Str")(as_constant("s"))
+    register_to_doc("NameConstant")(as_constant("value"))
+    register_to_doc("Num")(as_constant("n"))
+    register_to_doc("Bytes")(as_constant("s"))
+    register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+        if isinstance(x.slice, ast.Slice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Slice(
+                    lower=to_doc(x.slice.lower),
+                    upper=to_doc(x.slice.upper),
+                    step=to_doc(x.slice.step),
+                    lineno=getattr(x.slice, "lineno", None),
+                    col_offset=getattr(x.slice, "col_offset", None),
+                    end_lineno=getattr(x.slice, "end_lineno", None),
+                    end_col_offset=getattr(x.slice, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.ExtSlice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Tuple(
+                    elts=[to_doc(i) for i in x.slice.dims],
+                    ctx=doc.Load(
+                        lineno=None,
+                        col_offset=None,
+                        end_lineno=None,
+                        end_col_offset=None,
+                    ),
+                    lineno=getattr(x, "lineno", None),
+                    col_offset=getattr(x, "col_offset", None),
+                    end_lineno=getattr(x, "end_lineno", None),
+                    end_col_offset=getattr(x, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.Index):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=to_doc(x.slice.value),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+    def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+        if isinstance(x.slice, doc.Slice):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=from_doc(x.slice),
+                ctx=from_doc(x.ctx),
+            )
+        elif isinstance(x.slice, doc.Tuple):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.ExtSlice(
+                    dims=[from_doc(i) for i in x.slice.elts],
+                ),
+                ctx=from_doc(x.ctx),
+            )
+        else:
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.Index(value=from_doc(x.slice)),
+                ctx=from_doc(x.ctx),
+            )
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Subscript")(subscript_to_doc)
+    register_from_doc("Subscript")(subscript_from_doc)
+
+
+def _register_index_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def index_to_doc(x: ast.Index) -> doc.Expr:
+        return to_doc(x.value)
+
+    def index_from_doc(x: doc.Expr) -> ast.Index:
+        result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx))
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Index")(index_to_doc)
+    register_from_doc("Index")(index_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
+_register_index_handling()
diff --git a/python/tvm/script/printer/doc_core.py b/python/tvm/script/parser/core/doc_core.py
similarity index 100%
rename from python/tvm/script/printer/doc_core.py
rename to python/tvm/script/parser/core/doc_core.py
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/core/entry.py
similarity index 51%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/core/entry.py
index 923eb97d27..ccf42e8c15 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/core/entry.py
@@ -14,32 +14,30 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
 
-import inspect
-from typing import Callable
+from ...ir_builder import IRBuilder
+from . import doc
+from .diagnostics import Source
+from .parser import Parser
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
 
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+    if extra_vars is None:
+        from tvm.script.parser import ir  # pylint: disable=import-outside-toplevel
+        from tvm.script.parser import tir  # pylint: disable=import-outside-toplevel
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
+        extra_vars = {
+            "I": ir,
+            "ir": ir,
+            "T": tir,
+            "tir": tir,
+        }
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
-
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
-
-    raise TypeError("Only function definitions are supported.")
+    source = Source(program)
+    parser = Parser(source)
+    with IRBuilder() as builder:
+        parser.parse(extra_vars=extra_vars)
+    return builder.get()
diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py
new file mode 100644
index 0000000000..405281a65e
--- /dev/null
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+    doc.Add: lambda a, b: a + b,
+    doc.Sub: lambda a, b: a - b,
+    doc.Mult: lambda a, b: a * b,
+    doc.Div: lambda a, b: a / b,
+    doc.FloorDiv: lambda a, b: a // b,
+    doc.Mod: lambda a, b: a % b,
+    doc.LShift: lambda a, b: a << b,
+    doc.RShift: lambda a, b: a >> b,
+    doc.BitOr: lambda a, b: a | b,
+    doc.BitXor: lambda a, b: a ^ b,
+    doc.BitAnd: lambda a, b: a & b,
+    doc.MatMult: lambda a, b: a @ b,
+    doc.Pow: lambda a, b: a ** b,
+    doc.Eq: lambda a, b: a == b,
+    doc.NotEq: lambda a, b: a != b,
+    doc.Lt: lambda a, b: a < b,
+    doc.LtE: lambda a, b: a <= b,
+    doc.Gt: lambda a, b: a > b,
+    doc.GtE: lambda a, b: a >= b,
+    doc.Is: lambda a, b: a is b,
+    doc.IsNot: lambda a, b: a is not b,
+    doc.In: lambda a, b: a in b,
+    doc.NotIn: lambda a, b: a not in b,
+    doc.And: lambda a, b: a and b,
+    doc.Or: lambda a, b: a or b,
+    doc.Invert: lambda a: ~a,
+    doc.Not: lambda a: not a,
+    doc.UAdd: lambda a: +a,
+    doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+    parser: "Parser"
+    value_table: Dict[str, Any]
+    new_value_count: int
+
+    def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+        super().__init__()
+        self.parser = parser
+        self.value_table = value_table
+        self.new_value_count = 0
+
+    @staticmethod
+    def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+        self = ExprEvaluator(parser, value_table)
+        result = self._visit(node)  # pylint: disable=protected-access
+        if isinstance(result, doc.Name):
+            if result.id not in self.value_table:
+                self.parser.report_error(result, f"Undefined variable: {result.id}")
+            return self.value_table[result.id]
+        if isinstance(result, doc.Constant):
+            return result.value
+        raise TypeError(f"Unexpected result type: {type(result)}")
+
+    def _add_intermediate_result(self, value: Any) -> doc.Name:
+        name = f"__tvm_tmp_value_{self.new_value_count}"
+        self.new_value_count += 1
+        self.value_table[name] = value
+        lineno = 0
+        col_offset = 0
+        return doc.Name(
+            id=name,
+            ctx=doc.Load(
+                lineno=lineno,
+                col_offset=col_offset,
+                end_lineno=None,
+                end_col_offset=None,
+            ),
+            lineno=lineno,
+            col_offset=col_offset,
+            end_lineno=None,
+            end_col_offset=None,
+        )
+
+    def _visit(self, node: doc.AST) -> Any:
+        if isinstance(node, list):
+            return [self._visit(n) for n in node]
+        if isinstance(node, tuple):
+            return tuple(self._visit(n) for n in node)
+        assert isinstance(node, doc.AST)
+        if isinstance(node, doc.Name):
+            if node.id not in self.value_table:
+                self.parser.report_error(node, f"Undefined variable: {node.id}")
+            return node
+        if isinstance(
+            node,
+            (
+                doc.Constant,
+                doc.expr_context,
+                doc.operator,
+                doc.boolop,
+                doc.unaryop,
+                doc.cmpop,
+            ),
+        ):
+            return node
+        if not isinstance(node, (doc.expr, doc.slice)):
+            return node
+        if isinstance(node, doc.Lambda):
+            return self._eval_lambda(node)
+        fields = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            attr = getattr(node, field)
+            if isinstance(attr, (doc.AST, tuple, list)):
+                fields[field] = self._visit(attr)
+            else:
+                fields[field] = attr
+        try:
+            if isinstance(node, doc.BoolOp):
+                value = self._eval_bool_op(fields)
+            elif isinstance(node, doc.Compare):
+                value = self._eval_compare(fields)
+            elif isinstance(node, doc.UnaryOp):
+                value = self._eval_unary_op(fields)
+            elif isinstance(node, doc.BinOp):
+                value = self._eval_bin_op(fields)
+            elif isinstance(node, doc.Slice):
+                value = self._eval_slice(fields)
+            else:
+                value = self._eval_expr(node.__class__(**fields))
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_lambda(self, node: doc.Lambda) -> Any:
+        try:
+            value = self._eval_expr(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+        op = fields["op"]
+        if not isinstance(op, (doc.And, doc.Or)):
+            raise TypeError(f"Unexpected operator: {op}")
+        value = self._eval_expr(fields["values"][0])
+        for rhs in fields["values"][1:]:
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["left"])
+        for op, rhs in zip(fields["ops"], fields["comparators"]):
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["operand"])
+        value = _eval_op(fields["op"], values=[value])
+        return value
+
+    def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+        return _eval_op(
+            fields["op"],
+            values=[
+                self._eval_expr(fields["left"]),
+                self._eval_expr(fields["right"]),
+            ],
+        )
+
+    def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+        lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+        lower = self._eval_expr(lower) if lower is not None else None
+        upper = self._eval_expr(upper) if upper is not None else None
+        step = self._eval_expr(step) if step is not None else None
+
+        return slice(lower, upper, step)
+
+    def _eval_expr(self, v: Any) -> Any:
+        return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+    parser: "Parser",
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    value_table = {}
+    if dict_globals is not None:
+        value_table.update(dict_globals)
+    return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+    parser: "Parser",
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    try:
+        return _eval_assign(target, source)
+    except Exception as e:  # pylint: disable=broad-except,invalid-name
+        parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+        raise
+
+
+def _eval_expr(
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    node = doc.from_doc(node)
+    if isinstance(node, ast.expr):
+        node = ast.Expression(body=node)
+    assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+    if dict_globals is None:
+        dict_globals = {}
+    node = ast.fix_missing_locations(node)
+    exe = compile(node, filename="<ast>", mode="eval")
+    return eval(exe, dict_globals)  # pylint: disable=eval-used
+
+
+def _eval_op(
+    op: doc.AST,
+    values: List[Any],
+):
+    op_type = type(op)  # pylint: disable=protected-access
+    for i, v in enumerate(values):
+        v_type = getattr(type(v), "_dispatch_type", None)
+        if v_type is None:
+            continue
+        f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+        if f is not None:
+            return f(*values)
+    return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    target = doc.from_doc(target)
+    assert isinstance(target, ast.expr)
+    RHS_VAR_NAME = "__tvm_rhs_var__"  # pylint: disable=invalid-name
+    rhs_var_name = RHS_VAR_NAME
+    dict_locals = {rhs_var_name: source}
+    mod = ast.fix_missing_locations(
+        ast.Module(
+            body=[
+                ast.Assign(
+                    targets=[target],
+                    value=ast.Name(
+                        id=rhs_var_name,
+                        ctx=ast.Load(),
+                    ),
+                )
+            ],
+            type_ignores=[],
+        )
+    )
+    exe = compile(mod, filename="<ast>", mode="exec")
+    exec(exe, {}, dict_locals)  # pylint: disable=exec-used
+    del dict_locals[rhs_var_name]
+    return dict_locals
diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
new file mode 100644
index 0000000000..e26324262f
--- /dev/null
+++ b/python/tvm/script/parser/core/parser.py
@@ -0,0 +1,273 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from collections import defaultdict
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, List, Optional, Set, Union
+
+from tvm.error import DiagnosticError
+
+from . import dispatch, doc
+from .diagnostics import Diagnostics, Source
+from .evaluator import eval_assign, eval_expr
+
+DEFAULT_VISIT = {
+    "Interactive",
+    "Module",
+    "Expression",
+    "Pass",
+}
+
+
+def _deferred(f: Callable[[], None]):
+    @contextmanager
+    def context():
+        try:
+            yield
+        finally:
+            f()
+
+    return context()
+
+
+class VarTableFrame:
+    vars: Set[str]
+
+    def __init__(self):
+        self.vars = set()
+
+    def add(self, var: str):
+        if var in self.vars:
+            raise ValueError(f"Variable {var} already defined in current scope")
+        self.vars.add(var)
+
+    def pop_all(self, fn_pop: Callable[[str], None]):
+        for var in self.vars:
+            fn_pop(var)
+        self.vars.clear()
+
+
+class VarTable:
+
+    frames: List[VarTableFrame]
+    name2value: Dict[str, List[Any]]
+
+    def __init__(self):
+        self.frames = []
+        self.name2value = defaultdict(list)
+
+    def with_frame(self):
+        def pop_frame():
+            frame = self.frames.pop()
+            frame.pop_all(lambda name: self.name2value[name].pop())
+
+        self.frames.append(VarTableFrame())
+        return _deferred(pop_frame)
+
+    def add(self, var: str, value: Any):
+        self.frames[-1].add(var)
+        self.name2value[var].append(value)
+
+    def get(self) -> Dict[str, Any]:
+        return {key: values[-1] for key, values in self.name2value.items() if values}
+
+    def exist(self, value: Any):
+        for v in self.name2value.values():
+            if v is value:
+                return True
+        return False
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+    def _wrapper(self: "Parser", node: doc.AST) -> None:
+        try:
+            return func(self, node)
+        except DiagnosticError:
+            raise
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+            raise
+
+    return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+    for token in [self.dispatch_tokens[-1], "default"]:
+        func = dispatch.get(token=token, type_name=type_name, default=None)
+        if func is not None:
+            return _dispatch_wrapper(func)
+    return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+    """The TVMScript parser"""
+
+    diag: Diagnostics
+    dispatch_tokens: List[str]
+    var_table: VarTable
+
+    def __init__(self, source: Source) -> None:
+        self.diag = Diagnostics(source)
+        self.dispatch_tokens = ["default"]
+        self.var_table = VarTable()
+
+    def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+        if extra_vars is None:
+            extra_vars = {}
+        with self.var_table.with_frame():
+            for k, v in extra_vars.items():
+                self.var_table.add(k, v)
+            node = self.diag.source.as_ast()
+            self.visit(node)
+
+    def with_dispatch_token(self, token: str):
+        def pop_token():
+            self.dispatch_tokens.pop()
+
+        self.dispatch_tokens.append(token)
+        return _deferred(pop_token)
+
+    def eval_expr(
+        self,
+        node: Union[doc.Expression, doc.expr],
+        extra_vars: Optional[Dict[str, Any]] = None,
+    ) -> Any:
+        var_values = self.var_table.get()
+        if extra_vars is not None:
+            for k, v in extra_vars.items():
+                var_values[k] = v
+        return eval_expr(self, node, var_values)
+
+    def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
+        if isinstance(target, (doc.Tuple, doc.List)):
+            vars: Set[str] = set()  # pylint: disable=redefined-builtin
+            for i in target.elts:
+                res = self._duplicate_lhs_check(i)
+                if isinstance(res, bool) and res:
+                    return True
+                assert isinstance(res, set)
+                if vars & res:
+                    return True
+                vars = vars.union(res)
+            return vars
+        elif isinstance(target, doc.Name):
+            return {target.id}
+        else:
+            self.report_error(target, "Invalid type in assign statement")
+            raise NotImplementedError
+
+    def eval_assign(
+        self,
+        target: doc.expr,
+        source: Any,
+        bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+    ) -> Dict[str, Any]:
+        if self._duplicate_lhs_check(target) is True:
+            self.report_error(target, "Duplicate vars assigned.")
+        var_values = eval_assign(self, target, source)
+        for k, v in var_values.items():
+            var = bind_value(self, target, k, v)
+            self.var_table.add(k, var)
+        return var_values
+
+    def report_error(self, node: doc.AST, msg: str) -> None:  # pylint: disable=no-self-use
+        self.diag.error(node, msg)
+
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        name = node.__class__.__name__.split(".")[-1]
+        if name in DEFAULT_VISIT:
+            func = self.generic_visit
+        else:
+            func = getattr(self, "visit_" + name, None)
+        if func is None:
+            raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+        try:
+            func(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_body(self, node: List[doc.stmt]) -> Any:
+        for stmt in node:
+            self.visit(stmt)
+
+    def visit_tvm_annotation(self, node: doc.expr) -> Any:
+        return _dispatch(self, "tvm_annotation")(self, node)
+
+    def visit_FunctionDef(self, node: doc.FunctionDef) -> Any:  # pylint: disable=invalid-name
+        if not node.decorator_list:
+            self.report_error(node, "Function must be decorated")
+        # TODO: only the last decorator is parsed
+        decorator = self.eval_expr(node.decorator_list[-1])
+        if not hasattr(decorator, "dispatch_token"):
+            self.report_error(node, "The parser does not understand the decorator")
+        token = decorator.dispatch_token
+        func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        try:
+            func(self, node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_ClassDef(self, node: doc.ClassDef) -> Any:  # pylint: disable=invalid-name
+        func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        try:
+            func(self, node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_arguments(self, node: doc.arguments) -> Any:
+        return _dispatch(self, "arguments")(self, node)
+
+    def visit_For(self, node: doc.For) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "For")(self, node)
+
+    def visit_While(self, node: doc.While) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "While")(self, node)
+
+    def visit_With(self, node: doc.With) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "With")(self, node)
+
+    def visit_Assign(self, node: doc.Assign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assign")(self, node)
+
+    def visit_Expr(self, node: doc.Expr) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Expr")(self, node)
+
+    def visit_If(self, node: doc.If) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "If")(self, node)
+
+    def visit_AnnAssign(self, node: doc.AnnAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AnnAssign")(self, node)
+
+    def visit_AugAssign(self, node: doc.AugAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AugAssign")(self, node)
+
+    def visit_Assert(self, node: doc.Assert) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assert")(self, node)
+
+    def visit_Return(self, node: doc.Return) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Return")(self, node)
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/core/utils.py
similarity index 58%
copy from python/tvm/script/tir/__init__.py
copy to python/tvm/script/parser/core/utils.py
index 2f2b4bbc25..aae45fe6ff 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -14,18 +14,23 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Any, Callable, Dict
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
 
-from .prim_func import prim_func
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+    captured = {
+        **inspect.getclosurevars(func).nonlocals,
+        **func.__globals__,  # type: ignore
+    }
+    return captured
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+    result: Dict[str, Any] = {}
+    for _, v in cls.__dict__.items():
+        if inspect.isfunction(v):
+            func_vars = inspect_function_capture(v)
+            result.update(**func_vars)
+    return result
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/__init__.py
index 555659d0c5..4cbd9910a2 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,8 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
 
-from . import tir
-
-from .parser import ir_module, from_source
+__all__ = ["ir_module"]
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/ir/entry.py
similarity index 54%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser/ir/entry.py
index 923eb97d27..3c1e4de5a7 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,32 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
-
+# pylint: disable=missing-docstring
 import inspect
-from typing import Callable
+from typing import Type
+
+from tvm.ir import IRModule
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+from .._core import parse, utils
 
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
+def ir_module(f: Type) -> IRModule:
+    if not inspect.isclass(f):
+        raise TypeError(f"Expect a class, but got: {f}")
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
+    return parse(f, utils.inspect_class_capture(f))
 
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
 
-    raise TypeError("Only function definitions are supported.")
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 55%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser/ir/parser.py
index 2f2b4bbc25..8871d3b415 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,18 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+from ...ir_builder import ir as I
+from .._core import Parser, dispatch, doc
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
 
-from .prim_func import prim_func
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+    with self.var_table.with_frame():
+        with I.ir_module():
+            with self.with_dispatch_token("ir"):
+                self.visit_body(node.body)
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+    pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+    pass
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 71%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..930764f73d 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from ...ir_builder.tir import *  # pylint: disable=redefined-builtin
+from ...ir_builder.tir import ir as _tir
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
 
-from . import tir
-
-from .parser import ir_module, from_source
+__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..db4e2dd9a3
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+
+from ...ir_builder.tir import buffer_decl, ptr
+from .._core import parse, utils
+
+
+def _is_defined_in_class(frames):
+    if len(frames) > 2:
+        maybe_class_frame = frames[2]
+        statement_list = maybe_class_frame[4]
+        if statement_list is None:
+            return False
+        first_statement = statement_list[0]
+        line = first_statement.strip()
+        if line.startswith("class "):
+            return True
+        if line.startswith("@") and "ir_module" in line:
+            return True
+    return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+    if not inspect.isfunction(f):
+        raise TypeError(f"Expect a function, but got: {f}")
+    if _is_defined_in_class(inspect.stack()):
+        return f
+    return parse(f, utils.inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+    def __call__(
+        self,
+        shape,
+        dtype="float32",
+        data=None,
+        strides=None,
+        elem_offset=None,
+        scope="global",
+        align=0,
+        offset_factor=0,
+        buffer_type="",
+        axis_separators=None,
+    ) -> Buffer:
+        return buffer_decl(
+            shape,
+            dtype=dtype,
+            data=data,
+            strides=strides,
+            elem_offset=elem_offset,
+            scope=scope,
+            align=align,
+            offset_factor=offset_factor,
+            buffer_type=buffer_type,
+            axis_separators=axis_separators,
+        )
+
+    def __getitem__(self, keys) -> Buffer:
+        if not isinstance(keys, tuple):
+            return self(keys)
+        if len(keys) >= 2 and not isinstance(keys[1], str):
+            return self(keys)
+        return self(*keys)  # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+    def __call__(self, dtype, storage_scope="global"):
+        if callable(dtype):
+            dtype = dtype().dtype
+        return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+    def __getitem__(self, keys):
+        if not isinstance(keys, tuple):
+            return self(keys)
+        return self(*keys)
+
+
+Buffer = BufferProxy()  # pylint: disable=invalid-name
+Ptr = PtrProxy()  # pylint: disable=invalid-name
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..716525b984
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .._core import OpMethod, doc, register_op
+
+
+def _register_expr_op(ty: Type):  # pylint: disable=invalid-name
+    ty._dispatch_type = ty  # pylint: disable=protected-access
+
+    def _and(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.And(a, b)
+
+    def _or(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.Or(a, b)
+
+    def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
+        register_op(ty, op, i)(m)
+
+    for i in [0, 1]:
+        # Case 1. binop
+        r(doc.Add, i, tir.Add)
+        r(doc.Sub, i, tir.Sub)
+        r(doc.Mult, i, tir.Mul)
+        r(doc.Div, i, tir.Div)
+        r(doc.FloorDiv, i, tir.FloorDiv)
+        r(doc.Mod, i, tir.FloorMod)
+        r(doc.LShift, i, lambda a, b: a << b)
+        r(doc.RShift, i, lambda a, b: a >> b)
+        r(doc.BitOr, i, lambda a, b: a | b)
+        r(doc.BitXor, i, lambda a, b: a ^ b)
+        r(doc.BitAnd, i, lambda a, b: a & b)
+        # doc.MatMult <-- not implemented
+        # doc.Pow <-- not implemented
+        # Case 2. cmpop
+        r(doc.Eq, i, tir.EQ)
+        r(doc.NotEq, i, tir.NE)
+        r(doc.Lt, i, tir.LT)
+        r(doc.LtE, i, tir.LE)
+        r(doc.Gt, i, tir.GT)
+        r(doc.GtE, i, tir.GE)
+        # doc.Is <-- not implemented
+        # doc.IsNot <-- not implemented
+        # doc.In <-- not implemented
+        # doc.NotIn <-- not implemented
+        # Case 3. boolop
+        r(doc.And, i, _and)
+        r(doc.Or, i, _or)
+    for i in [0]:
+        #  Case 4. unaryop
+        r(doc.Invert, i, lambda a: ~a)
+        r(doc.Not, i, tir.Not)
+        r(doc.UAdd, i, lambda a: +a)
+        r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..351238c06f
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,268 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+
+from ...ir_builder import tir as T
+from ...ir_builder.base import IRBuilderFrame as Frame
+from ...ir_builder.base import name
+from .._core import Parser, dispatch, doc
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, (Buffer, Var)):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+        raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Var):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+        raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, T.inline):
+        return value.value
+    elif isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, _node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Frame):
+        value.add_callback(partial(value.__exit__, None, None, None))
+        res = value.__enter__()
+        name(var_name, res)
+        return res
+    elif isinstance(value, (Buffer, IterVar)) or (
+        isinstance(value, Var) and not self.var_table.exist(value)
+    ):
+        name(var_name, value)
+        return value
+    elif isinstance(value, PrimExpr):
+        var = T.var(value.dtype)
+        name(var_name, var)
+        frame = T.let(var, value)
+        frame.add_callback(partial(frame.__exit__, None, None, None))
+        frame.__enter__()
+        return var
+    return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+    for_frame = self.eval_expr(node.iter)
+    if not isinstance(for_frame, T.frame.ForFrame):
+        self.report_error(
+            node.iter,
+            "Expect the for loop to be one of the following: "
+            "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+        )
+    with self.var_table.with_frame():
+        with for_frame as iters:
+            self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+    with self.var_table.with_frame():
+        cond = self.eval_expr(node.test)
+        with T.While(cond):
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+    lhs = node.targets[0]
+    rhs = self.eval_expr(node.value)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+    lhs_pos = (
+        node.target.lineno,
+        node.target.col_offset,
+        node.target.end_lineno,
+        node.target.end_col_offset,
+    )
+    rhs_pos = (
+        node.value.lineno,
+        node.value.col_offset,
+        node.value.end_lineno,
+        node.value.end_col_offset,
+    )
+    node.target.ctx = doc.Load(*lhs_pos)
+    with self.var_table.with_frame():
+        lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+        rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+        lhs_expr = self.eval_expr(node.target)
+        rhs_expr = self.eval_expr(node.value)
+        self.var_table.add(lhs_name, lhs_expr)
+        self.var_table.add(rhs_name, rhs_expr)
+        op = doc.BinOp(
+            doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+            node.op,
+            doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+            *lhs_pos,
+        )
+        rhs = self.eval_expr(op)
+    lhs = node.target
+    lhs.ctx = doc.Store(*lhs_pos)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+    lhs = node.target
+    rhs = self.eval_expr(node.value)
+    ann_var = self.visit_tvm_annotation(node.annotation)
+    if not isinstance(ann_var, Var):
+        self.report_error(node.annotation, "Annotation should be Var")
+    self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+    frame = T.let(ann_var, rhs)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+    with contextlib.ExitStack() as stack:
+        stack.enter_context(self.var_table.with_frame())
+        for item in node.items:
+            frame = self.eval_expr(item.context_expr)
+            if not isinstance(frame, Frame):
+                self.report_error(
+                    item.context_expr, "Invalid context expression in the with-statement."
+                )
+            rhs = stack.enter_context(frame)
+            if item.optional_vars is not None:
+                self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+        self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    with self.var_table.with_frame():
+        self.var_table.add("range", T.serial)
+        with T.prim_func():
+            T.func_name(node.name)
+            if node.returns is not None:
+                ret_type = self.eval_expr(node.returns)
+                if callable(ret_type):
+                    ret_type = PrimType(ret_type().dtype)
+                T.func_ret(ret_type)
+            with self.with_dispatch_token("tir"):
+                self.visit(node.args)
+                self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+    # TODO: handle different types of arguments:
+    # - vararg: arg | None
+    # - kwonlyargs: list[arg]
+    # - kw_defaults: list[expr | None]
+    # - kwarg: arg | None
+    # - defaults: list[expr]
+    # - posonlyargs: list[arg]
+    arg: doc.arg
+    for arg in node.args:
+        if arg.annotation is None:
+            self.report_error(arg, "Type annotation is required for function parameters.")
+        param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+        self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+    annotation = self.eval_expr(node)
+    if callable(annotation):
+        annotation = annotation()
+    return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+    res = self.eval_expr(node.value)
+    if isinstance(res, Frame):
+        res.add_callback(partial(res.__exit__, None, None, None))
+        res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+    with self.var_table.with_frame():
+        with T.If(self.eval_expr(node.test)):
+            with T.Then():
+                self.visit_body(node.body)
+            if node.orelse:
+                with T.Else():
+                    self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+    cond = self.eval_expr(node.test)
+    msg = self.eval_expr(node.msg)
+    frame = T.Assert(cond, msg)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="Return")
+def visit_return(self: Parser, node: doc.Return) -> None:
+    self.report_error(node, "Return is not allowed.")
diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py
deleted file mode 100644
index e7d90dd515..0000000000
--- a/python/tvm/script/registry.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Function Registry """
-# pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
-import types
-from typing import Union, Callable, Dict, Optional, Any
-
-
-class Registry(object):
-    """Registration map
-    All these maps are static
-    """
-
-    registrations: Dict[str, type] = dict()
-
-    @staticmethod
-    def lookup(name: str) -> Optional[Any]:
-        if name in Registry.registrations:
-            # every time we create a new handler
-            # since we may want to keep some local info inside it
-            return Registry.registrations[name]()
-        return None
-
-
-def register(inputs: Union[Callable, type]) -> type:
-    """Register Intrin/ScopeHandler/SpecialStmt"""
-    registration: type
-    if isinstance(inputs, types.FunctionType):
-        # is function
-        from .tir.intrin import Intrin
-
-        def create_new_intrin(func) -> type:
-            class NewIntrin(Intrin):
-                def __init__(self):
-                    super().__init__(func)
-
-            return NewIntrin
-
-        registration = create_new_intrin(inputs)
-    elif isinstance(inputs, type):
-        # is class
-        registration = inputs
-    else:
-        raise ValueError()
-
-    key: str = registration().signature()[0]
-    Registry.registrations[key] = registration
-    return registration
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
deleted file mode 100644
index a62fb102be..0000000000
--- a/python/tvm/script/tir/__init__.pyi
+++ /dev/null
@@ -1,477 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=redefined-builtin
-from typing import (
-    Any,
-    Callable,
-    ContextManager,
-    Dict,
-    Iterable,
-    Optional,
-    Tuple,
-    Union,
-    Sequence,
-    List,
-    Mapping,
-    overload,
-)
-from numbers import Number
-import builtins
-
-from tvm.tir.function import PrimFunc
-from tvm.tir import Range
-from tvm.runtime import Object
-from tvm.target import Target
-from .node import BufferSlice
-
-"""
-redefine types
-"""
-
-class PrimExpr:
-    def __init__(self: PrimExpr) -> None: ...
-    @overload
-    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __mod__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ...
-    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __floordiv__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ...
-    def __index__(self: PrimExpr) -> int: ...  # so range doesn't complain
-
-class Var(PrimExpr): ...
-class IterVar(Var): ...
-
-class Buffer:
-    @overload
-    def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ...
-    @overload
-    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
-    @overload
-    def __setitem__(
-        self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr
-    ) -> None: ...
-    @overload
-    def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
-    @property
-    def data(self: Buffer) -> Ptr: ...
-
-"""
-Intrinsic
-"""
-
-def min_value(dtype: str) -> PrimExpr: ...
-def max_value(dtype: str) -> PrimExpr: ...
-def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def ceildiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def abs(x: PrimExpr) -> PrimExpr: ...
-def load(
-    dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None
-) -> PrimExpr: ...
-def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
-def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
-def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
-def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
-def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
-def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
-def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
-def evaluate(value: PrimExpr) -> None: ...
-def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def store(
-    var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
-) -> None: ...
-def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ...
-def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ...
-def preflattened_buffer(
-    buf: Buffer,
-    shape: Sequence[PrimExpr],
-    dtype: str = "float32",
-    data: Optional[Ptr] = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-) -> Buffer: ...
-
-"""
-Intrinsics - tvm builtin
-"""
-
-def tvm_thread_allreduce(
-    *freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str
-) -> PrimExpr: ...
-
-"""
-Unary operator
-Note that any intrinsics not registered in script.tir.intrin
-should add "dtype" as an argument. This is different from their
-definition but intentional.
-"""
-
-def exp(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def erf(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log10(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def tan(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def cos(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def acos(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sin(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def asin(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atan(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
-
-"""
-special_stmt - Buffers
-"""
-
-def match_buffer(
-    param: Union[Var, BufferSlice],
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def decl_buffer(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def buffer_decl(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def alloc_buffer(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-
-"""
-special_stmt - Reads/Writes
-"""
-
-@overload
-def reads(read_regions: List[BufferSlice]) -> None: ...
-@overload
-def reads(*read_regions: BufferSlice) -> None: ...
-@overload
-def writes(write_region: List[BufferSlice]) -> None: ...
-@overload
-def writes(*write_region: BufferSlice) -> None: ...
-def block_attr(attrs: Mapping[str, Object]) -> None: ...
-
-"""
-special_stmt - Axis
-"""
-
-class axis:
-    @overload
-    @staticmethod
-    def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def spatial(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def reduce(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def scan(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def opaque(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @staticmethod
-    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
-
-def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
-
-"""
-special_stmt - Annotations
-"""
-
-def buffer_var(dtype: str, storage_scope: str) -> Var: ...
-def func_attr(attrs: Mapping[str, Union[Object, str, bool, int, float]]) -> None: ...
-def prim_func(input_func: Callable) -> PrimFunc: ...
-
-"""
-special_stmt - Threads and Bindings
-"""
-
-def env_thread(env_name: str) -> IterVar: ...
-def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
-
-"""
-Scope handler
-"""
-
-class block(ContextManager):
-    def __init__(self, name_hint: str = "") -> None: ...
-    def __enter__(self) -> Sequence[IterVar]: ...
-
-class init(ContextManager):
-    def __init__(self) -> None: ...
-
-class let(ContextManager):
-    def __init__(self, var: Var, value: PrimExpr) -> None: ...
-
-def where(cond: PrimExpr) -> None: ...
-def allocate(
-    extents: List[PrimExpr],
-    dtype: str,
-    scope: str,
-    condition: Union[PrimExpr, builtins.bool] = True,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Buffer: ...
-def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ...
-def realize(
-    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
-) -> None: ...
-def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
-def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
-
-"""
-Scope handler - Loops
-"""
-
-@overload
-def serial(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def serial(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def parallel(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def parallel(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def vectorized(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def vectorized(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def unroll(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def unroll(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def thread_binding(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    thread: str,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def thread_binding(
-    end: Union[PrimExpr, int],
-    thread: str,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def for_range(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def for_range(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ...
-
-"""
-ty - redefine types
-"""
-
-class boolean: ...
-
-class handle(Var):
-    @overload
-    def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ...
-    @overload
-    def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ...
-    @overload
-    def __setitem__(
-        self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer
-    ) -> None: ...
-    @overload
-    def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ...
-    @property
-    def data(self: handle) -> Ptr: ...
-
-class Ptr: ...
-
-def target(target_str: Union[str, Mapping[str, Object]]) -> Target: ...
-
-class var(Var):
-    def __init__(self: Var, dtype: str): ...
-
-class bool(PrimExpr):
-    def __init__(self: bool, imm: Union[PrimExpr, builtins.bool, builtins.int]): ...
-
-class int8(PrimExpr):
-    def __init__(self: int8, imm: Union[PrimExpr, int]): ...
-
-class int16(PrimExpr):
-    def __init__(self: int16, imm: Union[PrimExpr, int]): ...
-
-class int32(PrimExpr):
-    def __init__(self: int32, imm: Union[PrimExpr, int]): ...
-
-class int64(PrimExpr):
-    def __init__(self: int64, imm: Union[PrimExpr, int]): ...
-
-class uint8(PrimExpr):
-    def __init__(self: uint8, imm: Union[PrimExpr, int]): ...
-
-class uint16(PrimExpr):
-    def __init__(self: uint16, imm: Union[PrimExpr, int]): ...
-
-class uint32(PrimExpr):
-    def __init__(self: uint32, imm: Union[PrimExpr, int]): ...
-
-class uint64(PrimExpr):
-    def __init__(self: uint64, imm: Union[PrimExpr, int]): ...
-
-class float8(PrimExpr):
-    def __init__(self: float8, imm: Union[PrimExpr, int, float]): ...
-
-class float16(PrimExpr):
-    def __init__(self: float16, imm: Union[PrimExpr, int, float]): ...
-
-class float32(PrimExpr):
-    def __init__(self: float32, imm: Union[PrimExpr, int, float]): ...
-
-class float64(PrimExpr):
-    def __init__(self: float64, imm: Union[PrimExpr, int, float]): ...
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
deleted file mode 100644
index 382431c229..0000000000
--- a/python/tvm/script/tir/intrin.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Intrinsic Classes"""
-# pylint: disable=redefined-builtin, relative-beyond-top-level
-import builtins
-from typing import List, Any
-
-import tvm.tir
-from ..registry import register
-from ...target import codegen
-from ..utils import get_param_list, tvm_span_from_synr
-
-
-class Intrin:
-    def __init__(self, intrin, stmt=False):
-        self.intrin = intrin
-        self.stmt = stmt
-
-    def signature(self):
-        return "tir." + self.intrin.__name__, get_param_list(self.intrin)
-
-    def handle(self, arg_list: List[Any], span: tvm.ir.Span):
-        return self.intrin(*arg_list, span=tvm_span_from_synr(span))
-
-
-@register
-def bool(imm, span):
-    return imm.astype("bool", span)
-
-
-# register all datatypes
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            _name = _dtype + _size + _lanes
-
-            # nest closures so we copy the name string
-            def wrap(name):
-                def f(imm, span):
-                    return imm.astype(name, span)
-
-                f.__name__ = name
-                return f
-
-            _intrin = wrap(_name)
-            register(_intrin)
-
-
-@register
-def min_value(dtype, span):
-    return tvm.tir.min_value(dtype, span)
-
-
-@register
-def max_value(dtype, span):
-    return tvm.tir.max_value(dtype, span)
-
-
-@register
-def floordiv(x, y, span):
-    return tvm.tir.floordiv(x, y, span)
-
-
-@register
-def floormod(x, y, span):
-    return tvm.tir.floormod(x, y, span)
-
-
-@register
-def truncmod(x, y, span):
-    return tvm.tir.truncmod(x, y, span)
-
-
-@register
-def ceildiv(x, y, span):
-    return tvm.tir.ceildiv(x, y, span)
-
-
-@register
-def abs(x, span):
-    return tvm.tir.abs(x, span)
-
-
-@register
-def load(dtype, var, index, predicate=None, span=None):
-    return tvm.tir.Load(dtype, var, index, predicate, span)
-
-
-@register
-def cast(value, dtype, span):
-    return tvm.tir.Cast(dtype, value, span)
-
-
-@register
-def ramp(base, stride, lanes, span):
-    return tvm.tir.Ramp(base, stride, lanes.value, span)
-
-
-@register
-def broadcast(value, lanes, span):
-    return tvm.tir.Broadcast(value, lanes.value, span)
-
-
-@register
-def iter_var(var, dom, iter_type, thread_tag, span):
-    iter_type = getattr(tvm.tir.IterVar, iter_type)
-    return tvm.tir.IterVar(dom, var, iter_type, thread_tag, span)
-
-
-@register
-def max(a, b, span):  # pylint: disable=redefined-builtin
-    return tvm.tir.Max(a, b, span)
-
-
-@register
-def min(a, b, span):  # pylint: disable=redefined-builtin
-    return tvm.tir.Min(a, b, span)
-
-
-def get_axis(begin, end, iter_type, span):
-    ana = tvm.arith.Analyzer()
-    extent = ana.simplify(end - begin)
-    block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)
-
-    iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
-    return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)
-
-
-@register
-def range(begin, end, span):
-    return get_axis(begin, end, "data_par", span)
-
-
-@register
-def reduce_axis(begin, end, span):
-    return get_axis(begin, end, "reduce", span)
-
-
-@register
-def scan_axis(begin, end, span):
-    return get_axis(begin, end, "scan", span)
-
-
-@register
-def opaque_axis(begin, end, span):
-    return get_axis(begin, end, "opaque", span)
-
-
-@register
-def Select(cond, if_body, else_body, span):  # pylint: disable=invalid-name
-    return tvm.tir.Select(cond, if_body, else_body, span)
-
-
-@register
-def Let(var, value, body, span):  # pylint: disable=invalid-name
-    return tvm.tir.Let(var, value, body, span)
-
-
-@register
-class EvaluateIntrin(Intrin):
-    def __init__(self):
-        def evaluate(value, span):
-            return tvm.tir.Evaluate(value, span)
-
-        super().__init__(evaluate, stmt=True)
-
-
-@register
-class StoreIntrin(Intrin):
-    def __init__(self):
-        def store(var, index, value, predicate=True, span=None):
-            return tvm.tir.Store(var, value, index, predicate, span)
-
-        super().__init__(store, stmt=True)
-
-
-@register
-class AssumeIntrin(Intrin):
-    def __init__(self):
-        def assume(constraint, span):
-            return tvm.tir.Evaluate(
-                tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span)
-            )
-
-        super().__init__(assume, stmt=True)
-
-
-@register
-def comm_reducer(lambda_io, identities, span):
-    """Create a CommReducer from lambda inputs/outputs and the identities"""
-    lambda_input = lambda_io[0]
-    lambda_output = lambda_io[1]
-
-    num_args = len(lambda_input)
-    num_arg_per_group = num_args // 2
-    x = [lambda_input[i] for i in builtins.range(0, num_arg_per_group)]
-    y = [lambda_input[i] for i in builtins.range(num_arg_per_group, num_args)]
-
-    if not isinstance(lambda_output, tuple):
-        lambda_output = (lambda_output,)
-
-    return tvm.tir.CommReducer(x, y, lambda_output, identities, span)
-
-
-@register
-def llvm_lookup_intrinsic_id(name, span):
-    # pylint: disable=unused-argument
-    return codegen.llvm_lookup_intrinsic_id(name)
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py
deleted file mode 100644
index 29e79607fb..0000000000
--- a/python/tvm/script/tir/node.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=redefined-builtin
-"""TVM Script nodes."""
-
-from typing import Optional, Union, List, Callable
-import synr
-from tvm.arith import Analyzer
-from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
-
-
-class Slice:
-    """A helper class to present slice information for BufferSlice
-
-    Parameters
-    ----------
-    start : Union[PrimExpr, int]
-        The start index.
-
-    stop : Optional[Union[PrimExpr, int]]
-        The stop index, None means the Slice is an element-wise index
-
-    step : int
-        The slice step
-
-    span : Optional[Span]
-        The location of the slice in the source.
-    """
-
-    start: Union[PrimExpr, int]
-    stop: Optional[Union[PrimExpr, int]]
-    step: int
-    span: Optional[Span]
-
-    def __init__(
-        self,
-        start: Union[PrimExpr, int],
-        stop: Optional[Union[PrimExpr, int]] = None,
-        step: int = 1,
-        span: Optional[Span] = None,
-    ):
-        self.start = start
-        self.stop = stop
-        self.step = step
-        self.span = span
-
-    def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
-        """Helper to create index PrimExpr from slice object
-        Parameters
-        ----------
-        report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-            The error report func
-        """
-        if self.stop is None:
-            # scalar index
-            return self.start
-        if self.step < 1:
-            report_error("Slice's step should be positive integer", self.span)
-        lanes = Analyzer().simplify((self.stop - self.start + self.step - 1) // self.step)
-        if not isinstance(lanes, (int, IntImm)):
-            report_error("Slice's lanes should be constant for buffer indices", self.span)
-        if lanes == 1:
-            return self.start
-        return Ramp(self.start, self.step, int(lanes), self.span)
-
-
-class BufferSlice(ObjectGeneric):
-    """A generic object for representing general buffer access. Following cases are supported:
-        - element wise access buffer[i, j], which can be converted to BufferLoad if necessary
-        - slice access buffer[i: i + 1, j : j + 2]
-        - union of element and slice buffer[i, j: j + 2]
-
-        This node is used in TVMScript to parse BufferLoad, BufferRegion and Realize
-
-    Parameters
-    ----------
-    buffer : Buffer
-        The buffer.
-
-    indices : List[Union[Slice, PrimExpr, int]]
-        The access indexes can be slice, PrimExpr or int.
-
-    report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-        The error report func
-
-    span : Optional[Span]
-        The location of the buffer access in the source.
-    """
-
-    buffer: Buffer
-    slices: List[Slice]
-    report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-    span: Optional[Span]
-
-    def __init__(
-        self,
-        buffer: Buffer,
-        indices: List[Union[Slice, PrimExpr, int]],
-        report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
-        span: Optional[Span] = None,
-    ):
-        def check_index(index: Union[int, PrimExpr]):
-            """Check input index is non-negative integer or PrimExpr"""
-            if isinstance(index, int):
-                if index < 0:
-                    report_error("Negative index is not allowed during buffer access", span)
-            elif isinstance(index, PrimExpr):
-                element_dtype = index.dtype.split("x", maxsplit=1)[0]
-                if element_dtype[:3] != "int":
-                    report_error(
-                        "index expected an integer type PrimExpr but got " + str(index.dtype),
-                        index.span,
-                    )
-            else:
-                report_error(
-                    "Unsupported index type, expected int or tvm.tir.PrimExpr, but got "
-                    + str(type(index)),
-                    span,
-                )
-
-        slices: List[Union[Slice, BufferSlice]] = []
-        for index in indices:
-            if isinstance(index, Slice):
-                index.start, index.stop = [convert(_) for _ in [index.start, index.stop]]
-                check_index(index.start)
-                check_index(index.stop)
-                slices.append(index)
-            elif isinstance(index, (PrimExpr, int)):
-                check_index(index)
-                slices.append(Slice(index))
-            elif isinstance(index, BufferSlice):
-                buffer_load = index.asobject()
-                check_index(buffer_load)
-                slices.append(Slice(buffer_load))
-            else:
-                report_error(
-                    "Unsupported index type for BufferSlice, "
-                    + "expected int, tvm.tir.PrimExpr, tvm.tir.Slice, but got "
-                    + str(type(index)),
-                    span,
-                )
-
-        self.buffer = buffer
-        self.slices = slices
-        self.report_error = report_error
-        self.span = span
-
-    def __str__(self):
-        regions: List[str] = []
-        for s in self.slices:
-            if s.stop is None:
-                regions.append(str(s.start))
-            else:
-                regions.append(str(s.start) + ": " + str(s.stop))
-
-        return self.buffer.name + "[" + ", ".join(regions) + "]"
-
-    def asobject(self) -> BufferLoad:
-        """Convert object."""
-        indices = [s.as_index_expr(self.report_error) for s in self.slices]
-        return BufferLoad(self.buffer, indices, span=self.span)
-
-    def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion:
-        """Construct BufferRegion from BufferSlice
-
-        Parameters
-        ----------
-        analyzer : Optional[tvm.arith.Analyzer]
-            The analyzer for simplifying. If not provided, the method will construct a new one
-
-        Returns
-        -------
-        buffer_region : BufferRegion
-            The constructed BufferRegion.
-        """
-        region: List[Range] = []
-        for s in self.slices:
-            start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start)
-            extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
-            if not analyzer:
-                analyzer = Analyzer()
-            if isinstance(extent, PrimExpr):
-                extent = analyzer.simplify(extent)
-            if s.step != 1:
-                self.report_error("BufferRegion do not support non-trivial stride", s.span)
-            region.append(Range.from_min_extent(start, extent, span=s.span))
-        return BufferRegion(self.buffer, region)
-
-    def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
-        return self.asobject().astype(dtype, span)
-
-    @property
-    def dtype(self) -> str:
-        """Return the dtype referenced by the slice.
-
-        Implemented as a property so that ``slice.dtype`` has the same
-        calling convention as ``primexpr.dtype``.  This allows a
-        BufferSlice object can be assigned to a variable without
-        requiring a type annotation on the variable, similar to other
-        expressions.
-        """
-        return self.asobject().dtype
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
deleted file mode 100644
index da7545c9a9..0000000000
--- a/python/tvm/script/tir/scope_handler.py
+++ /dev/null
@@ -1,788 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Scope Handler Classes"""
-# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
-
-import synr
-import numpy as np
-import tvm.tir
-from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
-
-from .node import BufferSlice
-
-from ..context_maintainer import ContextMaintainer
-from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
-
-
-class ScopeHandler:
-    """Base class for all scope handlers"""
-
-    def __init__(self, func: Callable):
-        self.func: Callable = func
-        self.body: Optional[Stmt] = None
-        self.node: Optional[synr.ast.Node] = None
-        self.context: Optional[ContextMaintainer] = None
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir." + self.func.__name__, get_param_list(self.func)
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        pass
-
-    def exit_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        self.node = node
-        self.context = context
-        return call_with_error_reporting(
-            context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span)
-        )
-
-
-class WithScopeHandler(ScopeHandler):
-    """Base class for all with scope handlers"""
-
-    def __init__(self, func, concise_scope, def_symbol):
-        super().__init__(func)
-        self.concise_scope = concise_scope
-        self.def_symbol = def_symbol
-
-    @staticmethod
-    def get_optional_vars(node, context):
-        """Get a list synr.ast.With's optional_vars"""
-        assert isinstance(
-            node, synr.ast.With
-        ), f"WithScopeHandler expected synr.ast.With but got {type(node)}"
-
-        if isinstance(node.lhs, list):
-            for var in node.lhs:
-                if not isinstance(var, synr.ast.Var):
-                    context.report_error(
-                        f"Invalid optional var definition, expected Var but got {type(var)}",
-                        node.span,
-                    )
-            vars = node.lhs
-        else:
-            context.report_error(
-                f"Invalid optional var definition, expected list of Var but got {type(node.lhs)}",
-                node.span,
-            )
-        return vars
-
-
-@register
-class Allocate(WithScopeHandler):
-    """With scope handler T.allocate(extents, dtype, scope, condition, annotations)"""
-
-    def __init__(self):
-        def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
-            condition = tvm.runtime.convert(condition)
-            scope = tvm.runtime.convert(scope)
-
-            return tvm.tir.Allocate(
-                self.buffer.data,
-                self.buffer.dtype,
-                self.buffer.shape,
-                condition,
-                self.body,
-                annotations=annotations,
-                span=span,
-            )
-
-        super().__init__(allocate, concise_scope=True, def_symbol=True)
-        self.buffer = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer(
-            extents, dtype, scope, condition=True, annotations=None, span: Span = None
-        ):
-            """Setup buffer object for a given type."""
-            self.buffer = tvm.tir.decl_buffer(
-                shape=extents,
-                dtype=dtype,
-                name=name,
-                scope=scope,
-                span=span,
-            )
-
-        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer, node)
-
-
-@register
-class AllocateConst(WithScopeHandler):
-    """With scope handler T.allocate_const(data, extents, dtype, condition)
-
-    TIR constant node to represent non-scalar constant
-    """
-
-    def __init__(self):
-        def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
-            list_data = []
-            for i in raw_data:
-                list_data.append(i.value)
-            nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
-            n = tvm.tir.AllocateConst(
-                self.buffer.data,
-                dtype,
-                shape,
-                nd_data,
-                self.body,
-                annotations=annotations,
-                span=span,
-            )
-            return n
-
-        super().__init__(allocate_const, concise_scope=True, def_symbol=True)
-        self.buffer = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None):
-            """Setup buffer var for a given type."""
-            self.buffer = tvm.tir.decl_buffer(
-                shape=shape,
-                dtype=dtype,
-                name=name,
-                span=span,
-            )
-
-        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer, node)
-
-
-@register
-class DeclBuffer(WithScopeHandler):
-    """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align,
-                                offset_factor, buffer_type, axis_separators)
-    Example
-    -------
-    .. code-block:: python
-        A = T.decl_buffer((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def decl_buffer(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            return tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
-
-        super().__init__(decl_buffer, concise_scope=True, def_symbol=True)
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer(
-            shape,
-            dtype,
-            data,
-            strides,
-            elem_offset,
-            scope,
-            align,
-            offset_factor,
-            buffer_type,
-            axis_separators,
-            span: Span = None,
-        ):
-            self.buffer = tvm.tir.decl_buffer(
-                shape=shape,
-                dtype=dtype,
-                data=data,
-                strides=strides,
-                elem_offset=elem_offset,
-                scope=scope,
-                data_alignment=align,
-                offset_factor=offset_factor,
-                buffer_type=buffer_type,
-                axis_separators=axis_separators,
-                span=span,
-            )
-
-        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer, node)
-
-
-@register
-class LaunchThread(WithScopeHandler):
-    """With scope handler T.launch_thread(env_var, extent)"""
-
-    def __init__(self):
-        def launch_thread(env_var, extent, span):
-            extent = tvm.runtime.convert(extent, span=span)
-            thread_id = self.context.func_var_env_dict[env_var]
-            attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
-            return tvm.tir.AttrStmt(
-                IterVar(
-                    (0, extent),
-                    env_var,
-                    getattr(IterVar, "ThreadIndex"),
-                    thread_id,
-                    span=span,
-                ),
-                attr_key,
-                extent,
-                self.body,
-                span=span,
-            )
-
-        super().__init__(launch_thread, concise_scope=True, def_symbol=False)
-
-
-@register
-class Realize(WithScopeHandler):
-    """With scope handler T.realize(buffer_bounds, scope, condition)"""
-
-    def __init__(self):
-        def realize(
-            buffer_slice: BufferSlice, scope: str, condition: bool = True, span: bool = None
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            buffer: Buffer = buffer_slice.buffer
-            bounds: List[Range] = []
-            for s in buffer_slice.slices:
-                min: Union[PrimExpr, int] = s.start
-                extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start
-                if isinstance(extent, PrimExpr):
-                    extent = self.context.analyzer.simplify(extent)
-                bounds.append(Range.from_min_extent(min, extent, span=s.span))
-
-            scope = tvm.runtime.convert(scope, span=span)
-            return tvm.tir.AttrStmt(
-                buffer,
-                "realize_scope",
-                scope,
-                tvm.tir.BufferRealize(buffer, bounds, condition, self.body, span=span),
-                span=span,
-            )
-
-        super().__init__(realize, concise_scope=True, def_symbol=False)
-
-
-@register
-class Attr(WithScopeHandler):
-    """With scope handler T.attr(attr_node, attr_key, value)"""
-
-    def __init__(self):
-        def attr(attr_node, attr_key, value, span):
-            attr_node = tvm.runtime.convert(attr_node, span=span)
-            value = tvm.runtime.convert(value, span=span)
-            return tvm.tir.AttrStmt(attr_node, attr_key, value, self.body, span=span)
-
-        super().__init__(attr, concise_scope=True, def_symbol=False)
-
-
-@register
-class AssertHandler(WithScopeHandler):
-    """With scope handler T.Assert(condition, message)"""
-
-    def __init__(self):
-        def Assert(condition, message, span):
-            return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.body, span=span)
-
-        super().__init__(Assert, concise_scope=True, def_symbol=False)
-
-
-@register
-class Let(WithScopeHandler):
-    """With scope handler T.let(var, value)"""
-
-    def __init__(self):
-        def let(var, value, span):
-            return tvm.tir.LetStmt(var, value, self.body, span=span)
-
-        super().__init__(let, concise_scope=False, def_symbol=False)
-
-    def __call__(self, var: tvm.tir.Var, value: tvm.tir.PrimExpr, body: tvm.tir.PrimExpr):
-        return tvm.tir.Let(var, value, body)
-
-
-@register
-class Block(WithScopeHandler):
-    """With scope handler T.block(name)"""
-
-    def __init__(self):
-        def block(name_hint: str = "", span: Optional[Span] = None):
-            assert (
-                self.node and self.context and self.body
-            ), "call 'exit_scope' before 'enter_scope'"
-            block_info = self.context.block_info_stack[-1]
-
-            # create block read/write regions
-            reads: List[BufferRegion] = (
-                [read.as_buffer_region() for read in block_info.reads] if block_info.reads else []
-            )
-            writes: List[BufferRegion] = (
-                [write.as_buffer_region() for write in block_info.writes]
-                if block_info.writes
-                else []
-            )
-
-            region_detect_mask: int = (block_info.reads is None) | (
-                (block_info.writes is None) << 1
-            )
-            annotations = {} if block_info.annotations is None else block_info.annotations
-            if region_detect_mask != 0:
-                annotations["tir.script_parsing_detect_access"] = region_detect_mask
-            inner = tvm.tir.Block(
-                block_info.iter_vars,
-                reads,
-                writes,
-                name_hint,
-                self.body,
-                block_info.init,
-                block_info.alloc_buffers,
-                block_info.match_buffers,
-                annotations,
-                span,
-            )
-            assert len(block_info.iter_vars) == len(block_info.iter_values)
-            predicate = (
-                tvm.tir.const(True, "bool")
-                if block_info.predicate is None
-                else block_info.predicate
-            )
-            body = tvm.tir.BlockRealize(block_info.iter_values, predicate, inner, span)
-            return body
-
-        super().__init__(func=block, concise_scope=False, def_symbol=True)
-        self.block_vars = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define block vars
-        assert isinstance(
-            node, synr.ast.With
-        ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}"
-
-        optional_vars = [var.id.name for var in WithScopeHandler.get_optional_vars(node, context)]
-        if optional_vars:
-            context.report_error(
-                f"Block expected no optional_vars (e.g., `x` in `with block() as x`), "
-                f"but got {optional_vars}",
-                node.span,
-            )
-
-
-@register
-class InitBlock(WithScopeHandler):
-    """With scope handler T.init()"""
-
-    def __init__(self):
-        def init(span: Span = None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            if self.context.block_info_stack[-2].init is not None:
-                self.context.report_error("Duplicate init block declaration", span)
-            self.context.block_info_stack[-2].init = self.body
-
-        super().__init__(func=init, concise_scope=False, def_symbol=True)
-
-
-class LoopInfo:
-    """Helper class for loop information"""
-
-    loop_var: Var
-    begin: PrimExpr
-    extent: PrimExpr
-    kind: ForKind
-    thread_binding: Optional[str]
-    annotations: Optional[Mapping[str, Object]]
-
-    def __init__(
-        self,
-        begin: PrimExpr,
-        extent: PrimExpr,
-        kind: ForKind,
-        thread_binding: Optional[str] = None,
-        annotations: Optional[Mapping[str, Object]] = None,
-    ) -> None:
-        self.begin = begin
-        self.extent = extent
-        self.kind = kind
-        self.thread_binding = thread_binding
-        self.annotations = annotations
-
-
-class ForScopeHandler(ScopeHandler):
-    """Base class for all for scope handlers"""
-
-    def __init__(self, func):
-        super().__init__(func)
-        self.loop_vars: List[Var] = []
-        self.loop_info: List[LoopInfo] = []
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        assert isinstance(
-            node, synr.ast.For
-        ), f"ForScopeHandler expected synr.ast.For but got {type(node)}"
-
-        loop_var_names = list()
-        spans = list()
-        if isinstance(node.lhs, synr.ast.Var):
-            loop_var_names.append(node.lhs.id.name)
-            spans.append(tvm_span_from_synr(node.lhs.id.span))
-        elif isinstance(node.lhs, list):
-            for elt in node.lhs:
-                if not isinstance(elt, synr.ast.Var):
-                    context.report_error(
-                        f"Invalid loop var. Expected a var, but got {type(elt)}", elt.span
-                    )
-                loop_var_names.append(elt.id.name)
-                spans.append(tvm_span_from_synr(elt.id.span))
-        else:
-            context.report_error(
-                f"Invalid loop var. Expected var or list of vars as lhs, but got {type(node.lhs)}",
-                span,
-            )
-
-        self.node = node
-        self.context = context
-        # collect loop infos by calling self.func
-        call_with_error_reporting(context.report_error, span, self.func, *arg_list)
-        if len(loop_var_names) != len(self.loop_info):
-            self.context.report_error(
-                f"Inconsistent number of vars and loops, got {len(loop_var_names)} "
-                + f"vs {len(self.loop_info)}",
-                self.node.span,
-            )
-        # generate loop vars
-        self.loop_vars = []
-        for name, lv_span, li in zip(loop_var_names, spans, self.loop_info):
-            if not li.begin.dtype.startswith("int"):
-                raise NotImplementedError(f"Unsupported dtype in loop begin: {li.begin.dtype}")
-            if not li.extent.dtype.startswith("int"):
-                raise NotImplementedError(f"Unsupported dtype in loop extent: {li.extent.dtype}")
-            dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype] else "int32"
-            self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span))
-
-        for loop_var, loop_info in zip(self.loop_vars, self.loop_info):
-            context.update_symbol(loop_var.name, loop_var, node)
-            context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent)
-
-    def exit_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
-        for loop_var in self.loop_vars:
-            context.loop_stack.pop(loop_var)
-        # Use assert here since we have check it in `enter_scope`
-        assert len(self.loop_vars) == len(self.loop_info)
-
-        body = self.body
-        for var, info in zip(reversed(self.loop_vars), reversed(self.loop_info)):
-            body = tvm.tir.For(
-                var,
-                info.begin,
-                info.extent,
-                info.kind,
-                body,
-                info.thread_binding,
-                info.annotations,
-                span=tvm_span_from_synr(span),
-            )
-
-        return body
-
-    def create_loop_info(
-        self,
-        begin: PrimExpr,
-        end: PrimExpr,
-        kind: ForKind,
-        thread_binding: Optional[str] = None,
-        annotations: Optional[Mapping[str, Object]] = None,
-    ) -> None:
-        """
-        Helper function for creating For in TVM Script parser.
-
-        Parameters
-        ----------
-        begin : PrimExpr
-            The beginning value.
-
-        end : PrimExpr
-            The endding value.
-
-        kind : ForKind
-            The type of the for.
-
-        thread_binding: Optional[str]
-            The thread this loop binds to.
-
-        annotations : Optional[Mapping[str, Object]]
-            Additional annotation hints.
-
-        span : Optional[Span]
-            The location of this for in the source code.
-
-        Returns
-        -------
-        for : For
-            The constructed For.
-        """
-        begin, end = [convert(_) for _ in [begin, end]]
-        assert self.context and self.node, "call 'exit_scope' before 'enter_scope'"
-        extent = end if begin == 0 else self.context.analyzer.simplify(end - begin)
-        self.annotations: Mapping[str, Object] = {}
-        if annotations is not None:
-            self.annotations = {
-                key: String(val) if isinstance(val, str) else val
-                for key, val in annotations.items()
-            }
-
-        self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations))
-
-
-@register
-class Serial(ForScopeHandler):
-    """For scope handler T.serial(begin, end, annotations)"""
-
-    def __init__(self):
-        def serial(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
-
-        super().__init__(serial)
-
-
-@register
-class Parallel(ForScopeHandler):
-    """For scope handler T.parallel(begin, end, annotations)"""
-
-    def __init__(self):
-        def parallel(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations)
-
-        super().__init__(parallel)
-
-
-@register
-class Vectorized(ForScopeHandler):
-    """For scope handler T.vectorized(begin, end, annotations)"""
-
-    def __init__(self):
-        def vectorized(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations)
-
-        super().__init__(vectorized)
-
-
-@register
-class Unroll(ForScopeHandler):
-    """For scope handler T.unroll(begin, end, annotations)"""
-
-    def __init__(self):
-        def unroll(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations)
-
-        super().__init__(unroll)
-
-
-@register
-class ThreadBinding(ForScopeHandler):
-    """For scope handler T.thread_binding(begin, end, thread, annotations)"""
-
-    def __init__(self):
-        def thread_binding(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            thread: str = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if thread is None:
-                if isinstance(end, str):  # handle case like thread_binding(128, "threadIdx.x")
-                    thread = end
-                    end = None
-                else:
-                    raise ValueError("Thread cannot be None for thread_binding")
-            if end is None:
-                end = begin
-                begin = 0
-            thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread)
-            self.create_loop_info(
-                begin,
-                end,
-                ForKind.THREAD_BINDING,
-                thread_binding=thread_iter_var,
-                annotations=annotations,
-            )
-
-        super().__init__(thread_binding)
-
-
-@register
-class RangeHandler(ForScopeHandler):
-    """For scope handler range(begin, end, annotations)
-    Note that tir.range is totally the same as T.serial
-    """
-
-    def __init__(self):
-        def for_range(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
-
-        super().__init__(for_range)
-
-    def signature(self):
-        return "range", get_param_list(self.func)
-
-
-@register
-class Grid(ForScopeHandler):
-    """For scope handler T.grid(extents)"""
-
-    def __init__(self):
-        def grid(*extents: List[PrimExpr]):
-            for extent in extents:
-                self.create_loop_info(0, extent, ForKind.SERIAL)
-
-        super().__init__(grid)
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
deleted file mode 100644
index 15502055b7..0000000000
--- a/python/tvm/script/tir/special_stmt.py
+++ /dev/null
@@ -1,964 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Special Stmt Classes"""
-# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
-# pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
-
-import synr
-from synr import ast
-from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
-from tvm.runtime import Object, String
-from tvm.target import Target
-from tvm.ir import Span
-from tvm.tir import IntImm, IterVar, Var
-
-from .node import BufferSlice
-
-from ..context_maintainer import BlockInfo, ContextMaintainer
-from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
-
-
-def convert_to_int(
-    value: Union[IntImm, int],
-    arg_name: str,
-    report_error: Callable,
-    span: Union[Span, synr.ast.Span],
-) -> int:
-    """convert a const int or TVM IntImm to Python int.
-    Reports an error when input cannot be converted to int.
-
-    Parameters
-    ----------
-    value : Union[tvm.tir.IntImm, int]
-        The input value to be converted.
-    arg_name : str
-        Function argument name for error reporting.
-    report_error: Callable
-        The report error function handle
-    span : Union[synr.ast.Span, tvm.ir.Span]
-        Location of the error
-    """
-    if isinstance(value, IntImm):
-        return value.value
-    if isinstance(value, int):
-        return value
-    report_error(
-        f"Expected int or IntImm for {arg_name}, but got {str(type(value))}",
-        span,
-    )
-
-
-class SpecialStmt:
-    """Base class for all Special Stmts"""
-
-    def __init__(self, func: Callable, def_symbol: bool):
-        self.func: Callable = func
-        self.def_symbol: bool = def_symbol
-        self.node: Optional[synr.ast.Node] = None
-        self.context: Optional[ContextMaintainer] = None
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir." + self.func.__name__, get_param_list(self.func)
-
-    def handle(
-        self,
-        node: ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        self.node = node
-        self.context = context
-        return call_with_error_reporting(
-            context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span)
-        )
-
-
-@register
-class MatchBuffer(SpecialStmt):
-    """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align,
-                                 offset_factor, buffer_type, axis_separators)
-
-    Note
-    ----
-    This Special Stmt will perform different behavior depends on the type of param.
-    If the param is a var in function parameter, it will create a buffer from DLTensor.
-    Else if the param is a subregion of other buffers, then create a subregion match inside a block.
-
-    Example
-    -------
-    Match buffer from function parameter
-    .. code-block:: python
-        A = T.match_buffer(a, (128, 128), dtype="float32")
-
-    Match buffer from Buffer subregion
-    .. code-block:: python
-        A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def match_buffer(
-            param,
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`match_buffer` must be assigned to a single buffer, "
-                    "e.g. A = match_buffer(...)",
-                    self.node.span,
-                )
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            if isinstance(param, tvm.tir.Var):
-                if param not in self.context.func_params:
-                    self.context.report_error(
-                        "Can not bind non-input param to buffer", self.node.rhs.params[0].span
-                    )
-                self.context.func_buffer_map[param] = buffer
-            elif isinstance(param, BufferSlice):
-                buffer_region = param.as_buffer_region()
-                self.context.current_block_scope().match_buffers.append(
-                    tvm.tir.MatchBufferRegion(buffer, buffer_region)
-                )
-            else:
-                self.context.report_error(
-                    "The source of match_buffer expected Var or BufferSlice, but got "
-                    + str(type(param)),
-                    self.node.rhs.params[0].span,
-                )
-            self.context.update_symbol(buffer_name, buffer, self.node)
-
-        super().__init__(match_buffer, def_symbol=True)
-
-
-@register
-class BufferDeclare(SpecialStmt):
-    """Special Stmt buffer_decl(shape, dtype, data, strides, elem_offset, scope, align,
-                                offset_factor, buffer_type, axis_separators)
-    Example
-    -------
-    .. code-block:: python
-        A = T.buffer_decl((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def buffer_decl(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`buffer_decl` must be assigned to a single buffer, e.g. A = buffer_decl(...)",
-                    self.node.span,
-                )
-
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            self.context.update_symbol(buffer_name, buffer, self.node)
-            return buffer
-
-        super().__init__(buffer_decl, def_symbol=True)
-
-
-@register
-class AllocBuffer(SpecialStmt):
-    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align,
-                                     offset_factor, buffer_type, axis_separators)
-
-    Example
-    -------
-    .. code-block:: python
-
-        A = T.alloc_buffer((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def alloc_buffer(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`alloc_buffer` must be assigned to a single buffer, "
-                    "e.g. A = alloc_buffer(...)",
-                    self.node.span,
-                )
-
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            if self.context.current_block_scope():
-                self.context.current_block_scope().alloc_buffers.append(buffer)
-            else:
-                # If it is allocated outside all blocks, allocate it under root block.
-                self.context.root_alloc_buffers.append(buffer)
-            self.context.update_symbol(buffer_name, buffer, self.node)
-
-        super().__init__(alloc_buffer, def_symbol=True)
-
-
-@register
-class BlockReads(SpecialStmt):
-    """Special function reads([read_regions], *other_regions)
-
-    Note
-    ----
-    *other_region is an unpackable list of BufferSlice to support
-    reads syntax sugar like reads(BufferRegion1, BufferRegion2, ...)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
-    """
-
-    def __init__(self):
-        def reads(
-            *read_regions: Union[BufferSlice, List[BufferSlice]],
-            span: Span = None,
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare read regions inside a block.",
-                    span,
-                )
-            if block_scope.reads is not None:
-                self.context.report_error(
-                    "Duplicate write region declaration, "
-                    + "previous one is "
-                    + str(", ".join(str(x) for x in block_scope.reads)),
-                    span,
-                )
-            if len(read_regions) > 1:
-                for read_region in read_regions:
-                    if not isinstance(read_region, BufferSlice):
-                        self.context.report_error(
-                            "Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
-                            + f" but got {type(read_regions)}",
-                            span,
-                        )
-            elif len(read_regions) == 1:
-                if isinstance(read_regions[0], list):
-                    read_regions = read_regions[0]
-
-            block_scope.reads = read_regions
-
-        super().__init__(reads, def_symbol=False)
-
-
-@register
-class BlockWrites(SpecialStmt):
-    """Special function writes([write_regions], *other_regions)
-
-    Note
-    ----
-    *other_region is an unpackable list of BufferSlice to support
-    writes syntax sugar like writes(BufferRegion1, BufferRegion2, ...)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.writes([C[vi: vi + 4, vj])
-    """
-
-    def __init__(self):
-        def writes(
-            *write_regions: Union[BufferSlice, List[BufferSlice]],
-            span: Span = None,
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare write regions inside a block.",
-                    span,
-                )
-            if block_scope.writes is not None:
-                self.context.report_error(
-                    "Duplicate write region declaration, "
-                    + "previous one is "
-                    + str(", ".join(str(x) for x in block_scope.writes)),
-                    span,
-                )
-            if len(write_regions) > 1:
-                for write_region in write_regions:
-                    if not isinstance(write_region, BufferSlice):
-                        self.context.report_error(
-                            "Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
-                            + f" but got {type(write_regions)}",
-                            span,
-                        )
-            elif len(write_regions) == 1:
-                if isinstance(write_regions[0], list):
-                    write_regions = write_regions[0]
-            block_scope.writes = write_regions
-
-        super().__init__(writes, def_symbol=False)
-
-
-@register
-class BlockAttr(SpecialStmt):
-    """Special function block_attr({attr_key: attr_value})
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.block_attr({"double_buffer_scope": 1})
-    """
-
-    def __init__(self):
-        def block_attr(attrs: Mapping[str, Object], span: Span = None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare block annotations inside a block.",
-                    span,
-                )
-            if block_scope.annotations is not None:
-                self.context.report_error(
-                    "Duplicate block annotations declaration, "
-                    + "previous one is "
-                    + str(block_scope.annotations),
-                    span,
-                )
-            attrs = {
-                key: String(val) if isinstance(val, str) else val for key, val in attrs.items()
-            }
-            block_scope.annotations = attrs
-
-        super().__init__(block_attr, def_symbol=False)
-
-
-class BlockAxis(SpecialStmt):
-    """Special stmt for defining a spatial block axis
-    axis.S(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.S(128, i * 4 + j)
-    """
-
-    def axis(
-        self,
-        var_name: str,
-        dom: Union[PrimExpr, Range],
-        value: PrimExpr,
-        iter_type: int,
-        span: Optional[Span] = None,
-    ) -> None:
-        """
-        Helper function for creating block axis
-
-        Parameters
-        ----------
-        var_name : str
-            The name_hint of var
-
-        dom : Union[PrimExpr, Range]
-            The iter domain.
-
-        value : PrimExpr
-            The binding value
-
-        iter_type : int
-            The iteration type.
-
-        span : Optional[Span]
-            The location of this for in the source code.
-        """
-        assert self.context, "call 'exit_scope' before 'enter_scope'"
-        block_scope: BlockInfo = self.context.current_block_scope()
-        if block_scope is None:
-            self.context.report_error(
-                "Expected to declare block axes inside a block.",
-                self.node.span,
-            )
-        if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]:
-            self.context.report_error("Duplicate block axis " + var_name, self.node.span)
-
-        dom = tvm.runtime.convert(dom)
-        if isinstance(dom, PrimExpr):
-            dom = tvm.ir.Range(dom)
-        elif isinstance(dom, tvm.ir.container.Array) and len(dom) == 2:
-            dom = tvm.ir.Range(dom[0], dom[1])
-        elif not isinstance(dom, tvm.ir.Range):
-            self.context.report_error(
-                f"Block axis domain expected PrimExpr or Range, but got {type(dom)}",
-                self.node.span,
-            )
-        block_var = tvm.tir.Var(var_name, dtype=dom.extent.dtype)
-        value = tvm.runtime.convert(value)
-        if not isinstance(value, PrimExpr):
-            self.context.report_error(
-                f"Block axis value expected PrimExpr, but got {type(value)}",
-                self.node.span,
-            )
-        iter_var = tvm.tir.IterVar(dom, block_var, iter_type)
-        block_scope.iter_vars.append(iter_var)
-        block_scope.iter_values.append(value)
-        self.context.update_symbol(var_name, block_var, self.node)
-
-
-@register
-class BlockAxisSpatial(BlockAxis):
-    """Special stmt for defining a spatial block axis
-    axis.spatial(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.spatial(128, k)
-    """
-
-    def __init__(self):
-        def axis_spatial(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.spatial` must be assigned to a var, e.g. vi = axis.spatial(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar)
-
-        super().__init__(axis_spatial, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.spatial", get_param_list(self.func)
-
-
-@register
-class BlockAxisS(BlockAxis):
-    """The sugar special stmt for defining a spatial block axis
-    axis.S(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.S(128, k)
-    """
-
-    def __init__(self):
-        def axis_spatial(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.S` must be assigned to a var, e.g. vi = axis.S(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar)
-
-        super().__init__(axis_spatial, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.S", get_param_list(self.func)
-
-
-@register
-class BlockAxisReduce(BlockAxis):
-    """Special stmt for defining a reduce block axis
-    axis.reduce(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.reduce(128, k)
-    """
-
-    def __init__(self):
-        def axis_reduce(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.reduce` must be assigned` to a var, e.g. vi = axis.reduce(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce)
-
-        super().__init__(axis_reduce, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.reduce", get_param_list(self.func)
-
-
-@register
-class BlockAxisR(BlockAxis):
-    """The sugar special stmt for defining a reduce block axis
-    axis.R(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.R(128, k)
-    """
-
-    def __init__(self):
-        def axis_reduce(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.R` must be assigned to a var, e.g. vi = axis.R(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce)
-
-        super().__init__(axis_reduce, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.R", get_param_list(self.func)
-
-
-@register
-class BlockAxisScan(BlockAxis):
-    """Special stmt for defining a ordered block axis
-    axis.scan(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.scan(128, k)
-    """
-
-    def __init__(self):
-        def axis_scan(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.scan` must be assigned to a var, e.g. vi = axis.scan(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.Ordered)
-
-        super().__init__(axis_scan, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.scan", get_param_list(self.func)
-
-
-@register
-class BlockAxisOpaque(BlockAxis):
-    """Special stmt for defining a opaque block axis
-    axis.opaque(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.opaque(128, k)
-    """
-
-    def __init__(self):
-        def axis_opaque(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.opaque` must be assigned to a var, e.g. vi = axis.opaque(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DimInfo)
-
-        super().__init__(axis_opaque, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.opaque", get_param_list(self.func)
-
-
-@register
-class BlockAxisRemap(BlockAxis):
-    """Special stmt for remapping loops vars to block axes.
-    axis.remap(iter_type, iter_value)
-
-    Note
-    ----
-    Iter_type is a string consisting of 'S' and 'R', where 'S' means
-    for spatial and 'R' means for reduce.
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi, vj = T.axis.remap("SS", [i, j])
-    """
-
-    def __init__(self):
-        def axis_remap(iter_types: str, loop_vars: List[tvm.tir.expr.Var], span: Span = None):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) >= 1:
-                self.context.report_error(
-                    "`axis.remap` must be assigned to one or more vars, "
-                    "e.g. vi, vj = axis.remap(...)",
-                    self.node.span,
-                )
-            var_num: int = len(self.node.lhs)
-            if var_num != len(iter_types):
-                self.context.report_error(
-                    f"`iter_type` expected {var_num} charactor(s), "
-                    f"but got {len(iter_types)}: {iter_types}",
-                    span,
-                )
-            if var_num != len(loop_vars):
-                self.context.report_error(
-                    f"`iter_type` expected {var_num} loop var(s), "
-                    f"but got {len(loop_vars)}: {loop_vars}",
-                    span,
-                )
-            for var, iter_ty, loop_var in zip(self.node.lhs, iter_types, loop_vars):
-                iter_type: int
-                if iter_ty == "S":
-                    iter_type = IterVar.DataPar
-                elif iter_ty == "R":
-                    iter_type = IterVar.CommReduce
-                else:
-                    self.context.report_error(
-                        f'`iter_type` only expected "S" (for spatial) or "R" (for reduce), '
-                        f'but got "{iter_ty}"',
-                        span,
-                    )
-
-                if not isinstance(loop_var, tvm.tir.expr.Var):
-                    self.context.report_error(
-                        f"Values of `axis.remap` expected single loop var, but got {loop_var}",
-                        loop_var.span,
-                    )
-                loops = self.context.loop_stack
-                if loop_var not in loops:
-                    self.context.report_error(
-                        f"Cannot find loop var {loop_var} in loop nesting.",
-                        span,
-                    )
-                self.axis(var.id.name, loops[loop_var], loop_var, iter_type)
-
-        super().__init__(axis_remap, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.remap", get_param_list(self.func)
-
-
-@register
-class BlockPredicate(SpecialStmt):
-    """Special function where(predicate)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.where(i < 4)
-    """
-
-    def __init__(self):
-        def where(predicate, span=None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare the predicate inside a block.",
-                    span,
-                )
-            if block_scope.predicate is not None:
-                self.context.report_error(
-                    "Duplicate block predicate declaration, "
-                    + "previous one is "
-                    + str(block_scope.predicate),
-                    span,
-                )
-
-            block_scope.predicate = predicate
-
-        super().__init__(where, def_symbol=False)
-
-
-@register
-class VarDef(SpecialStmt):
-    """Special function for defining a Var"""
-
-    def __init__(self):
-        def var(dtype, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"VarDef expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            v = Var(names[0], dtype, span=span)
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(var, def_symbol=True)
-
-
-@register
-class BufferVarDef(SpecialStmt):
-    """Special function for defining a variable of pointer type"""
-
-    def __init__(self):
-        def buffer_var(dtype, storage_scope, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"BufferVarDef expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
-            v = Var(names[0], ptr_type, span=span)
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(buffer_var, def_symbol=True)
-
-
-@register
-class EnvThread(SpecialStmt):
-    """Bind a var to thread env"""
-
-    def __init__(self):
-        def env_thread(env_name, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"EnvThread expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            v = Var(names[0], dtype="int32", span=span)
-            self.context.func_var_env_dict[v] = env_name
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(env_thread, def_symbol=True)
-
-
-@register
-class FuncAttr(SpecialStmt):
-    """Special Stmt for declaring the DictAttr of PrimFunc
-    Example
-    -------
-    .. code-block:: python
-         T.func_attr({"tir.noalias": True, "global_symbol"})
-    """
-
-    def __init__(self):
-        def func_attr(dict_attr, span):
-            self.context.func_dict_attr = dict_attr
-
-        super().__init__(func_attr, def_symbol=False)
-
-
-@register
-class PreflattenedBufferMap(SpecialStmt):
-    """Special Stmt for declaring the PrimFunc::preflattened_buffer_map
-
-    Example
-    -------
-    .. code-block:: python
-         A0 = T.match_buffer(A, (48,), dtype="float32")
-         T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
-    """
-
-    def __init__(self):
-        def preflattened_buffer(
-            postflattened,
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            span=None,
-        ):
-
-            param = None
-            for key, value in self.context.func_buffer_map.items():
-                if value.same_as(postflattened):
-                    param = key
-                    break
-
-            assert (
-                param is not None
-            ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."
-
-            if data is None:
-                data = self.context.func_buffer_map[param].data
-
-            buffer_name: str = f"{postflattened.name}_preflatten"
-            if align != -1:
-                if isinstance(align, IntImm):
-                    align = align.value
-                else:
-                    assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"
-
-            if offset_factor != 0:
-                if isinstance(offset_factor, IntImm):
-                    offset_factor = offset_factor.value
-                else:
-                    assert isinstance(
-                        offset_factor, int
-                    ), f"offset_factor: want int or IntImm, got {offset_factor!r}"
-
-            preflattened = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                span=span,
-            )
-
-            self.context.func_preflattened_buffer_map[param] = preflattened
-
-        super().__init__(preflattened_buffer, def_symbol=False)
-
-
-@register
-class TargetAttrValue(SpecialStmt):
-    """Special Stmt for target attr value.
-    Example
-    -------
-    .. code-block:: python
-        T.target("llvm")
-    """
-
-    def __init__(self):
-        def target(*args, span):
-            self.context.report_error(f"T.target should not appear as a stmt", span)
-
-        super().__init__(target, def_symbol=False)
-
-    def __call__(self, target_config):
-        if not isinstance(target_config, (str, dict)):
-            raise ValueError(
-                f"T.target expected a config dict or string, but got {type(target_config)}"
-            )
-        return Target(target_config)
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
deleted file mode 100644
index 4548102a9e..0000000000
--- a/python/tvm/script/tir/ty.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Typing Class for TIR
-
-This module provides typing class for TVM script type annotation usage, it can be viewed as
-a wrapper for uniform Type system in IR
-"""
-# pylint: disable=invalid-name
-from numbers import Integral
-
-import tvm
-from .special_stmt import SpecialStmt, convert_to_int
-
-
-class TypeGeneric:  # pylint: disable=too-few-public-methods
-    """Base class for all the TVM script typing class"""
-
-    def evaluate(self):
-        """Return an actual ir.Type Object that this Generic class wraps"""
-        raise TypeError("Cannot get tvm.Type from a generic type")
-
-    def require_type_generic_at(self, idx):  # pylint: disable=unused-argument
-        """If True, the `idx`th type argument must be TypeGeneric"""
-        return True
-
-    # This function is added here to avoid a pylint error
-    # for T.int/float below not being callable
-    def __call__(self):
-        raise NotImplementedError()
-
-
-class ConcreteType(TypeGeneric):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for uniform Type objects
-
-    Params
-    ------
-    vtype: Union[str, tvm.ir.Type]
-
-        The IR type represented by the type annotation.  If a string
-        (e.g. "float32"), this represents a `ir.PrimType` generated
-        from that string.  If a `ir.Type` is provided, this represents
-        the type provided.
-    """
-
-    def __init__(self, vtype):
-        if isinstance(vtype, tvm.ir.Type):
-            self.type = vtype
-        else:
-            self.type = tvm.ir.PrimType(vtype)
-
-    def __call__(self, *args):  # pylint: disable=arguments-differ
-        pass
-
-    def evaluate(self):
-        return self.type
-
-
-class VoidType(ConcreteType):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for void type"""
-
-    def __init__(self):
-        super().__init__("")
-
-
-class GenericPtrType(TypeGeneric):  # pylint: disable=abstract-method
-    """TVM script typing class generator for PtrType
-
-    [] operator is overloaded, accepts a ConcreteType and an optional storage scope string,
-    returns a ConcreteType wrapping PtrType
-    """
-
-    def __getitem__(self, args):
-        if isinstance(args, TypeGeneric):
-            args = [args]
-        if len(args) == 1:
-            vtype, scope = args[0], "global"
-        elif len(args) == 2:
-            vtype, scope = args[0], args[1]
-        else:
-            raise TypeError(f"Illegal type argument num for Ptr")
-        if not isinstance(vtype, TypeGeneric):
-            raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
-        if not isinstance(scope, str):
-            raise TypeError(f"Ptr expects storage scope argument be a string")
-        return ConcreteType(tvm.ir.PointerType(vtype.evaluate(), scope))
-
-    def require_type_generic_at(self, idx):
-        return idx != 1  # the second argument is storage scope for Ptr
-
-
-class GenericTupleType(TypeGeneric):  # pylint: disable=abstract-method
-    """TVM script typing class generator for TupleType
-
-    [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType
-    wrapping TupleType
-    """
-
-    def __getitem__(self, vtypes):
-        if isinstance(vtypes, TypeGeneric):
-            vtypes = [vtypes]
-        return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
-
-
-class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for uniform Type objects"""
-
-    def __init__(self, vtype):
-        def match_buffer_syntax_sugar(
-            shape,
-            dtype: str = "float32",
-            name: str = None,
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            return buffer
-
-        self.type = vtype
-        super().__init__(match_buffer_syntax_sugar, def_symbol=True)
-
-    def __call__(
-        self,
-        shape,
-        dtype="float32",
-        *,
-        name: str = None,
-        data=None,
-        strides=None,
-        elem_offset=None,
-        scope="global",
-        align=-1,
-        offset_factor=0,
-        buffer_type="default",
-        axis_separators=None,
-        span=None,
-    ):
-        """
-        This function is for Buffer(...) syntax sugar.
-        """
-        pass  # pylint: disable=unnecessary-pass
-
-    def __getitem__(self, args):
-        """
-        This function is for Buffer[...] syntax sugar
-        Note that args is the list of all arguments
-        """
-        if len(args) < 2:
-            raise ValueError("T.Buffer[...] needs at least two arguments: shape and dtype.")
-
-        shape = args[0]
-        dtype = args[1]
-
-        valid_shape = isinstance(shape, (tvm.ir.PrimExpr, Integral, tuple, list))
-        valid_dtype = isinstance(dtype, str)
-        if not (valid_shape and valid_dtype):
-            raise ValueError(
-                "The first argument of T.Buffer[...] needs to be a tuple, "
-                "followed by the second argument dtype as a string"
-            )
-
-
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            _name = _dtype + _size + _lanes
-            globals()[_name] = ConcreteType(_name)
-
-boolean = ConcreteType("bool")
-handle = ConcreteType("handle")
-void = VoidType()
-Ptr = GenericPtrType()
-Tuple = GenericTupleType()
-# we don't have 'buffer' type on the cpp side
-# thus 'handle' is used here for convenience's sake
-Buffer = GenericBufferType("handle")
diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py
deleted file mode 100644
index c655a62237..0000000000
--- a/python/tvm/script/utils.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Helper functions in TVM Script Parser"""
-
-from typing import Callable, List, Any, Optional, Tuple
-
-import inspect
-import synr
-
-from tvm.ir import Span, SourceName
-from tvm.error import DiagnosticError
-
-
-def get_param_list(
-    func: Callable,
-) -> Tuple[List[str], List[Tuple[str, Tuple[Any, ...]]], Optional[str]]:
-    """Get the parameter list from definition of function"""
-    full_arg_spec: inspect.FullArgSpec = inspect.getfullargspec(func)
-
-    args: List[str]
-    defaults: Optional[Tuple[Any, ...]]
-    kwonlyargs: List[str]
-    args, defaults, kwonlyargs = (
-        full_arg_spec.args,
-        full_arg_spec.defaults,
-        full_arg_spec.kwonlyargs,
-    )
-
-    if defaults is None:
-        defaults = tuple()
-
-    if full_arg_spec.varkw is not None:
-        raise RuntimeError(
-            "TVM Script register error : variable keyword argument is not supported now"
-        )
-
-    if len(kwonlyargs) == 1 and kwonlyargs[0] == "span":
-        pass
-    elif not len(kwonlyargs) == 0:
-        raise RuntimeError("TVM Script register error : keyword only argument is not supported now")
-
-    pos_only: List[str] = list()
-    for arg in args[: len(args) - len(defaults)]:
-        if arg != "span":
-            pos_only.append(arg)
-    kwargs: List[Tuple[str, Tuple[Any, ...]]] = list()
-    for default, arg in zip(defaults, args[len(args) - len(defaults) :]):
-        if arg != "span":
-            kwargs.append((arg, default))
-
-    return pos_only, kwargs, full_arg_spec.varargs
-
-
-def tvm_span_from_synr(span: synr.ast.Span) -> Span:
-    """Convert a synr span to a TVM span"""
-    return Span(
-        SourceName(span.filename),
-        span.start_line,
-        span.end_line,
-        span.start_column,
-        span.end_column,
-    )
-
-
-def synr_span_from_tvm(span: Span) -> synr.ast.Span:
-    """Convert a TVM span to a synr span"""
-    return synr.ast.Span(
-        span.source_name.name,
-        span.line,
-        span.column,
-        span.end_line,
-        span.end_column,
-    )
-
-
-def call_with_error_reporting(
-    report_error,
-    node_span,
-    func,
-    *args,
-    **kwargs,
-):
-    """Call function with exception handling and report error using node_span"""
-    try:
-        return func(*args, **kwargs)
-    except DiagnosticError:
-        raise
-    except Exception as err:  # pylint: disable=broad-except
-        # printing last non-empty row of error message.
-        error_msg = list(filter(None, str(err).split("\n")))[-1]
-        report_error(error_msg, node_span)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index c64b7dfe71..aec3eceacb 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -19,50 +19,184 @@
 from tvm.ir import PrimExpr
 from tvm.runtime import const
 
-from .buffer import Buffer, decl_buffer, DataProducer
-from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
-from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
-from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
-from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
-from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
-
-from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
+from . import analysis, schedule, stmt_functor, transform, usmp
+from .generic import cast
+from .buffer import Buffer, DataProducer, decl_buffer
+from .data_layout import BijectiveLayout, Layout, bijective_layout, layout
+from .expr import (
+    EQ,
+    GE,
+    GT,
+    LE,
+    LT,
+    NE,
+    Add,
+    And,
+    Any,
+    Broadcast,
+    BufferLoad,
+    Call,
+    CallEffectKind,
+    Cast,
+    CommReducer,
+    Div,
+    FloatImm,
+    FloorDiv,
+    FloorMod,
+    IntImm,
+    IterVar,
+    Let,
+    Load,
+    Max,
+    Min,
+    Mod,
+    Mul,
+    Not,
+    Or,
+    ProducerLoad,
+    Ramp,
+    Reduce,
+    Select,
+    Shuffle,
+    SizeVar,
+    StringImm,
+    Sub,
+    Var,
+)
+from .function import IndexMap, PrimFunc, TensorIntrin
+from .op import (
+    TVMBackendAllocWorkspace,
+    TVMBackendFreeWorkspace,
+    abs,
+    acos,
+    acosh,
+    address_of,
+    all,
+    any,
+    asin,
+    asinh,
+    assume,
+    atan,
+    atan2,
+    atanh,
+    call_cpacked,
+    call_cpacked_lowered,
+    call_extern,
+    call_intrin,
+    call_llvm_intrin,
+    call_llvm_pure_intrin,
+    call_packed,
+    call_packed_lowered,
+    call_pure_extern,
+    ceil,
+    ceildiv,
+    clz,
+    comm_reducer,
+    copysign,
+    cos,
+    cosh,
+    div,
+    erf,
+    exp,
+    exp2,
+    exp10,
+    floor,
+    floordiv,
+    floormod,
+    fmod,
+    hypot,
+    if_then_else,
+    indexdiv,
+    indexmod,
+    isfinite,
+    isinf,
+    isnan,
+    isnullptr,
+    ldexp,
+    likely,
+    log,
+    log1p,
+    log2,
+    log10,
+    lookup_param,
+    max,
+    max_value,
+    min,
+    min_value,
+    mma_fill,
+    mma_store,
+    nearbyint,
+    nextafter,
+    popcount,
+    power,
+    ptx_commit_group,
+    ptx_cp_async,
+    ptx_ldmatrix,
+    ptx_mma,
+    ptx_mma_sp,
+    ptx_wait_group,
+    q_multiply_shift,
+    ret,
+    round,
+    rsqrt,
+    shift_left,
+    shift_right,
+    sigmoid,
+    sin,
+    sinh,
+    sqrt,
+    sum,
+    tan,
+    tanh,
+    trace,
+    trunc,
+    truncdiv,
+    truncmod,
+    tvm_access_ptr,
+    tvm_bmma_sync,
+    tvm_fill_fragment,
+    tvm_load_matrix_sync,
+    tvm_mma_sync,
+    tvm_stack_alloca,
+    tvm_stack_make_array,
+    tvm_stack_make_shape,
+    tvm_store_matrix_sync,
+    tvm_struct_get,
+    tvm_struct_set,
+    tvm_thread_allreduce,
+    tvm_throw_last_error,
+    tvm_tuple,
+    undef,
+    vectorcombine,
+    vectorhigh,
+    vectorlow,
+)
+from .schedule import BlockScope, Schedule, ScheduleError, ScheduleState, StmtSRef
 from .stmt import (
-    BufferStore,
-    BufferRealize,
-    Store,
-    ProducerStore,
     Allocate,
     AllocateConst,
+    AssertStmt,
     AttrStmt,
+    Block,
+    BlockRealize,
+    BufferRealize,
+    BufferRegion,
+    BufferStore,
     DeclBuffer,
+    Evaluate,
+    For,
+    ForKind,
+    IfThenElse,
+    LetStmt,
+    MatchBufferRegion,
+    Prefetch,
+    ProducerRealize,
+    ProducerStore,
+    SeqStmt,
+    Stmt,
+    Store,
+    While,
+    stmt_list,
+    stmt_seq,
+    type_annotation,
 )
-
-from .stmt import ProducerRealize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
-from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-
-from .function import PrimFunc, TensorIntrin, IndexMap
-
-from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
-from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
-from .op import sin, sinh, asin, asinh
-from .op import cos, cosh, acos, acosh
-from .op import tan, tanh, atan, atan2, atanh
-from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
-from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
-from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
-from .op import comm_reducer, min, max, sum
-from .op import q_multiply_shift
-
-from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
-
-from . import schedule
-from . import ir_builder
-from . import transform
-from . import analysis
-from . import stmt_functor
-from . import usmp
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 13674daa24..7959e82e7b 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -20,11 +20,11 @@ from typing import Dict, List, Union
 
 from tvm import Object
 from tvm.ir import IRModule
-from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
 
-from .. import Buffer, Stmt
+from ..buffer import Buffer
+from ..expr import Var
 from ..function import PrimFunc
+from ..stmt import Block, BufferRegion, PrimExpr, Stmt
 from . import _ffi_api
 
 
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
   assert(y.a == x)
 """
 from typing import Optional, Union
-from tvm import ir
+
 import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
 from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
 
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
 from . import _ffi_api
+from . import generic as _generic
 
 
 def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
 class ExprOp(object):
     """Operator overloading for Expr like expressions."""
 
-    # TODO(tkonolige): use inspect to add source information to these objects
-
     def __add__(self, other):
         return _generic.add(self, other)
 
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
     """
 
     def __init__(self, condition, true_value, false_value, span=None):
+        if isinstance(condition, bool):
+            condition = IntImm("bool", condition)
         self.__init_handle_by_constructor__(
             _ffi_api.Select, condition, true_value, false_value, span  # type: ignore
         )
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 17005b04a4..e2cad37bb6 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -14,17 +14,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=redefined-builtin, invalid-name
+# pylint: disable=redefined-builtin,invalid-name,no-member,protected-access
 """Operators used in TIR expression."""
+import warnings
 from typing import Any, Optional
+
 import tvm._ffi
-from tvm.ir.base import Span
-from tvm.runtime import convert, const
 from tvm.ir import Array, Op
+from tvm.ir.base import Span
+from tvm.runtime import const, convert
 
-from .buffer import Buffer
-from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer
 from . import _ffi_api
+from .buffer import Buffer
+from .expr import Call, CommReducer, PrimExprWithOp, StringImm, Var
 
 
 def _pack_buffer(buf, span=None):
@@ -100,6 +102,64 @@ def call_cpacked(*args, span=None):
     return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span)
 
 
+def call_packed_lowered(*args, span=None):
+    """Lowered version of call packed.
+
+    The argument to packed function can be Expr or Buffer.
+    The argument is the corresponding POD type when Expr is presented.
+
+    When the argument is Buffer, the corresponding PackedFunc
+    will recieve an TVMArrayHandle whose content is valid during the callback period.
+    If the PackedFunc is a python callback, then the corresponding argument is NDArray.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span)
+
+
+def call_cpacked_lowered(*args, span=None):
+    """Lowered version of call c-packed.
+
+    Same as call_packed, except that the first argument is the function name
+    (as in call_extern), and the last argument is the resource handle.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span)
+
+
 def call_intrin(dtype, func_name, *args, span=None):
     """Build expression by calling an intrinsic function.
 
@@ -151,7 +211,10 @@ def call_pure_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span
+        dtype,
+        Op.get("tir.call_pure_extern"),
+        convert((StringImm(func_name),) + args),
+        span,
     )
 
 
@@ -178,7 +241,10 @@ def call_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span
+        dtype,
+        Op.get("tir.call_extern"),
+        convert((StringImm(func_name),) + args),
+        span=span,
     )
 
 
@@ -207,10 +273,22 @@ def call_llvm_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    if llvm_id == 0:
+        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
     return call_intrin(
-        dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+        dtype,
+        Op.get("tir.call_llvm_intrin"),
+        tvm.tir.const(llvm_id, "uint32"),
+        *args,
+        span=span,
     )
 
 
@@ -239,8 +317,16 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    if llvm_id == 0:
+        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
     return call_intrin(
         dtype,
         Op.get("tir.call_llvm_pure_intrin"),
@@ -250,6 +336,326 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     )
 
 
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+    return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+    return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
+def tvm_stack_alloca(dtype_str, num):
+    return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num)
+
+
+def tvm_stack_make_shape(*args):
+    return call_intrin("handle", "tir.tvm_stack_make_shape", *args)
+
+
+def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
+    return call_intrin(
+        "handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset
+    )
+
+
+def address_of(buffer_load, span=None):
+    """Returns the address of an element in the buffer
+
+    Parameters
+    ----------
+    buffer_load: BufferLoad
+        The buffer load.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+
+
+def lookup_param(param_name, span=None):
+    """Returns the param by name
+
+    Parameters
+    ----------
+    param_name : str
+        The name of param.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.lookup_param", param_name, span=span)
+
+
+def tvm_tuple(*value):
+    return call_intrin("handle", "tir.tvm_tuple", *value)
+
+
+def tvm_struct_get(arr, index, field_id, dtype):
+    return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field_id)
+
+
+def tvm_struct_set(arr, index, field_id, value):
+    return call_intrin("handle", "tir.tvm_struct_set", arr, index, field_id, value)
+
+
+def tvm_thread_allreduce(*freduce_args):
+    return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_load_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def tvm_mma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_mma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_bmma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_bmma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_fill_fragment(fragment, m, n, k, index, value):
+    return call_intrin(
+        "handle",
+        "tir.tvm_fill_fragment",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        value,
+    )
+
+
+def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_store_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def ptx_mma(  # pylint: disable=missing-docstring
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    saturate,
+    operator=None,
+):
+    if operator is None:
+        return call_intrin(
+            dtype,
+            "tir.ptx_mma",
+            shape,
+            A_layout,
+            B_layout,
+            A_dtype,
+            B_dtype,
+            C_dtype,
+            multiplicand_a,
+            a_index,
+            multiplicand_b,
+            b_index,
+            accumulator,
+            c_index,
+            saturate,
+        )
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        saturate,
+        operator,
+    )
+
+
+def ptx_mma_sp(  # pylint: disable=missing-docstring
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    metadata,
+    meta_index,
+    sparse_selector,
+    saturate,
+):
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma_sp",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        metadata,
+        meta_index,
+        sparse_selector,
+        saturate,
+    )
+
+
+def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
+    return call_intrin(
+        dtype,
+        "tir.ptx_ldmatrix",
+        trans,
+        num,
+        type,
+        local_ptr,
+        local_offset,
+        smem_ptr,
+        smem_offset,
+    )
+
+
+def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
+    return call_intrin(
+        dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
+    )
+
+
+def ptx_commit_group():
+    return call_intrin("", "tir.ptx_commit_group")
+
+
+def ptx_wait_group(num):
+    return call_intrin("", "tir.ptx_wait_group", num)
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+    return call_intrin(
+        dtype,
+        "tir.mma_store",
+        m,
+        n,
+        dst_ptr,
+        src_ptr,
+        src_offset,
+        dst_stride,
+    )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+    return call_intrin(
+        dtype,
+        "tir.mma_fill",
+        local_size,
+        local_ptr,
+        offset,
+    )
+
+
+def vectorlow(dtype, vec):
+    return call_intrin(dtype, "tir.vectorlow", vec)
+
+
+def vectorhigh(dtype, vec):
+    return call_intrin(dtype, "tir.vectorhigh", vec)
+
+
+def vectorcombine(dtype, vec1, vec2):
+    return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)
+
+
+def assume(cond=None):
+    return call_intrin("int32", "tir.assume", cond)
+
+
+def undef():
+    return call_intrin("int32", "tir.undef")
+
+
 def ret(val):
     """Create a tir return expression
 
@@ -286,9 +692,9 @@ def any(*args, span=None):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    val = _ffi_api._OpOr(args[0], args[1], span)  # type: ignore
+    val = _ffi_api._OpOr(args[0], args[1], span)  # type: ignore # pylint: disable=no-member,protected-access
     for i in range(2, len(args)):
-        val = _ffi_api._OpOr(val, args[i], span)  # type: ignore
+        val = _ffi_api._OpOr(val, args[i], span)  # type: ignore # pylint: disable=no-member,protected-access
     return val
 
... 1655 lines suppressed ...