You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 22:31:53 UTC

[tvm] branch ir-builder-v2 created (now fbba02c1a4)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a change to branch ir-builder-v2
in repository https://gitbox.apache.org/repos/asf/tvm.git


      at fbba02c1a4 [TVMScript] New Parser

This branch includes the following new commits:

     new fbba02c1a4 [TVMScript] New Parser

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[tvm] 01/01: [TVMScript] New Parser

Posted by ju...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch ir-builder-v2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit fbba02c1a41917cdf803070cc2f66f2a8b8b03c7
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Aug 15 12:33:04 2022 -0700

    [TVMScript] New Parser
---
 include/tvm/script/ir_builder/base.h               |  166 +++
 include/tvm/script/ir_builder/ir/frame.h           |   61 +
 include/tvm/script/ir_builder/ir/ir.h              |   39 +
 include/tvm/script/ir_builder/tir/frame.h          |  459 +++++++
 include/tvm/script/ir_builder/tir/ir.h             |  142 ++
 include/tvm/tir/op.h                               |   34 +-
 python/tvm/script/__init__.py                      |    6 +-
 python/tvm/script/context_maintainer.py            |  251 ----
 python/tvm/script/diagnostics.py                   |   55 -
 python/tvm/script/{ => ir_builder}/__init__.py     |    7 +-
 .../script/{__init__.py => ir_builder/_ffi_api.py} |    7 +-
 python/tvm/script/ir_builder/base.py               |   76 ++
 python/tvm/script/{ => ir_builder/ir}/__init__.py  |    8 +-
 .../{__init__.py => ir_builder/ir/_ffi_api.py}     |    7 +-
 .../{tir/__init__.py => ir_builder/ir/frame.py}    |   17 +-
 .../script/{__init__.py => ir_builder/ir/ir.py}    |    9 +-
 python/tvm/script/{ => ir_builder/tir}/__init__.py |    7 +-
 .../{__init__.py => ir_builder/tir/_ffi_api.py}    |    7 +-
 python/tvm/script/ir_builder/tir/frame.py          |  116 ++
 python/tvm/script/ir_builder/tir/ir.py             |  954 ++++++++++++++
 python/tvm/script/meta_unparser.py                 |   45 -
 python/tvm/script/parser.py                        | 1385 --------------------
 python/tvm/script/{ => parser}/__init__.py         |   12 +-
 python/tvm/script/{__init__.py => parser/_core.py} |   13 +-
 python/tvm/script/{ => parser/core}/__init__.py    |    7 +-
 python/tvm/script/parser/core/diagnostics.py       |  175 +++
 python/tvm/script/parser/core/dispatch.py          |   63 +
 python/tvm/script/parser/core/doc.py               |  361 +++++
 .../script/{printer => parser/core}/doc_core.py    |    0
 .../{tir/prim_func.py => parser/core/entry.py}     |   46 +-
 python/tvm/script/parser/core/evaluator.py         |  282 ++++
 python/tvm/script/parser/core/parser.py            |  273 ++++
 .../{tir/__init__.py => parser/core/utils.py}      |   27 +-
 python/tvm/script/{ => parser/ir}/__init__.py      |    8 +-
 .../{tir/prim_func.py => parser/ir/entry.py}       |   32 +-
 .../{tir/__init__.py => parser/ir/parser.py}       |   28 +-
 python/tvm/script/{ => parser/tir}/__init__.py     |   11 +-
 python/tvm/script/parser/tir/entry.py              |  101 ++
 python/tvm/script/parser/tir/operation.py          |   84 ++
 python/tvm/script/parser/tir/parser.py             |  268 ++++
 python/tvm/script/registry.py                      |   62 -
 python/tvm/script/tir/__init__.pyi                 |  477 -------
 python/tvm/script/tir/intrin.py                    |  222 ----
 python/tvm/script/tir/node.py                      |  218 ---
 python/tvm/script/tir/scope_handler.py             |  788 -----------
 python/tvm/script/tir/special_stmt.py              |  964 --------------
 python/tvm/script/tir/ty.py                        |  216 ---
 python/tvm/script/utils.py                         |  105 --
 python/tvm/tir/__init__.py                         |  216 ++-
 python/tvm/tir/analysis/analysis.py                |    6 +-
 python/tvm/tir/expr.py                             |   15 +-
 python/tvm/tir/op.py                               |  601 ++++++++-
 python/tvm/tir/schedule/block_scope.py             |    2 +-
 python/tvm/tir/schedule/schedule.py                |    6 +-
 python/tvm/tir/schedule/state.py                   |    3 +-
 python/tvm/tir/stmt.py                             |    4 +
 python/tvm/tir/usmp/transform/transform.py         |    5 +-
 src/script/ir_builder/base.cc                      |  115 ++
 src/script/ir_builder/ir/frame.cc                  |   43 +
 src/script/ir_builder/ir/ir.cc                     |   38 +
 src/script/ir_builder/tir/frame.cc                 |  210 +++
 src/script/ir_builder/tir/ir.cc                    |  665 ++++++++++
 src/script/ir_builder/tir/utils.h                  |   95 ++
 src/tir/ir/script/script_complete.h                |   37 +
 src/tir/ir/stmt.cc                                 |   10 +
 src/tir/op/op.cc                                   |   24 +
 66 files changed, 5782 insertions(+), 5014 deletions(-)

diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h
new file mode 100644
index 0000000000..179cca42df
--- /dev/null
+++ b/include/tvm/script/ir_builder/base.h
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_
+#define TVM_SCRIPT_IR_BUILDER_BASE_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+
+////////////////////////////// Core Infra: Frame //////////////////////////////
+
+class IRBuilderFrameNode : public runtime::Object {
+ public:
+  std::vector<runtime::TypedPackedFunc<void()>> callbacks;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `callbacks` is not visited.
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
+
+ public:
+  virtual ~IRBuilderFrameNode() = default;
+  virtual void EnterWithScope();
+  virtual void ExitWithScope();
+
+  void AddCallback(runtime::TypedPackedFunc<void()> callback);
+};
+
+class IRBuilderFrame : public runtime::ObjectRef {
+ public:
+  virtual ~IRBuilderFrame() = default;
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
+
+ protected:
+  IRBuilderFrame() = default;
+
+ public:
+  inline void EnterWithScope();
+  inline void ExitWithScope();
+};
+
+////////////////////////////// Core Infra: Builder //////////////////////////////
+///
+class IRBuilderNode : public runtime::Object {
+ public:
+  runtime::Array<IRBuilderFrame> frames;
+  Optional<ObjectRef> result;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("frames", &frames);
+    v->Visit("result", &result);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.IRBuilder";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
+
+ public:
+  template <typename TFrame>
+  inline Optional<TFrame> FindFrame() const;
+  template <typename TFrame>
+  inline Optional<TFrame> GetLastFrame() const;
+  template <typename TObjectRef>
+  inline TObjectRef Get() const;
+};
+
+class IRBuilder : public runtime::ObjectRef {
+ public:
+  IRBuilder();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
+
+ public:
+  void EnterWithScope();
+  void ExitWithScope();
+  static IRBuilder Current();
+  template <class TObjectRef>
+  inline static TObjectRef Name(String name, TObjectRef obj);
+};
+
+////////////////////////////// Details //////////////////////////////
+
+namespace details {
+
+class Namer {
+ public:
+  using FType = NodeFunctor<void(const ObjectRef&, String)>;
+  static FType& vtable();
+  static void Name(ObjectRef node, String name);
+};
+
+}  // namespace details
+
+template <class TObjectRef>
+inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
+  details::Namer::Name(obj, name);
+  return Downcast<TObjectRef>(obj);
+}
+
+inline void IRBuilderFrame::EnterWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
+}
+
+inline void IRBuilderFrame::ExitWithScope() {
+  ICHECK(data_ != nullptr);
+  static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
+  data_.reset();
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::FindFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
+    if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
+      return GetRef<TFrame>(p);
+    }
+  }
+  return NullOpt;
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
+  using TFrameNode = typename TFrame::ContainerType;
+  if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
+    return Downcast<TFrame>(frames.back());
+  }
+  return NullOpt;
+}
+
+template <typename TObjectRef>
+inline TObjectRef IRBuilderNode::Get() const {
+  using TObject = typename TObjectRef::ContainerType;
+  CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
+  const auto* n = result.as<TObject>();
+  CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key;
+  return GetRef<TObjectRef>(n);
+}
+
+}  // namespace ir_builder
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_IR_BUILDER_BASE_H_
diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h
new file mode 100644
index 0000000000..9a8791be7c
--- /dev/null
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_
+#define TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+#include <tvm/script/ir_builder/base.h>
+
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+
+class IRModuleFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<GlobalVar> global_vars;
+  Array<BaseFunc> functions;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("global_vars", &global_vars);
+    v->Visit("functions", &functions);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IRModuleFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame,
+                                                    IRModuleFrameNode);
+};
+
+}  // namespace ir_builder
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_
diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h
new file mode 100644
index 0000000000..b58e51a945
--- /dev/null
+++ b/include/tvm/script/ir_builder/ir/ir.h
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_IR_IR_H_
+#define TVM_SCRIPT_IR_BUILDER_IR_IR_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+#include <tvm/script/ir_builder/ir/frame.h>
+
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+
+TVM_DLL IRModuleFrame IRModule();
+
+}
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_IR_IR_BUILDER_IR_IR_H_
diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
new file mode 100644
index 0000000000..d2d2485bbe
--- /dev/null
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -0,0 +1,459 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
+#define TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
+
+#include <tvm/script/ir_builder/base.h>
+#include <tvm/script/ir_builder/ir/frame.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+namespace tir {
+
+class TIRFrameNode : public IRBuilderFrameNode {
+ public:
+  Array<tvm::tir::Stmt> stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    IRBuilderFrameNode::VisitAttrs(v);
+    v->Visit("stmts", &stmts);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.TIRFrame";
+  TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode);
+};
+
+class TIRFrame : public IRBuilderFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode);
+
+ protected:
+  TIRFrame() = default;
+};
+
+class BlockFrameNode : public TIRFrameNode {
+ public:
+  String name;
+  Array<tvm::tir::IterVar> iter_vars;
+  Optional<Array<tvm::tir::BufferRegion>> reads;
+  Optional<Array<tvm::tir::BufferRegion>> writes;
+  Optional<tvm::tir::Stmt> init;
+  Array<tvm::tir::Buffer> alloc_buffers;
+  Array<tvm::tir::MatchBufferRegion> match_buffers;
+  Optional<Map<String, ObjectRef>> annotations;
+
+  Array<PrimExpr> iter_values;
+  Optional<PrimExpr> predicate;
+  bool no_realize;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("iter_vars", &iter_vars);
+    v->Visit("reads", &reads);
+    v->Visit("writes", &writes);
+    v->Visit("init", &init);
+    v->Visit("alloc_buffers", &alloc_buffers);
+    v->Visit("match_buffers", &match_buffers);
+    v->Visit("annotations", &annotations);
+    v->Visit("iter_values", &iter_values);
+    v->Visit("predicate", &predicate);
+    v->Visit("no_realize", &no_realize);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.BlockFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class BlockFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
+};
+
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.BlockInitFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class BlockInitFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode);
+};
+
+class ForFrameNode : public TIRFrameNode {
+ public:
+  using FMakeForLoop =
+      runtime::TypedPackedFunc<tvm::tir::Stmt(Array<tvm::tir::Var>, Array<Range>, tvm::tir::Stmt)>;
+
+  Array<tvm::tir::Var> vars;
+  Array<Range> doms;
+  FMakeForLoop f_make_for_loop;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("vars", &vars);
+    v->Visit("doms", &doms);
+    // `f_make_for_loop` is not visited.
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.ForFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class ForFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
+};
+
+class PrimFuncFrameNode : public TIRFrameNode {
+ public:
+  Optional<String> name;
+  Array<tvm::tir::Var> args;
+  Optional<Type> ret_type;
+  Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
+  Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
+  Optional<Map<String, ObjectRef>> attrs;
+  Map<tvm::tir::Var, tvm::tir::IterVar> env_threads;
+  Array<tvm::tir::Buffer> root_alloc_buffers;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("args", &args);
+    v->Visit("ret_type", &ret_type);
+    v->Visit("buffer_map", &buffer_map);
+    v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
+    v->Visit("attrs", &attrs);
+    v->Visit("env_threads", &env_threads);
+    v->Visit("root_alloc_buffers", &root_alloc_buffers);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.PrimFuncFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class PrimFuncFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
+};
+
+class AssertFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  PrimExpr message;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("message", &message);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AssertFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AssertFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+class LetFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Var var;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LetFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+  Array<PrimExpr> extents;
+  DataType dtype;
+  String storage_scope;
+  PrimExpr condition;
+  Map<String, ObjectRef> annotations;
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extents", &extents);
+    v->Visit("dtype", &dtype);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+    v->Visit("annotations", &annotations);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+  DataType dtype;
+  Array<PrimExpr> extents;
+  tvm::runtime::NDArray data;
+  tvm::tir::Buffer buffer;
+  Map<String, ObjectRef> annotations;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("dtype", &dtype);
+    v->Visit("extents", &extents);
+    v->Visit("data", &data);
+    v->Visit("buffer", &buffer);
+    v->Visit("annotations", &annotations);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AllocateConstFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+                                                    AllocateConstFrameNode);
+};
+
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr extent;
+  String attr_key;
+  tvm::tir::IterVar iter_var;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("extent", &extent);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("iter_var", &iter_var);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class LaunchThreadFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+                                                    LaunchThreadFrameNode);
+};
+
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::BufferRegion buffer_slice;
+  String storage_scope;
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer_slice", &buffer_slice);
+    v->Visit("storage_scope", &storage_scope);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class RealizeFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
+
+class AttrFrameNode : public TIRFrameNode {
+ public:
+  ObjectRef node;
+  String attr_key;
+  PrimExpr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("node", &node);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class AttrFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+class WhileFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class WhileFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+class IfFrameNode : public TIRFrameNode {
+ public:
+  PrimExpr condition;
+  Optional<Array<tvm::tir::Stmt>> then_stmts;
+  Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("condition", &condition);
+    v->Visit("then_stmts", &then_stmts);
+    v->Visit("else_stmts", &else_stmts);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class IfFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+class ThenFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ThenFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+class ElseFrameNode : public TIRFrameNode {
+ public:
+  static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+  void EnterWithScope() final;
+  void ExitWithScope() final;
+};
+
+class ElseFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+  tvm::tir::Buffer buffer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    TIRFrameNode::VisitAttrs(v);
+    v->Visit("buffer", &buffer);
+  }
+
+  static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+  void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
new file mode 100644
index 0000000000..c26d552737
--- /dev/null
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
+#define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
+
+#include <tvm/script/ir_builder/base.h>
+#include <tvm/script/ir_builder/tir/frame.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+namespace tir {
+
+using tvm::runtime::NDArray;
+using tvm::tir::Buffer;
+using tvm::tir::Var;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+                  Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+                  String storage_scope, int align, int offset_factor, String buffer_type,
+                  Optional<Array<IntImm>> axis_separators);
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
+BlockFrame Block(String name, bool no_realize = false);
+BlockInitFrame Init();
+void Where(PrimExpr predicate);
+void Reads(Array<ObjectRef> buffer_slices);
+void Writes(Array<ObjectRef> buffer_slices);
+void BlockAttrs(Map<String, ObjectRef> attrs);
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
+                   int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+
+namespace axis {
+Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+}  // namespace axis
+
+ForFrame Serial(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Parallel(PrimExpr start, PrimExpr stop,
+                  Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Vectorized(PrimExpr start, PrimExpr stop,
+                    Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Unroll(PrimExpr start, PrimExpr stop,
+                Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Grid(Array<PrimExpr> extents);
+
+PrimFuncFrame PrimFunc();
+Var Arg(String name, Var var);
+Buffer Arg(String name, Buffer buffer);
+void FuncName(String name);
+void FuncAttrs(Map<String, ObjectRef> attrs);
+Type FuncRet(Type ret_type);
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
+                   int align = -1, int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
+                        DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
+                        Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
+                        String storage_scope = "global", int align = -1, int offset_factor = 0,
+                        String buffer_type = "default", Array<IntImm> axis_separators = {});
+
+AssertFrame Assert(PrimExpr condition, String message);
+LetFrame Let(Var var, PrimExpr value);
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+                       Optional<PrimExpr> condition = NullOpt,
+                       Optional<Map<String, ObjectRef>> annotations = NullOpt);
+AllocateConstFrame AllocateConst(
+    NDArray data, DataType dtype, Array<PrimExpr> extents,
+    Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+WhileFrame While(PrimExpr condition);
+IfFrame If(PrimExpr condition);
+ThenFrame Then();
+ElseFrame Else();
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
+Var EnvThread(String thread_tag);
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+void Prefetch(Buffer buffer, Array<Range> bounds);
+void Evaluate(PrimExpr value);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)                             \
+  inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) {                        \
+    DataType dtype = DType;                                                            \
+    return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
+  }
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
+
+#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
+
+}  // namespace tir
+}  // namespace ir_builder
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_TIR_IR_BUILDER_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 7236c6a611..09758cc923 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
  * \return The result expression.
  */
 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
 /*!
  * \brief Check if x is infinite.
  * \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
                       Span span = Span());
 
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
 /*!
  * \brief Calculate floor(x)
  * \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
                                   Span span = Span());
 
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
 // Intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
   inline PrimExpr OpName(PrimExpr x, Span span = Span()) {              \
@@ -701,6 +732,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
 TVM_DECLARE_INTRIN_UNARY(log);
 TVM_DECLARE_INTRIN_UNARY(log2);
 TVM_DECLARE_INTRIN_UNARY(log10);
+TVM_DECLARE_INTRIN_UNARY(log1p);
 TVM_DECLARE_INTRIN_UNARY(popcount);
 TVM_DECLARE_INTRIN_UNARY(tan);
 TVM_DECLARE_INTRIN_UNARY(cos);
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..3107ada88b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,5 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+from . import parser
+from .parser import ir, ir_module, parse, tir
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py
deleted file mode 100644
index f7f16855c7..0000000000
--- a/python/tvm/script/context_maintainer.py
+++ /dev/null
@@ -1,251 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Context Maintainer for TIR"""
-
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
-
-import tvm
-from tvm.ir import Span
-from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
-from tvm.runtime import Object
-from tvm.tir.expr import IterVar
-from .tir.node import BufferSlice
-
-
-class BlockInfo:
-    """Information for block and block_realize signature
-
-    Examples
-    ----------
-    .. code-block:: python
-
-        @T.prim_func
-        def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
-            A = T.match_buffer(a, (16, 16), "float32")
-            B = T.match_buffer(b, (16, 16), "float32")
-            C = T.match_buffer(a, (16, 16), "float32")
-
-            for i, j, k in T.grid(16, 16, 16):
-                with T.block("matmul"):
-                    vi = T.axis.S(16, i)
-                    vj = T.axis.S(16, j)
-                    vk = T.axis.R(16, k)         # iter_bindings = {vj: i, vj: j, vk: k}
-
-                    T.where(True)         # predicate of the block_realize
-
-                    T.reads(A[0:16, 0:16], B[0: 16, 0: 16])      # reads region of the block
-                    T.writes(C[0: 16, 0: 16])                    # writes region of the block
-                    T.block_attr({"attr_key": "attr_value"})     # block annotations
-
-                    # alloc_buffers inside the block
-                    CC = T.alloc_buffer((1, 1), dtype="float32")
-
-                    # match_buffers of the block,
-                    # which bind a sub-region of source buffer into a new buffer
-                    D = T.match_buffer(C[vi, vj], ())
-
-                    # init part of the block, executed when all reduce axes are the beginning value
-                    with T.init():
-                        C[vi, vj] = T.float32(0)
-
-                    # block body
-                    CC[0, 0] = A[vi, vk] * B[vj, vk]
-                    D[()] += CC[0, 0]         # The same as C[vi, vj] += CC[0, 0]
-    """
-
-    alloc_buffers: List[Buffer] = []
-    """List[Buffer]: list of T.alloc_buffer statements in the block signature"""
-    match_buffers: List[MatchBufferRegion] = []
-    """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
-    iter_values: List[PrimExpr] = []
-    """List[PrimExpr]: list of binding values for iter vars"""
-    iter_vars: List[IterVar] = []
-    """List[PrimExpr]: list of iter vars in the block"""
-    reads: Optional[List[BufferSlice]] = None
-    """Optional[List[BufferSlice]]:
-    list of T.reads statements in the block signature, None for not-visited"""
-    writes: Optional[List[BufferSlice]] = None
-    """Optional[List[BufferSlice]]:
-    list of T.writes statements in the block signature, None for not-visited"""
-    annotations: Optional[Mapping[str, Object]] = None
-    """Optional[Mapping[str, Object]]:
-    list of T.block_attr statements in the block signature, None for not-visited"""
-    predicate: Optional[PrimExpr] = None
-    """Optional[PrimExpr]: block realize predicate, None for not-visited"""
-    init: Optional[Stmt] = None
-    """Optional[Stmt]: init part of the block, None for not-visited"""
-
-    def __init__(self):
-        self.alloc_buffers = []
-        self.match_buffers = []
-        self.iter_values = []
-        self.iter_vars = []
-        self.reads = None
-        self.writes = None
-        self.annotations = None
-        self.predicate = None
-        self.init = None
-
-
-class ContextMaintainer:
-    """Maintain all the necessary context info
-    Parameters
-    ----------
-    _report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
-        The report error function handle
-    """
-
-    # scope context
-    node_stack: List[List[synr.ast.Node]] = []
-    """List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
-    block_info_stack: List[BlockInfo] = []
-    """List[BlockInfo]: The block info for the current block scope"""
-    loop_stack: Dict[Var, Range] = {}
-    """Dict[Var, Range]: The dict from loop var to its domain outside the block"""
-    symbols: List[Dict[str, Union[Var, Buffer]]] = []
-    """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
-    closure_vars: Dict[str, Object] = {}
-    """ClosureVars: The closure vars defined in Python interpreter"""
-
-    # function context
-    func_params: List[Var] = []
-    """List[Var]: The function parameters"""
-    func_buffer_map: Mapping[Var, Buffer] = {}
-    """Mapping[Var, Buffer]: The function buffer map"""
-    func_preflattened_buffer_map: Mapping[Var, Buffer] = {}
-    """Mapping[Var, Buffer]: The function buffer map, prior to any flattening."""
-    func_dict_attr: Mapping[str, Object] = {}
-    """Mapping[str, Object]: The function attrs"""
-    func_var_env_dict: Mapping[Var, str] = {}
-    """Mapping[Var, str]: The map from var to env thread"""
-
-    # parser and analyzer
-    analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
-    """tvm.arith.Analyzer: The analyzer for simplifying"""
-    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-    """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle"""
-
-    # root alloc_buffer
-    root_alloc_buffers: List[Buffer] = []
-    """List[Buffer]: The buffers allocated under root block"""
-
-    def __init__(
-        self,
-        _report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
-        closure_vars: Dict[str, Object],
-    ):
-        # scope context
-        self.node_stack = []
-        self.block_info_stack = []
-        self.loop_stack = {}
-        self.symbols = []
-        self.closure_vars = closure_vars
-        # function context
-        self.func_params = []
-        self.func_buffer_map = {}
-        self.func_preflattened_buffer_map = {}
-        self.func_dict_attr = {}
-        self.func_var_env_dict = {}
-        # parser and analyzer
-        self._report_error = _report_error
-        self.analyzer = tvm.arith.Analyzer()
-        # root alloc_buffer
-        self.root_alloc_buffers = []
-
-    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
-        """Creates a new scope
-
-        Note
-        ----
-        This function is used for normal scopes that do not involve
-        a `with block` scope. Use `enter_block_scope`
-        for block scope cases.
-
-        Parameters
-        ----------
-        nodes : Optional[List[synr.ast.Node]]
-            The synr AST nodes in new scope
-        """
-        if nodes is None:
-            nodes = []
-        self.node_stack.append(list(reversed(nodes)))
-        self.symbols.append(dict())
-
-    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
-        """Creates a new block scope, the function will call `enter_scope` implicitly
-        Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack
-        to maintain block info.
-
-        Note
-        ----
-        This function should be used to handle a block scope,
-        aka the blocks that involve a `with block` scope.
-
-        Parameters
-        ----------
-        nodes : Optional[List[synr.ast.Node]]
-            The synr AST nodes in new scope
-        """
-        self.enter_scope(nodes)
-        # Create a new BlockInfo for the new block
-        self.block_info_stack.append(BlockInfo())
-
-    def exit_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
-
-    def exit_block_scope(self):
-        """Pop the inner most block scope, the function will call `exit_scope` implicitly"""
-        self.exit_scope()
-        # Pop block_info
-        self.block_info_stack.pop()
-
-    def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
-        """Append a symbol into current scope"""
-        if isinstance(symbol, Buffer):
-            if name in self.symbols[0]:
-                self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
-            self.symbols[0][name] = symbol
-        else:
-            self.symbols[-1][name] = symbol
-
-    def remove_symbol(self, name: str):
-        """Remove a symbol"""
-        for symbols in reversed(self.symbols):
-            if name in symbols:
-                symbols.pop(name)
-                return
-        raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)
-
-    def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
-        """Look up symbol by name"""
-        for symbols in reversed(self.symbols):
-            if name in symbols:
-                return symbols[name]
-        return self.closure_vars.get(name)
-
-    def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
-        self._report_error(message, span)
-
-    def current_block_scope(self) -> BlockInfo:
-        if self.block_info_stack:
-            return self.block_info_stack[-1]
-        return None
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/diagnostics.py
deleted file mode 100644
index e676461ab3..0000000000
--- a/python/tvm/script/diagnostics.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Bridge from synr's (the library used for parsing the python AST)
-   DiagnosticContext to TVM's diagnostics
-"""
-from synr import DiagnosticContext, ast
-
-import tvm
-from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
-
-
-class TVMDiagnosticCtx(DiagnosticContext):
-    """TVM diagnostics for synr"""
-
-    diag_ctx: TVMCtx
-
-    def __init__(self) -> None:
-        self.diag_ctx = TVMCtx(tvm.IRModule(), get_renderer())
-        self.source_name = None
-
-    def to_tvm_span(self, src_name, ast_span: ast.Span) -> tvm.ir.Span:
-        return tvm.ir.Span(
-            src_name,
-            ast_span.start_line,
-            ast_span.end_line,
-            ast_span.start_column,
-            ast_span.end_column,
-        )
-
-    def add_source(self, name: str, source: str) -> None:
-        src_name = self.diag_ctx.module.source_map.add(name, source)
-        self.source_name = src_name
-
-    def emit(self, _level, message, span):
-        span = self.to_tvm_span(self.source_name, span)
-        self.diag_ctx.emit(Diagnostic(DiagnosticLevel.ERROR, span, message))
-        self.diag_ctx.render()  # Raise exception on the first error we hit. TODO remove
-
-    def render(self):
-        self.diag_ctx.render()
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/__init__.py
similarity index 87%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/__init__.py
index 555659d0c5..53721d2432 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/__init__.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
+"""tvm.script.ir_builder is a generic IR builder for TVM."""
 from . import tir
-
-from .parser import ir_module, from_source
+from .base import IRBuilder
+from .ir import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/_ffi_api.py
similarity index 84%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/_ffi_api.py
index 555659d0c5..68811c9e01 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/_ffi_api.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+"""FFI APIs for tvm.script.ir_builder"""
+import tvm._ffi
 
-from . import tir
-
-from .parser import ir_module, from_source
+tvm._ffi._init_api("script.ir_builder", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py
new file mode 100644
index 0000000000..d8b965d03b
--- /dev/null
+++ b/python/tvm/script/ir_builder/base.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A generic IRBuilder across the TVM stack"""
+from typing import List, TypeVar
+
+from tvm._ffi import register_object as _register_object
+from tvm.runtime import Object as _Object
+
+from . import _ffi_api
+
+
+@_register_object("script.ir_builder.IRBuilderFrame")
+class IRBuilderFrame(_Object):
+    def __enter__(self) -> "IRBuilderFrame":
+        _ffi_api.IRBuilderFrameEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameExit(self)  # pylint: disable=no-member # type: ignore
+
+    def add_callback(self, callback) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderFrameAddCallback(  # pylint: disable=no-member # type: ignore
+            self, callback
+        )
+
+
+@_register_object("script.ir_builder.IRBuilder")
+class IRBuilder(_Object):
+    def __init__(self) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.IRBuilder  # pylint: disable=no-member # type: ignore
+        )
+
+    def __enter__(self) -> "IRBuilder":
+        _ffi_api.IRBuilderEnter(self)  # pylint: disable=no-member # type: ignore
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
+        _ffi_api.IRBuilderExit(self)  # pylint: disable=no-member # type: ignore
+
+    @staticmethod
+    def current() -> "IRBuilder":
+        return _ffi_api.IRBuilderCurrent()  # pylint: disable=no-member # type: ignore
+
+    def get(self) -> _Object:
+        return _ffi_api.IRBuilderGet(self)  # pylint: disable=no-member # type: ignore
+
+
+DefType = TypeVar("DefType", bound=_Object)
+
+
+def name(s: str, v: DefType) -> DefType:
+    return _ffi_api.IRBuilderName(s, v)  # pylint: disable=no-member # type: ignore
+
+
+def name_many(  # pylint: disable=invalid-name
+    s: List[str],
+    vs: List[DefType],
+) -> List[DefType]:
+    assert len(s) == len(vs)
+    return [name(i, v) for i, v in zip(s, vs)]
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/ir/__init__.py
index 555659d0c5..ebb9728737 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -14,8 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""Package tvm.script.ir_builder.ir"""
+from .frame import IRModuleFrame
+from .ir import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/ir/_ffi_api.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/ir/_ffi_api.py
index 555659d0c5..874cc278af 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/ir/_ffi_api.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+"""FFI APIs"""
+import tvm._ffi
 
-from . import tir
-
-from .parser import ir_module, from_source
+tvm._ffi._init_api("script.ir_builder.ir", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/ir_builder/ir/frame.py
similarity index 63%
copy from python/tvm/script/tir/__init__.py
copy to python/tvm/script/ir_builder/ir/frame.py
index 2f2b4bbc25..e16d86dc22 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/frame.py
@@ -14,18 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+"""Package tvm.script.ir_builder.ir.frame"""
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
+from tvm._ffi import register_object as _register_object
 
-from .prim_func import prim_func
+from ..base import IRBuilderFrame
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+@_register_object("script.ir_builder.IRModuleFrame")
+class IRModuleFrame(IRBuilderFrame):
+    ...
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/ir/ir.py
similarity index 79%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/ir/ir.py
index 555659d0c5..df92036435 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -14,8 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+"""Package tvm.script.ir_builder.ir.ir"""
 
-from . import tir
+from . import _ffi_api
+from .frame import IRModuleFrame
 
-from .parser import ir_module, from_source
+
+def ir_module() -> IRModuleFrame:
+    return _ffi_api.IRModule()  # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/tir/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/tir/__init__.py
index 555659d0c5..1e43d1af34 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/tir/__init__.py
@@ -14,8 +14,5 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""Package tvm.script.ir_builder.tir"""
+from .ir import *  # pylint: disable=wildcard-import,redefined-builtin
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/ir_builder/tir/_ffi_api.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/ir_builder/tir/_ffi_api.py
index 555659d0c5..876f5f3a35 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/ir_builder/tir/_ffi_api.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+"""FFI APIs"""
+import tvm._ffi
 
-from . import tir
-
-from .parser import ir_module, from_source
+tvm._ffi._init_api("script.ir_builder.tir", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py
new file mode 100644
index 0000000000..22b03ccdd4
--- /dev/null
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IRBuilder for TIR"""
+from typing import List
+
+from tvm._ffi import register_object as _register_object
+from tvm.tir import Buffer, Var
+
+from ..base import IRBuilderFrame
+
+
+@_register_object("script.ir_builder.tir.TIRFrame")
+class TIRFrame(IRBuilderFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.BlockFrame")
+class BlockFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.ForFrame")
+class ForFrame(TIRFrame):
+    def __enter__(self) -> List[Var]:
+        super().__enter__()
+        return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("script.ir_builder.tir.PrimFuncFrame")
+class PrimFuncFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("script.ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
+
+
+@_register_object("script.ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+    ...
+
+
+@_register_object("script.ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+    def __enter__(self) -> Buffer:
+        super().__enter__()
+        return self.buffer
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
new file mode 100644
index 0000000000..ebd764cf1d
--- /dev/null
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -0,0 +1,954 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""IRBuilder for TIR"""
+import functools
+import inspect
+from numbers import Integral
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
+from tvm.tir import Broadcast as broadcast
+from tvm.tir import (
+    Buffer,
+    BufferLoad,
+    BufferRegion,
+    Cast,
+    CommReducer,
+    IntImm,
+    IterVar,
+    Let,
+    PrimExpr,
+)
+from tvm.tir import Ramp as ramp
+from tvm.tir import Select, Shuffle, StringImm, Var, cast
+from tvm.tir import op as _tir_op
+from tvm.tir import type_annotation
+
+from . import _ffi_api, frame
+
+
+def buffer_decl(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.BufferDecl(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def ptr(dtype, storage_scope="global"):
+    return _ffi_api.Ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+
+def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
+    return _ffi_api.Block(name, no_realize)  # pylint: disable=no-member # type: ignore
+
+
+def init() -> frame.BlockInitFrame:
+    return _ffi_api.Init()  # pylint: disable=no-member # type: ignore
+
+
+def where(predicate) -> None:
+    if isinstance(predicate, bool):
+        predicate = IntImm("bool", predicate)
+    if isinstance(predicate, int):
+        if predicate in [0, 1]:
+            predicate = IntImm("bool", predicate)
+        else:
+            raise ValueError("Invalid value for predicate: {}".format(predicate))
+    _ffi_api.Where(predicate)  # pylint: disable=no-member # type: ignore
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]  # type: ignore
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Reads(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore
+        else:
+            buffer_slices = [buffer_slices[0]]
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore
+    _ffi_api.Writes(buffer_slices)  # pylint: disable=no-member # type: ignore
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.BlockAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def alloc_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.AllocBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def _as_range(dom) -> Range:
+    if isinstance(dom, Range):
+        return dom
+    if isinstance(dom, (list, tuple)):
+        return Range(dom[0], dom[1])
+    return Range(0, dom)
+
+
+class axis:  # pylint: disable=invalid-name
+    @staticmethod
+    def spatial(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisSpatial(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def reduce(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisReduce(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def scan(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisScan(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def opaque(dom, binding, dtype="int32") -> IterVar:
+        return _ffi_api.AxisOpaque(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def remap(kinds, bindings, dtype="int32") -> Union[List[IterVar], IterVar]:
+        iter_vars = _ffi_api.AxisRemap(  # pylint: disable=no-member # type: ignore
+            kinds, bindings, dtype
+        )
+        return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+    S = spatial  # pylint: disable=invalid-name
+    R = reduce  # pylint: disable=invalid-name
+
+
+def serial(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Serial(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def parallel(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Parallel(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def vectorized(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Vectorized(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def unroll(start, stop=None, *, annotations=None) -> frame.ForFrame:
+    if stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.Unroll(start, stop, annotations)  # pylint: disable=no-member # type: ignore
+
+
+def thread_binding(
+    start,
+    stop=None,
+    thread=None,
+    *,
+    annotations=None,
+) -> frame.ForFrame:
+    if thread is None:
+        if not isinstance(stop, str):
+            raise ValueError("Thread cannot be None for thread_binding")
+        thread = stop
+        stop = start
+        start = 0
+    elif stop is None:
+        stop = start
+        start = 0
+    return _ffi_api.ThreadBinding(  # pylint: disable=no-member # type: ignore
+        start, stop, thread, annotations
+    )
+
+
+def grid(*extents) -> frame.ForFrame:
+    return _ffi_api.Grid(extents)  # pylint: disable=no-member # type: ignore
+
+
+def prim_func() -> frame.PrimFuncFrame:
+    return _ffi_api.PrimFunc()  # pylint: disable=no-member # type: ignore
+
+
+def arg(name, obj):
+    return _ffi_api.Arg(name, obj)  # pylint: disable=no-member # type: ignore
+
+
+def func_name(name: str) -> str:
+    return _ffi_api.FuncName(name)  # pylint: disable=no-member # type: ignore
+
+
+def func_attr(attrs: Dict[str, Any]) -> None:
+    return _ffi_api.FuncAttrs(attrs)  # pylint: disable=no-member # type: ignore
+
+
+def func_ret(ret_type) -> Type:
+    return _ffi_api.FuncRet(ret_type)  # pylint: disable=no-member # type: ignore
+
+
+def match_buffer(
+    param,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> Buffer:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.MatchBuffer(  # pylint: disable=no-member # type: ignore
+        param,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def preflattened_buffer(
+    postflattened,
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="global",
+    align=-1,
+    offset_factor=0,
+    buffer_type="default",
+    axis_separators=None,
+) -> None:
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    _ffi_api.PreflattenedBuffer(  # pylint: disable=no-member # type: ignore
+        postflattened,
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Assert(condition, message)  # pylint: disable=no-member # type: ignore
+
+
+def let(
+    v: Var,
+    value: PrimExpr,
+    body: PrimExpr = None,
+) -> frame.LetFrame:
+    if body is None:
+        return _ffi_api.Let(v, value)  # pylint: disable=no-member # type: ignore
+    return Let(v, value, body)
+
+
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str = "",
+    condition: PrimExpr = None,
+    annotations=None,
+) -> frame.AllocateFrame:
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.Allocate(  # pylint: disable=no-member # type: ignore
+        extents, dtype, scope, condition, annotations
+    )
+
+
+def allocate_const(
+    data: List[PrimExpr],
+    dtype: str,
+    extents: List[PrimExpr],
+    annotations=None,
+) -> frame.AllocateConstFrame:
+
+    return _ffi_api.AllocateConst(  # pylint: disable=no-member # type: ignore
+        ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+    )
+
+
+def realize(
+    buffer_slice: BufferRegion,
+    storage_scope: str,
+    condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+    return _ffi_api.Realize(  # pylint: disable=no-member # type: ignore
+        buffer_slice, storage_scope, condition
+    )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+    node = convert(node)
+    value = convert(value)
+    return _ffi_api.Attr(node, attr_key, value)  # pylint: disable=no-member # type: ignore
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.While(condition)  # pylint: disable=no-member # type: ignore
+
+
+def If(condition: PrimExpr) -> frame.IfFrame:  # pylint: disable=invalid-name
+    if isinstance(condition, bool):
+        condition = IntImm("bool", condition)
+    return _ffi_api.If(condition)  # pylint: disable=no-member # type: ignore
+
+
+def Then() -> frame.ThenFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Then()  # pylint: disable=no-member # type: ignore
+
+
+def Else() -> frame.ElseFrame:  # pylint: disable=invalid-name
+    return _ffi_api.Else()  # pylint: disable=no-member # type: ignore
+
+
+def decl_buffer(
+    shape,
+    dtype="float32",
+    data=None,
+    strides=None,
+    elem_offset=None,
+    scope="",
+    align=0,
+    offset_factor=0,
+    buffer_type="",
+    axis_separators=None,
+) -> frame.DeclBufferFrame:
+
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    return _ffi_api.DeclBuffer(  # pylint: disable=no-member # type: ignore
+        shape,
+        dtype,
+        "",
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
+
+
+def launch_thread(
+    iter_var: IterVar,  # pylint: disable=redefined-outer-name
+    extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+    return _ffi_api.LaunchThread(iter_var, extent)  # pylint: disable=no-member # type: ignore
+
+
+def env_thread(thread_tag: str) -> IterVar:
+    return _ffi_api.EnvThread(thread_tag)  # pylint: disable=no-member # type: ignore
+
+
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+    from tvm.arith import Analyzer  # pylint: disable=import-outside-toplevel
+
+    expr_indices = []
+    for index in indices:
+        if isinstance(index, slice):
+            step = 1 if index.step is None else index.step
+            lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+            if lanes == 1:
+                expr_indices.append(index.start)
+            else:
+                expr_indices.append(ramp(index.start, step, int(lanes)))
+        else:
+            expr_indices.append(index)
+    if isinstance(value, bool) and buffer.dtype == "bool":
+        value = IntImm("bool", value)
+    return _ffi_api.BufferStore(  # pylint: disable=no-member # type: ignore
+        buffer, value, expr_indices
+    )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+    return _ffi_api.Prefetch(buffer, indices)  # pylint: disable=no-member # type: ignore
+
+
+def evaluate(value: PrimExpr) -> None:
+    if isinstance(value, str):
+        value = StringImm(value)
+    return _ffi_api.Evaluate(value)  # pylint: disable=no-member # type: ignore
+
+
+def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.UInt64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float32(expr)  # pylint: disable=no-member # type: ignore
+
+
+def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    if not isinstance(expr, PrimExpr):
+        expr = convert(expr)
+    return _ffi_api.Float64(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x4(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x8(expr)  # pylint: disable=no-member # type: ignore
+
+
+def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Int32x16(expr)  # pylint: disable=no-member # type: ignore
+
+
+def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Boolean(expr)  # pylint: disable=no-member # type: ignore
+
+
+def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Handle(expr)  # pylint: disable=no-member # type: ignore
+
+
+def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
+    return _ffi_api.Void(expr)  # pylint: disable=no-member # type: ignore
+
+
+def min(a, b):  # pylint: disable=redefined-builtin
+    """Compute the minimum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.min(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def max(a, b):  # pylint: disable=redefined-builtin
+    """Compute the maximum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.max(a, b)  # pylint: disable=no-member # type: ignore
+
+
+def var(dtype, name="") -> Var:
+    return Var(name, dtype)  # pylint: disable=no-member # type: ignore
+
+
+def iter_var(v, dom, iter_type, thread_tag):
+    iter_type = getattr(IterVar, iter_type)
+    return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner, identity):
+    """Create a CommReducer from lambda inputs/outputs and the identities"""
+    params = inspect.signature(combiner).parameters
+    num_args = len(params)
+    args = []
+    for name, i in zip(params.keys(), identity + identity):
+        if isinstance(i, int):
+            args.append(Var(name, "int32"))
+        else:
+            args.append(Var(name, i.dtype))
+    res = combiner(*args)
+    if not isinstance(res, tuple):
+        res = (res,)
+    return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def llvm_lookup_intrinsic_id(name):
+    # pylint: disable=import-outside-toplevel
+    from tvm.target.codegen import llvm_lookup_intrinsic_id as f
+
+    # pylint: enable=import-outside-toplevel
+    return f(name)
+
+
+def _op_wrapper(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            kwargs.pop("dtype")
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+def _dtype_forward(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            args = (kwargs.pop("dtype"),) + args
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+# pylint: disable=invalid-name
+
+buffer_var = ptr
+abs = _op_wrapper(_tir_op.abs)  # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round)  # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+vectorlow = _dtype_forward(_tir_op.vectorlow)
+vectorhigh = _dtype_forward(_tir_op.vectorhigh)
+vectorcombine = _dtype_forward(_tir_op.vectorcombine)
+assume = _op_wrapper(_tir_op.assume)
+undef = _op_wrapper(_tir_op.undef)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+
+class inline:
+    def __init__(self, value) -> None:
+        self.value = value
+        self.i = 0
+
+    def __iter__(self):
+        def f():
+            for i in self.value:
+                yield inline(i)
+
+        return f()
+
+
+# pylint: enable=invalid-name
+
+
+__all__ = [
+    "Assert",
+    "Cast",
+    "Else",
+    "If",
+    "Let",
+    "Select",
+    "Shuffle",
+    "TVMBackendAllocWorkspace",
+    "TVMBackendFreeWorkspace",
+    "Then",
+    "While",
+    "abs",
+    "acos",
+    "acosh",
+    "address_of",
+    "alloc_buffer",
+    "allocate",
+    "allocate_const",
+    "arg",
+    "asin",
+    "asinh",
+    "assume",
+    "atan",
+    "atan2",
+    "atanh",
+    "attr",
+    "axis",
+    "block",
+    "block_attr",
+    "boolean",
+    "broadcast",
+    "buffer_decl",
+    "buffer_store",
+    "buffer_var",
+    "call_cpacked",
+    "call_cpacked_lowered",
+    "call_extern",
+    "call_intrin",
+    "call_llvm_intrin",
+    "call_llvm_pure_intrin",
+    "call_packed",
+    "call_packed_lowered",
+    "call_pure_extern",
+    "cast",
+    "ceil",
+    "ceildiv",
+    "clz",
+    "comm_reducer",
+    "copysign",
+    "cos",
+    "cosh",
+    "env_thread",
+    "erf",
+    "evaluate",
+    "exp",
+    "exp10",
+    "exp2",
+    "decl_buffer",
+    "fabs",
+    "float16",
+    "float32",
+    "float64",
+    "float8",
+    "floor",
+    "floordiv",
+    "floormod",
+    "fmod",
+    "func_attr",
+    "func_name",
+    "func_ret",
+    "grid",
+    "handle",
+    "hypot",
+    "if_then_else",
+    "infinity",
+    "init",
+    "inline",
+    "int16",
+    "int32",
+    "int32x16",
+    "int32x4",
+    "int32x8",
+    "int64",
+    "int8",
+    "isfinite",
+    "isinf",
+    "isnan",
+    "isnullptr",
+    "iter_var",
+    "launch_thread",
+    "ldexp",
+    "let",
+    "likely",
+    "llvm_lookup_intrinsic_id",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "lookup_param",
+    "match_buffer",
+    "max",
+    "max_value",
+    "min",
+    "min_value",
+    "mma_fill",
+    "mma_store",
+    "nearbyint",
+    "nextafter",
+    "parallel",
+    "popcount",
+    "power",
+    "prefetch",
+    "preflattened_buffer",
+    "prim_func",
+    "ptr",
+    "ptx_commit_group",
+    "ptx_cp_async",
+    "ptx_ldmatrix",
+    "ptx_mma",
+    "ptx_mma_sp",
+    "ptx_wait_group",
+    "q_multiply_shift",
+    "ramp",
+    "reads",
+    "realize",
+    "reinterpret",
+    "ret",
+    "round",
+    "rsqrt",
+    "serial",
+    "shift_left",
+    "shift_right",
+    "sigmoid",
+    "sin",
+    "sinh",
+    "sqrt",
+    "tan",
+    "tanh",
+    "thread_binding",
+    "trunc",
+    "truncdiv",
+    "truncmod",
+    "tvm_access_ptr",
+    "tvm_bmma_sync",
+    "tvm_call_cpacked",
+    "tvm_call_cpacked_lowered",
+    "tvm_call_packed",
+    "tvm_call_packed_lowered",
+    "tvm_fill_fragment",
+    "tvm_load_matrix_sync",
+    "tvm_mma_sync",
+    "tvm_stack_alloca",
+    "tvm_stack_make_array",
+    "tvm_stack_make_shape",
+    "tvm_store_matrix_sync",
+    "tvm_struct_get",
+    "tvm_struct_set",
+    "tvm_thread_allreduce",
+    "tvm_throw_last_error",
+    "tvm_tuple",
+    "type_annotation",
+    "uint16",
+    "uint32",
+    "uint64",
+    "uint8",
+    "undef",
+    "unroll",
+    "var",
+    "vectorcombine",
+    "vectorhigh",
+    "vectorized",
+    "vectorlow",
+    "void",
+    "where",
+    "writes",
+]
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/meta_unparser.py
deleted file mode 100644
index b1472ccdc7..0000000000
--- a/python/tvm/script/meta_unparser.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Unparse meta AST node into a dict"""
-# pylint: disable=invalid-name
-
-from synr import Transformer
-
-
-class MetaUnparser(Transformer):
-    """Python AST Visitor to unparse meta AST node into a dict"""
-
-    def transform(self, node):
-        method = "transform_" + node.__class__.__name__
-        visitor = getattr(self, method, None)
-        if visitor is None:
-            self.error(f"Unexpected node type {type(node)} when parsing __tvm_meta__", node.span)
-        return visitor(node)
-
-    def transform_DictLiteral(self, node):
-        keys = [self.visit(key) for key in node.keys]
-        values = [self.visit(value) for value in node.values]
-        return dict(zip(keys, values))
-
-    def transform_Tuple(self, node):
-        return tuple(self.visit(element) for element in node.elts)
-
-    def transform_ArrayLiteral(self, node):
-        return [self.visit(element) for element in node.elts]
-
-    def transform_Constant(self, node):
-        return node.value
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
deleted file mode 100644
index 908af081c9..0000000000
--- a/python/tvm/script/parser.py
+++ /dev/null
@@ -1,1385 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser For TIR
-
-We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
-different python versions. Synr also provides an error handling context that we
-use for error reporting.
-"""
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
-import json
-import operator
-import inspect
-from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
-
-import tvm
-from tvm import IRModule
-from tvm._ffi.base import TVMError
-from tvm.ir import GlobalVar
-from tvm.ir.function import BaseFunc
-from tvm.tir import buffer
-from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
-
-from .context_maintainer import ContextMaintainer
-from .meta_unparser import MetaUnparser
-from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
-from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
-from .tir.special_stmt import SpecialStmt
-from .tir import ty
-
-
-class CallArgumentReader(object):
-    """Helper class to read required arguments from passed arguments.
-
-    When parsing a function call, we need to match the arguments provided in
-    the AST to the required arguments of the function. This class makes sure
-    all the positional arguments are filled and also fill keyword arguments
-    with thier default value if a different value was not provided.
-    """
-
-    def __init__(self, func_name, args, kwargs, parser, node):
-        self.func_name = func_name
-        self.args = args
-        self.kwargs = kwargs
-        self.parser = parser
-        self.node = node
-
-    def get_pos_only_arg(self, pos, name):
-        """Get corresponding position only function argument from argument list"""
-        if len(self.args) >= pos:
-            arg = self.args[pos - 1]
-        elif name not in self.kwargs:
-            # If no positional argument was found in the AST, we see if it was
-            # defined by name instead.
-            # TODO(tkonolige): this error message is not quite correct. The
-            # number of required arguments is >= pos
-            self.parser.report_error(
-                f"{self.func_name} requires {pos} arguments, but only {len(self.args)} were given.",
-                self.node.span,
-            )
-        else:
-            arg = self.kwargs[name]
-
-        return arg
-
-    def get_kwarg(self, pos, name, default):
-        """Get corresponding keyword function argument from argument list.
-
-        If the user hasn't provided the argument, set it to the default value.
-        """
-        if len(self.args) >= pos:
-            arg = self.args[pos - 1]
-        elif name in self.kwargs:
-            arg = self.kwargs[name]
-        else:
-            return default
-
-        return arg
-
-    def get_varargs(self, pos):
-        """Get corresponding variable argument from argument list"""
-        if len(self.args) >= pos and len(self.kwargs) == 0:
-            return self.args[pos - 1 :]
-        return []
-
-
-class TVMScriptParser(Transformer):
-    """Synr AST visitor pass which finally lowers to TIR.
-
-    Notes for Extension
-    -------------------
-    1. To support a new type of AST node, add a function transform_xxx().
-    2. To support new functions, add the function to the appropriate registry:
-        We divide allowed function calls in TVM script into 3 categories,
-        intrin, scope_handler and special_stmt.
-        1. intrin functions are low level functions like mod, load, and
-           constants. They correspond to a tir `IRNode`. They must have a
-           return value. The user can register intrin functions for the parser to
-           use.
-        2. scope_handler functions have no return value. They take two
-           arguments: the parser and the AST node. scope_handler functions are
-           used in with and for statements.
-        3. special_stmt functions handle cases that do not have a corresponding
-           tir `IRNode`. These functions take the parser and the AST node as
-           arguments and may return a value.
-        When visiting a Call node, we check the special_stmt registry first. If
-        no registered function is found, we then check the intrin registry.
-        When visiting With node, we check the with_scope registry.
-        When visiting For node, we check the for_scope registry.
-    """
-
-    _binop_maker = {
-        ast.BuiltinOp.Add: tvm.tir.Add,
-        ast.BuiltinOp.Sub: tvm.tir.Sub,
-        ast.BuiltinOp.Mul: tvm.tir.Mul,
-        ast.BuiltinOp.Div: tvm.tir.Div,
-        ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv,
-        ast.BuiltinOp.Mod: tvm.tir.FloorMod,
-        ast.BuiltinOp.BitOr: lambda lhs, rhs, span: operator.or_(lhs, rhs),
-        ast.BuiltinOp.BitAnd: lambda lhs, rhs, span: operator.and_(lhs, rhs),
-        ast.BuiltinOp.BitXor: lambda lhs, rhs, span: operator.xor(lhs, rhs),
-        ast.BuiltinOp.GT: tvm.tir.GT,
-        ast.BuiltinOp.GE: tvm.tir.GE,
-        ast.BuiltinOp.LT: tvm.tir.LT,
-        ast.BuiltinOp.LE: tvm.tir.LE,
-        ast.BuiltinOp.Eq: tvm.tir.EQ,
-        ast.BuiltinOp.NotEq: tvm.tir.NE,
-        ast.BuiltinOp.And: tvm.tir.And,
-        ast.BuiltinOp.Or: tvm.tir.Or,
-    }
-
-    _unaryop_maker = {
-        ast.BuiltinOp.USub: lambda rhs, span: operator.neg(rhs),
-        ast.BuiltinOp.Invert: lambda rhs, span: operator.invert(rhs),
-        ast.BuiltinOp.Not: tvm.tir.Not,
-    }
-
-    # pylint gets confused here with synr.Transformer which doesn't have a
-    # custom init, so just disable it
-    def __init__(
-        self, base_lineno, tir_namespace, closure_vars
-    ):  # pylint: disable=super-init-not-called
-        self.context = None
-
-        self.base_lineno = base_lineno
-        self.current_lineno = 0
-        self.current_col_offset = 0
-        self.tir_namespace = tir_namespace
-        self.closure_vars = closure_vars
-        self.meta = None
-        self._inside_buffer_sugar = False
-
-    def init_function_parsing_env(self):
-        """Initialize function parsing environment"""
-        self.context = ContextMaintainer(self.report_error, self.closure_vars)  # scope emitter
-
-    def init_meta(self, meta_dict):
-        if meta_dict is not None:
-            self.meta = tvm.ir.load_json(json.dumps(meta_dict))
-
-    def transform(self, node):
-        """Generic transformation for visiting the AST. Dispatches to
-        `transform_ClassName` for the appropriate ClassName."""
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-
-        if hasattr(node, "lineno"):
-            self.current_lineno = self.base_lineno + node.lineno - 1
-        if hasattr(node, "col_offset"):
-            self.current_col_offset = node.col_offset
-
-        method = "transform_" + node.__class__.__name__
-        visitor = getattr(self, method, self.generic_visit)
-        transform_res = visitor(node)
-
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-
-        return transform_res
-
-    def match_tir_namespace(self, identifier: str) -> bool:
-        """Check if the namespace is equal to tvm.script.tir"""
-        return identifier in self.tir_namespace
-
-    def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]):
-        """Report an error occuring at a location.
-
-        This just dispatches to synr's DiagnosticContext.
-
-        Parameters
-        ----------
-        message : str
-            Error message
-        span : Union[synr.ast.Span, tvm.ir.Span]
-            Location of the error
-        """
-        if isinstance(span, tvm.ir.Span):
-            span = synr_span_from_tvm(span)
-        self.error(message, span)
-
-    def parse_body(self, parent):
-        """Parse remaining statements in this scope.
-
-        Parameters
-        ----------
-        parent : synr.ast.Node
-            Parent node of this scope. Errors will be reported here.
-        """
-        body = []
-        spans = []
-        stmt = parent
-        while len(self.context.node_stack[-1]) > 0:
-            stmt = self.context.node_stack[-1].pop()
-            spans.append(stmt.span)
-            res = self.transform(stmt)
-            if res is not None:
-                body.append(res)
-        if len(body) == 0:
-            self.report_error(
-                "Expected another statement at the end of this block. Perhaps you "
-                "used a concise statement and forgot to include a body afterwards.",
-                stmt.span,
-            )
-        else:
-            return (
-                tvm.tir.SeqStmt(body, tvm_span_from_synr(ast.Span.union(spans)))
-                if len(body) > 1
-                else body[0]
-            )
-
-    def parse_arg_list(self, func, node_call):
-        """Match the arguments of a function call in the AST to the required
-        arguments of the function. This handles positional arguments,
-        positional arguments specified by name, keyword arguments, and varargs.
-
-        Parameters
-        ----------
-        func : Function
-            The function that provides the signature
-
-        node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
-            The AST call node that calls into the function.
-
-        Returns
-        -------
-        arg_list : list
-            The parsed positional argument.
-        """
-        assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
-        # collect arguments
-        args = [self.transform(arg) for arg in node_call.params]
-        if isinstance(node_call, ast.TypeApply):
-            kw_args = {}  # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
-        else:
-            kw_args = {
-                self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
-            }
-        # get the name and parameter list of func
-        if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
-            func_name, param_list = func.signature()
-        else:
-            self.report_error(
-                "Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, "
-                f"but it is {type(func).__name__}",
-                node_call.span,
-            )
-        # check arguments and parameter list and get a list of arguments
-        reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
-        pos_only, kwargs, varargs = param_list
-        internal_args = list()
-
-        for i, arg_name in enumerate(pos_only):
-            internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
-        for i, arg_info in enumerate(kwargs):
-            arg_name, default = arg_info
-            internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default))
-        if varargs is not None:
-            internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1))
-        elif len(args) + len(kw_args) > len(pos_only) + len(kwargs):
-            self.report_error(
-                "Arguments mismatched. "
-                + f"Expected {len(pos_only) + len(kwargs)} args but got "
-                + f"{len(args) + len(kw_args)}",
-                node_call.span,
-            )
-        return internal_args
-
-    def parse_type(self, type_node, parent):
-        """Parse a type annotation.
-
-        We require the parent object to the type so that we have a place to
-        report the error message if the type does not exist.
-        """
-        if type_node is None:
-            self.report_error("A type annotation is required", parent.span)
-        res_type = self.transform(type_node)
-        return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate()
-
-    def generic_visit(self, node):
-        """Fallback visitor if node type is not handled. Reports an error."""
-
-        self.report_error(type(node).__name__ + " AST node is not supported", node.span)
-
-    def transform_Module(self, node):
-        """Module visitor
-
-        Right now, we only support two formats for TVM Script.
-
-        Example
-        -------
-        1. Generate a PrimFunc (If the code is printed, then it may also contain metadata)
-        .. code-block:: python
-
-            import tvm
-
-            @tvm.script
-            def A(...):
-                ...
-
-            # returns a PrimFunc
-            func = A
-
-        2. Generate an IRModule
-        .. code-block:: python
-
-            import tvm
-
-            @tvm.script.ir_module
-            class MyMod():
-                @T.prim_func
-                def A(...):
-                    ...
-                @T.prim_func
-                def B(...):
-                    ...
-
-                __tvm_meta__ = ...
-
-            # returns an IRModule
-            mod = MyMod
-        """
-        if len(node.funcs) == 1:
-            return self.transform(next(iter(node.funcs.values())))
-        elif len(node.funcs) == 0:
-            self.report_error(
-                "You must supply at least one class or function definition", node.span
-            )
-        else:
-            self.report_error(
-                "Only one-function, one-class or function-with-meta source code is allowed",
-                ast.Span.union([x.span for x in list(node.funcs.values())[1:]]),
-            )
-
-    def transform_Class(self, node):
-        """Class definition visitor.
-
-        A class can have multiple function definitions and a single
-        :code:`__tvm_meta__` statement. Each class corresponds to a single
-        :code:`IRModule`.
-
-        Example
-        -------
-        .. code-block:: python
-
-            @tvm.script.ir_module
-            class MyClass:
-                __tvm_meta__ = {}
-                def A():
-                    T.evaluate(0)
-        """
-        if len(node.assignments) == 1:
-            if not (
-                len(node.assignments[0].lhs) == 1
-                and isinstance(node.assignments[0].lhs[0], ast.Var)
-                and node.assignments[0].lhs[0].id.name == "__tvm_meta__"
-            ):
-                self.report_error(
-                    "The only top level assignments allowed are `__tvm_meta__ = ...`",
-                    node.assignments[0].span,
-                )
-            self.init_meta(
-                MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context)
-            )
-        elif len(node.assignments) > 1:
-            self.report_error(
-                "Only a single top level `__tvm_meta__` is allowed",
-                ast.Span.union([x.span for x in node.assignments[1:]]),
-            )
-
-        return IRModule(
-            {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()}
-        )
-
-    def transform_Function(self, node):
-        """Function definition visitor.
-
-        Each function definition is translated to a single :code:`PrimFunc`.
-
-        There are a couple restrictions on TVM Script functions:
-        1. Function arguments must have their types specified.
-        2. The body of the function can contain :code:`func_attr` to specify
-           attributes of the function (like it's name).
-        3. The body of the function can also contain multiple :code:`buffer_bind`s,
-           which give shape and dtype information to arguments.
-        4. Return statements are implicit.
-
-        Example
-        -------
-        .. code-block:: python
-
-            @T.prim_func
-            def my_function(x: T.handle):  # 1. Argument types
-                T.func_attr({"global_symbol": "mmult"})  # 2. Function attributes
-                X_1 = tir.buffer_bind(x, [1024, 1024])  # 3. Buffer binding
-                T.evaluate(0)  # 4. This function returns 0
-        """
-
-        def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]):
-            if isinstance(decorator, ast.Call):
-                if len(decorator.params) != 1:
-                    return False
-                func_name = decorator.func_name
-            else:
-                func_name = decorator
-            if isinstance(func_name, ast.Var):
-                return func_name.id.name == "as_torch"
-
-        def check_decorator(decorators: List[ast.Expr]) -> bool:
-            """Check the decorator is `T.prim_func"""
-            if len(decorators) > 2 or len(decorators) == 0:
-                return False
-            if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]):
-                return False
-            d: ast.Expr = decorators[-1]
-            return (
-                isinstance(d, ast.Attr)
-                and isinstance(d.object, ast.Var)
-                and self.match_tir_namespace(d.object.id.name)
-                and d.field.name == "prim_func"
-            )
-
-        self.init_function_parsing_env()
-        self.context.enter_scope(nodes=node.body.stmts)
-
-        # add parameters of function
-        for arg in node.params:
-            # Note that this case is for T.match_buffer syntax sugar
-            if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
-                self.transform(arg.ty.func_name), ty.GenericBufferType
-            ):
-                result = self.handle_match_buffer_type(arg.ty, arg.name)
-                if not isinstance(result, buffer.Buffer):
-                    self.report_error(
-                        "The result type of evaluating TypeCall and TypeApply stmt"
-                        f" is wrong: {type(result)}. It should be a Buffer",
-                        node.span,
-                    )
-                arg_name_with_handle = arg.name + "_handle"
-                arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
-                self.context.func_buffer_map[arg_var] = result
-                self.context.update_symbol(arg.name, result, node)
-            else:
-                arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
-                self.context.update_symbol(arg.name, arg_var, node)
-            self.context.func_params.append(arg_var)
-
-        if not check_decorator(node.decorators):
-            self.report_error(
-                "All functions should be decorated by `T.prim_func`",
-                node.span,
-            )
-
-        # fetch the body of root block
-        body = self.parse_body(node.body)
-
-        # return a tir.PrimFunc
-        dict_attr = self.context.func_dict_attr
-        ret_type = self.parse_type(node.ret_type, node) if node.ret_type is not None else None
-        func = tvm.tir.PrimFunc(
-            self.context.func_params,
-            body,
-            ret_type,
-            buffer_map=self.context.func_buffer_map,
-            preflattened_buffer_map=self.context.func_preflattened_buffer_map,
-            attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None,
-            span=tvm_span_from_synr(node.span),
-        )
-
-        # New Scope : Implicit root block
-        # Each function contains an implicit root block in TensorIR,
-        # so here we need a block scope for it.
-        # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func),
-        # the root block will not be added. The logic to add root block is in `_ffi_api.Complete`
-
-        # Fix the PrimFunc
-        # 1. generate root block if necessary
-        # 2. generate surrounding loops for blocks if necessary
-
-        func = call_with_error_reporting(
-            self.report_error,
-            node.span,
-            _ffi_api.Complete,
-            func,
-            self.context.root_alloc_buffers,
-        )
-
-        self.context.exit_scope()
-        return func
-
-    def transform_Lambda(self, node):
-        """Lambda visitor
-
-        Return an array of input parameters and the transformed lambda body.
-        """
-
-        self.context.enter_scope(nodes=[node.body])
-
-        # add parameters of the lambda
-        arg_vars = []
-        for arg in node.params:
-            # Use "void" for dtype here. The actual type is not yet known and will be
-            # determined later. Using void type will allow IRSubstitute to do the
-            # replacement without flagging a type-mismatch error.
-            arg_var = tvm.te.var(arg.name, dtype="")
-            arg_vars.append(arg_var)
-            self.context.update_symbol(arg.name, arg_var, node)
-
-        # the body of a lambda must be an expr
-        if not isinstance(node.body, ast.Expr):
-            self.report_error("The body of a lambda must be an expression", node.span)
-
-        # transform the body of the lambda
-        body = self.transform(node.body)
-
-        self.context.exit_scope()
-        return arg_vars, body
-
-    def transform_Assign(self, node):
-        """Assign visitor
-        AST abstract grammar:
-            Assign(expr* targets, expr value, string? type_comment)
-
-        By now 5 patterns of Assign is supported:
-            1. special stmts with return value
-                1.1 Buffer = T.match_buffer()/T.buffer_decl()
-                1.2 Var = T.var()
-                1.3 Var = T.env_thread()
-            2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
-            3. (Store)       Var[PrimExpr] = PrimExpr
-            4. with scope handlers with concise scoping and var def
-                4.1 var = T.allocate()
-            5. A call to a pure python function, consuming and producing TVMScript values.
-               The outputs are inlined into the following body (no variable is created).
-               x, y = f(...)
-        """
-
-        if isinstance(node.rhs, ast.Call):
-            # Pattern 1 & Pattern 4
-            if isinstance(node.rhs.func_name, ast.Op):
-                func = None
-            else:
-                func = self.transform(node.rhs.func_name)
-
-            if isinstance(func, WithScopeHandler):
-                if not func.concise_scope or not func.def_symbol:
-                    self.report_error(
-                        "with scope handler " + func.signature()[0] + " is not suitable here",
-                        node.rhs.span,
-                    )
-                # Pattern 4
-                arg_list = self.parse_arg_list(func, node.rhs)
-                func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-                func.body = self.parse_body(node)
-                return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-            elif isinstance(func, SpecialStmt):
-                # Pattern 1
-                arg_list = self.parse_arg_list(func, node.rhs)
-                func.handle(node, self.context, arg_list, node.rhs.func_name.span)
-                return self.parse_body(node)
-            elif isinstance(func, types.FunctionType):
-                # Pattern 5
-                args = [self.transform(arg) for arg in node.rhs.params]
-                try:
-                    out = func(*args)
-                except Exception as e:
-                    self.report_error(
-                        "Error occured when invoking the function "
-                        + func.__name__
-                        + ": \n"
-                        + str(e),
-                        node.rhs.span,
-                    )
-
-                if len(node.lhs) == 1 and not isinstance(out, list):
-                    out = [out]
-
-                assert len(out) == len(node.lhs)
-
-                for var, value in zip(node.lhs, out):
-                    self.context.update_symbol(var.id.name, value, node)
-
-                body = self.parse_body(node)
-
-                for var, value in zip(node.lhs, out):
-                    self.context.remove_symbol(var.id.name)
-
-                return body
-
-        if isinstance(node.rhs, (ast.Call, ast.Constant)):
-            # Pattern 4 of let binding
-            value = self.transform(node.rhs)
-            if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
-                # This is a little confusing because it only is true when
-                # we have taken this branch. We might need to clarify what
-                # exectly is allowed in Assignments in tvmscript.
-                self.report_error(
-                    "Left hand side of assignment must be an unqualified variable",
-                    node.span,
-                )
-            ast_var = node.lhs[0]
-
-            if node.ty is None and hasattr(value, "dtype"):
-                var_ty = value.dtype
-            else:
-                var_ty = self.parse_type(node.ty, ast_var)
-
-            var = tvm.te.var(
-                ast_var.id.name,
-                var_ty,
-                span=tvm_span_from_synr(ast_var.span),
-            )
-            self.context.update_symbol(var.name, var, node)
-            body = self.parse_body(node)
-            self.context.remove_symbol(var.name)
-            return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
-
-        self.report_error(
-            """Assignments should be one of:
-            1. A "special statement" with return value
-                1.1 Buffer = T.match_buffer()/T.buffer_decl()
-                1.2 Var = T.var()
-                1.3 Var = T.env_thread()
-            2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
-            3. A store into a variable: Var[PrimExpr] = PrimExpr
-            4. A with scope handler with concise scoping and var def
-                4.1 var = T.allocate()
-            5. The right-hand side being a call to a pure python function, consuming and
-               producing TVMScript values.
-               x, y = f(...)""",
-            node.span,
-        )
-
-    def transform_SubscriptAssign(self, node):
-        """Visitor for statements of the form :code:`x[1] = 2`."""
-        symbol = self.transform(node.params[0])
-        indexes = self.transform(node.params[1])
-        rhs = self.transform(node.params[2])
-        rhs_span = tvm_span_from_synr(node.params[2].span)
-        if isinstance(symbol, tvm.tir.Buffer):
-            if len(indexes) != len(symbol.shape):
-                self.report_error(
-                    f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, "
-                    f"cannot be indexed by {len(indexes)}-dimensional indices.",
-                    node.params[1].span,
-                )
-
-            def __convert_index(x):
-                if isinstance(x, Slice):
-                    return x.as_index_expr(self.report_error)
-                return x
-
-            # BufferStore
-            indexes = [__convert_index(x) for x in indexes]
-            return tvm.tir.BufferStore(
-                symbol,
-                tvm.runtime.convert(rhs, span=rhs_span),
-                indexes,
-                span=tvm_span_from_synr(node.span),
-            )
-        else:
-            if symbol.dtype == "handle" and len(indexes) != 1:
-                self.report_error(
-                    "Handles only support one-dimensional indexing. Use `T.match_buffer` to "
-                    "construct a multidimensional buffer from a handle.",
-                    node.params[0].span,
-                )
-            if len(indexes) != 1:
-                self.report_error(
-                    f"Store is only allowed with one index, but {len(indexes)} were provided.",
-                    node.params[1].span,
-                )
-            self.report_error(
-                "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span
-            )
-
-    def transform_AttrAssign(self, node):
-        """Visitor for statements of the form :code:`x.y = 2`."""
-        obj = self.transform(node.params[0])
-        field = node.params[1]
-        value = self.transform(node.params[2])
-
-        if not hasattr(obj, field.name):
-            self.error(f"Field {field.name} does not exist", field.span)
-
-        var = getattr(obj, field.name)
-
-        if not isinstance(var, tvm.tir.Var):
-            self.error(
-                f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span
-            )
-
-        body = self.parse_body(node)
-        return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
-
-    def transform_Assert(self, node):
-        """Assert visitor
-
-        Pattern corresponds to concise mode of :code:`with T.Assert()`.
-        """
-
-        condition = self.transform(node.condition)
-        if node.msg is None:
-            self.report_error("Assert statements must have an error message.", node.span)
-        message = self.transform(node.msg)
-        body = self.parse_body(node)
-        return tvm.tir.AssertStmt(
-            condition, tvm.runtime.convert(message), body, span=tvm_span_from_synr(node.span)
-        )
-
-    def transform_For(self, node):
-        """For visitor
-        AST abstract grammar:
-            For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
-        By now 1 pattern of For is supported:
-            1. for scope handler
-                for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/
-                            T.grid()/T.thread_binding()
-        """
-
-        if not isinstance(node.rhs, ast.Call):
-            self.report_error("The loop iterator should be a function call.", node.rhs.span)
-        func = self.transform(node.rhs.func_name)
-        if not isinstance(func, ForScopeHandler):
-            self.report_error(
-                "Only For scope handlers can be used in a for statement.", node.rhs.func_name.span
-            )
-        # prepare for new for scope
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno = node.span.start_line
-        self.current_col_offset = node.span.start_column
-        self.context.enter_scope(nodes=node.body.stmts)
-        # for scope handler process the scope
-        arg_list = [
-            tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span))
-            for arg in self.parse_arg_list(func, node.rhs)
-        ]
-        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        func.body = self.parse_body(node)
-        res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        # exit the scope
-        self.context.exit_scope()
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-        return res
-
-    def transform_While(self, node):
-        """While visitor
-        AST abstract grammar:
-            While(expr condition, stmt* body)
-        """
-        condition = self.transform(node.condition)
-        # body
-        self.context.enter_scope(nodes=node.body.stmts)
-        body = self.parse_body(node)
-        self.context.exit_scope()
-
-        return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span))
-
-    def transform_With(self, node):
-        """With visitor
-        AST abstract grammar:
-            With(withitem* items, stmt* body, string? type_comment)
-            withitem = (expr context_expr, expr? optional_vars)
-        By now 2 patterns of With is supported:
-            1. with scope handler with symbol def
-                with T.allocate() as targets:
-            2. with scope handler without symbol def
-                with T.block(*axes)/T.let()/T.Assert()/T.attr()/T.realize()
-        """
-
-        if not isinstance(node.rhs, ast.Call):
-            self.report_error(
-                "The context expression of a `with` statement should be a function call.",
-                node.rhs.span,
-            )
-
-        func = self.transform(node.rhs.func_name)
-
-        if not isinstance(func, WithScopeHandler):
-            self.report_error(
-                f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span
-            )
-        # prepare for new block scope
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno = node.body.span.start_line
-        self.current_col_offset = node.body.span.start_column
-        self.context.enter_block_scope(nodes=node.body.stmts)
-        # with scope handler process the scope
-        arg_list = self.parse_arg_list(func, node.rhs)
-        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        func.body = self.parse_body(node)
-        res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        # exit the scope
-        self.context.exit_block_scope()
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-        return res
-
-    def transform_If(self, node):
-        """If visitor
-        AST abstract grammar:
-            If(expr test, stmt* body, stmt* orelse)
-        """
-
-        condition = self.transform(node.condition)
-        # then body
-        self.context.enter_scope(nodes=node.true.stmts)
-        then_body = self.parse_body(node)
-        self.context.exit_scope()
-
-        # else body
-        if len(node.false.stmts) > 0:
-            self.context.enter_scope(nodes=node.false.stmts)
-            else_body = self.parse_body(node)
-            self.context.exit_scope()
-        else:
-            else_body = None
-
-        return tvm.tir.IfThenElse(
-            condition, then_body, else_body, span=tvm_span_from_synr(node.span)
-        )
-
-    def transform_Call(self, node):
-        """Call visitor
-
-        3 different Call patterns are allowed:
-            1. Intrin representing a PrimExpr/IterVar
-                1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max
-                1.2 tir.range/reduce_axis/scan_axis/opaque_axis
-            2. tir.Op(dtype, ...)
-            3. other callable functions
-        """
-
-        if isinstance(node.func_name, ast.Op):
-            if node.func_name.name == ast.BuiltinOp.Subscript:
-                return self.transform_Subscript(node)
-            if node.func_name.name in self._binop_maker:
-                lhs = self.transform(node.params[0])
-                # There is no supertype for everything that can appear in
-                # an expression, so we manually add what we might get here.
-                if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)):
-                    # We would really like to report a more specific
-                    # error here, but this parser contains no distinction
-                    # between parsing statements and parsing expressions. All
-                    # rules just call `transform`.
-                    self.report_error(
-                        f"Left hand side of binary op must be a PrimExpr, "
-                        "but it is a {type(lhs).__name__}",
-                        node.params[0].span,
-                    )
-                rhs = self.transform(node.params[1])
-                if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)):
-                    self.report_error(
-                        f"Right hand side of binary op must be a PrimExpr, "
-                        "but it is a {type(rhs).__name__}",
-                        node.params[1].span,
-                    )
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.span,
-                    lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name](
-                        lhs, rhs, span=span
-                    ),
-                    node,
-                    lhs,
-                    rhs,
-                    tvm_span_from_synr(node.span),
-                )
-            if node.func_name.name in self._unaryop_maker:
-                rhs = self.transform(node.params[0])
-                return self._unaryop_maker[node.func_name.name](
-                    rhs, span=tvm_span_from_synr(node.span)
-                )
-            self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span)
-        else:
-            func = self.transform(node.func_name)
-            if isinstance(func, Intrin) and not func.stmt:
-                # pattern 1
-                arg_list = self.parse_arg_list(func, node)
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.func_name.span,
-                    func.handle,
-                    arg_list,
-                    node.func_name.span,
-                )
-            else:
-                args = [self.transform(arg) for arg in node.params]
-                kw_args = {
-                    self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
-                }
-                if isinstance(func, tvm.tir.op.Op):
-                    if not "dtype" in kw_args.keys():
-                        self.report_error(f"{func} requires a dtype keyword argument.", node.span)
-                    # pattern 2
-                    return tvm.tir.Call(
-                        kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span)
-                    )
-                elif callable(func):
-                    # pattern 3
-                    return func(*args, **kw_args)
-                else:
-                    self.report_error(
-                        f"Function is neither callable nor a tvm.tir.op.Op (it is a {type(func)}).",
-                        node.func_name.span,
-                    )
-
-    def transform_UnassignedCall(self, node):
-        """Visitor for statements that are function calls.
-
-        This handles function calls that appear on thier own line like `tir.realize`.
-
-        Examples
-        --------
-        .. code-block:: python
-
-            @T.prim_func
-            def f():
-                A = T.buffer_decl([10, 10])
-                T.realize(A[1:2, 1:2], "")  # This is an UnassignedCall
-                A[1, 1] = 2  # This is also an UnassignedCall
-        """
-        # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
-        if isinstance(node.call.func_name, ast.Op):
-            if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign:
-                return self.transform_SubscriptAssign(node.call)
-
-            if node.call.func_name.name == ast.BuiltinOp.AttrAssign:
-                return self.transform_AttrAssign(node.call)
-
-            self.report_error(
-                "Binary and unary operators are not allowed as a statement", node.span
-            )
-
-        # handle a regular function call
-        func = self.transform(node.call.func_name)
-        arg_list = self.parse_arg_list(func, node.call)
-
-        if isinstance(func, tir.scope_handler.AssertHandler):
-            self.report_error(
-                "A standalone `T.Assert` is not allowed. Use `assert condition, message` "
-                "instead.",
-                node.call.func_name.span,
-            )
-
-        if isinstance(func, Intrin):
-            if func.stmt:
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.call.func_name.span,
-                    func.handle,
-                    arg_list,
-                    node.call.func_name.span,
-                )
-            else:
-                self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span)
-        elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
-            func.enter_scope(node, self.context, arg_list, node.call.func_name.span)
-            func.body = self.parse_body(node)
-            return func.exit_scope(node, self.context, arg_list, node.call.func_name.span)
-        elif isinstance(func, SpecialStmt) and not func.def_symbol:
-            func.handle(node, self.context, arg_list, node.call.func_name.span)
-            return
-
-        self.report_error(
-            "Unexpected statement. Expected an assert, an intrinsic, a with statement, or a "
-            f"special statement, but got {type(func).__name__}.",
-            node.call.func_name.span,
-        )
-
-    def transform_Slice(self, node):
-        """Index slice visitor."""
-        start = self.transform(node.start)
-        end = self.transform(node.end)
-        if not (
-            isinstance(node.step, ast.Constant)
-            and isinstance(node.step.value, int)
-            and node.step.value > 0
-        ):
-            self.report_error(
-                "Only positive integer step size is supported for slices.", node.step.span
-            )
-        return Slice(start, end, node.step.value, tvm_span_from_synr(node.span))
-
-    def transform_Subscript(self, node):
-        """Array access visitor.
-
-        By now only 3 types of Subscript are supported:
-            1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
-               Var[index] Buffer element access()
-            2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...]))
-            3. Array[index], Buffer element access
-        """
-
-        symbol = self.transform(node.params[0])
-        if symbol is None:
-            self.report_error(
-                f"Variable {node.params[0].id.name} is not defined.", node.params[0].span
-            )
-
-        indexes = [self.transform(x) for x in node.params[1].values]
-        if isinstance(symbol, tvm.tir.expr.Var):
-            if symbol.dtype == "handle":
-                self.report_error(
-                    "Cannot read directly from a handle, use `T.match_buffer` "
-                    "to create a buffer to read from.",
-                    node.params[0].span,
-                )
-            if len(indexes) > 1:
-                self.report_error(
-                    "Only a single index can be provided when indexing into a `var`.",
-                    node.params[1].span,
-                )
-            index = indexes[0]
-            if not isinstance(index, (tvm.tir.PrimExpr, int)):
-                self.report_error(
-                    "Var load index should be an int or PrimExpr, but it is a" + type(index),
-                    node.span,
-                )
-
-            self.report_error(
-                "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span
-            )
-        elif isinstance(symbol, tvm.tir.Buffer):
-            return BufferSlice(
-                symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
-            )
-        elif isinstance(symbol, tvm.container.Array):
-            if len(indexes) > 1:
-                self.report_error(
-                    "Array access should be one-dimension access, but the indices are "
-                    + str(indexes),
-                    node.span,
-                )
-            index = indexes[0]
-            if not isinstance(index, (int, tvm.tir.expr.IntImm)):
-                self.report_error(
-                    "Array access index expected int or IntImm, but got " + type(index),
-                    node.span,
-                )
-            if int(index) >= len(symbol):
-                self.report_error(
-                    f"Array access out of bound, size: {len(symbol)}, got index {index}.",
-                    node.span,
-                )
-            return symbol[int(index)]
-        else:
-            self.report_error(
-                f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
-                "buffers are supported.",
-                node.params[0].span,
-            )
-
-    def transform_Attr(self, node):
-        """Visitor for field access of the form `x.y`.
-
-        This visitor is used to lookup function and symbol names. We have two
-        cases to handle here:
-        1. If we have a statement of the form `tir.something`, then we lookup
-           `tir.something` in the `Registry`. If the function is not in the
-           registry, then we try to find a `tvm.ir.op.Op` with the same name.
-        2. All other names `tvm.something` are lookup up in this current python
-           namespace.
-        """
-
-        def get_full_attr_name(node: ast.Attr) -> str:
-            reverse_field_names = [node.field.name]
-            while isinstance(node.object, ast.Attr):
-                node = node.object
-                reverse_field_names.append(node.field.name)
-            if isinstance(node.object, ast.Var):
-                reverse_field_names.append(node.object.id.name)
-            return ".".join(reversed(reverse_field_names))
-
-        if isinstance(node.object, (ast.Var, ast.Attr)):
-            full_attr_name = get_full_attr_name(node)
-            attr_object, fields = full_attr_name.split(".", maxsplit=1)
-            if self.match_tir_namespace(attr_object):
-                func_name = "tir." + fields
-                res = Registry.lookup(func_name)
-                if res is not None:
-                    return res
-                try:
-                    return tvm.ir.op.Op.get(func_name)
-                except TVMError as e:
-                    # Check if we got an attribute error
-                    if e.args[0].find("AttributeError"):
-                        self.report_error(f"Unregistered function `tir.{fields}`.", node.span)
-                    else:
-                        raise e
-
-        symbol = self.transform(node.object)
-        if symbol is None:
-            self.report_error("Unsupported Attribute expression.", node.object.span)
-        if not hasattr(symbol, node.field.name):
-            self.report_error(
-                f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span
-            )
-        res = getattr(symbol, node.field.name)
-        return res
-
-    def transform_TypeAttr(self, node):
-        """Visitor for field access of the form `x.y` for types.
-
-        We have two cases here:
-        1. If the type is of the form `T.something`, we look up the type in
-           the `tir` namespace in this module.
-        2. If the type is of the form `tvm.x.something` then we look up
-           `tvm.x.something` in this modules namespace.
-        """
-        if isinstance(node.object, ast.TypeVar):
-            if self.match_tir_namespace(node.object.id.name):
-                if not hasattr(tir, node.field.name):
-                    self.report_error(
-                        f"Invalid type annotation `tir.{node.field.name}`.", node.span
-                    )
-                return getattr(tir, node.field.name)
-
-        symbol = self.transform(node.object)
-        if symbol is None:
-            self.report_error("Unsupported Attribute expression", node.object.span)
-        if not hasattr(symbol, node.field):
-            self.report_error(
-                f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span
-            )
-        res = getattr(symbol, node.field)
-        return res
-
-    def transform_DictLiteral(self, node):
-        """Dictionary literal visitor.
-
-        Handles dictionary literals of the form `{x:y, z:2}`.
-        """
-
-        keys = [self.transform(key) for key in node.keys]
-        values = [self.transform(value) for value in node.values]
-
-        return dict(zip(keys, values))
-
-    def transform_Tuple(self, node):
-        """Tuple visitor.
-
-        Handles tuples of the form `(x, y, 2)`.
-        """
-
-        return tuple(self.transform(element) for element in node.values)
-
-    def transform_ArrayLiteral(self, node):
-        """List literal visitor.
-
-        Handles lists of the form `[x, 2, 3]`.
-        """
-
-        return [self.transform(element) for element in node.values]
-
-    def transform_Var(self, node):
-        """Variable visitor
-
-        Handles variables like `x` in `x = 2`.
-        """
-
-        name = node.id.name
-        if name == "meta":
-            return self.meta
-        symbol = Registry.lookup(name)
-        if symbol is not None:
-            return symbol
-        symbol = self.context.lookup_symbol(name)
-        if symbol is not None:
-            return symbol
-        self.report_error(f"Unknown identifier {name}.", node.span)
-
-    def transform_TypeVar(self, node):
-        """Type variable visitor.
-
-        Equivalent to `transform_Var` but for types.
-        """
-        name = node.id.name
-        symbol = Registry.lookup(name) or self.context.lookup_symbol(name)
-        if symbol is not None:
-            return symbol
-        self.report_error(f"Unknown identifier {name}.", node.span)
-
-    def transform_Constant(self, node):
-        """Constant value visitor.
-
-        Constant values include `None`, `"strings"`, `2` (integers), `4.2`
-        (floats), and `true` (booleans).
-        """
-        return tvm.runtime.convert(node.value, span=tvm_span_from_synr(node.span))
-
-    def transform_TypeConstant(self, node):
-        """Constant value visitor for types.
-
-        See `transform_Constant`.
-        """
-        if self._inside_buffer_sugar:
-            return self.transform_Constant(node)
-
-        return node.value
-
-    def transform_TypeTuple(self, node):
-        """Tuple value visitor for types.
-
-        Mostly used in `transform_TypeCall` and `transform_TypeApply`.
-        """
-        return [self.transform(value) for value in node.values]
-
-    def transform_TypeCall(self, node):
-        """TypeCall visitor
-
-        This occurs when an expression is used inside a T.Buffer
-        parameter annotation.
-        """
-
-        # ast.Call has the BuiltinOp as node.func_name.name, where
-        # ast.TypeCall has the BuiltinOp as node.func_name.  So we can
-        # delegate to self.transform_Call, but the error messages for
-        # unsupported operations will highlight the entire expression
-        # and not just the function itself.
-        op = ast.Op(node.span, node.func_name)
-        call = ast.Call(node.span, op, node.params, node.keyword_params)
-        return self.transform_Call(call)
-
-    def transform_TypeApply(self, node):
-        """Visitor for Type[Type] expressions.
-
-        Mostly used for ``T.Ptr`` expressions.
-        """
-        func = self.transform(node.func_name)
-
-        if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"):
-            self.report_error(
-                f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), "
-                f"but found {type(func).__name__} instead.",
-                node.span,
-            )
-
-        param_types = []
-        for idx, param in enumerate(node.params):
-            param_type = self.transform(param)
-            if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx):
-                self.report_error(
-                    f"Expected a type but found {type(param).__name__} "
-                    f"at {idx}th type argument",
-                    param.span,
-                )
-
-            param_types.append(param_type)
-
-        if len(param_types) == 1:
-            return func[param_types[0]]
-        else:
-            return func[param_types]
-
-    def handle_match_buffer_type(self, node, buffer_name):
-        """special function to handle syntax sugar for match buffer.
-
-        This method is for buffer declarations in the function parameters.
-        """
-        func = self.transform(node.func_name)
-        assert isinstance(func, SpecialStmt)
-
-        # parse args and kwargs for TypeCall and TypeApply
-        self._inside_buffer_sugar = True
-        try:
-            arg_list = self.parse_arg_list(func, node)
-        finally:
-            self._inside_buffer_sugar = False
-
-        # Note that the third element in arg_list would always be the 'name'
-        # TODO: This index is hardcoded as a workaround. Better to make it programmatic
-        if arg_list[2] is None:
-            arg_list[2] = buffer_name
-        buf = func.handle(node, self.context, arg_list, node.func_name.span)
-        return buf
-
-    def transform_Return(self, node):
-        self.report_error(
-            "TVM script does not support return statements. Instead the last statement in any "
-            "block is implicitly returned.",
-            node.span,
-        )
-
-
-def get_tir_namespace(script: Union[Callable, type]) -> List[str]:
-    assert inspect.isfunction(script) or inspect.isclass(script)
-    env: Dict[str, Any] = script.__globals__
-    return [key for key in env.keys() if env[key] == tir]
-
-
-def from_source(
-    input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None
-) -> Union[PrimFunc, IRModule]:
-    """Parse function or string into PrimFunc or IRModule.
-
-    If possible, pass the TVM script in as a function so that line numbers and
-    filename will be accurate.
-
-    Parameters
-    ----------
-    input_module : Union[str, Callable]
-        The python function to be parsed.
-
-    tir_prefix : Optional[List[str]]
-        The tir prefix list. Only works for str input, default by "tir" and "T".
-
-    Returns
-    -------
-    output : Union[Function, Module]
-        The Function or Module in IR.
-    """
-    if isinstance(input_func, str):
-        tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
-        return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
-    elif inspect.isfunction(input_func):
-        _, start_line = inspect.getsourcelines(input_func)
-        env: Dict[str, Any] = input_func.__globals__
-        namespace = [key for key in env.keys() if env[key] is tir]
-        _closure_vars = inspect.getclosurevars(input_func)
-        closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
-        parser = TVMScriptParser(start_line, namespace, closure_vars)
-        result = to_ast(input_func, TVMDiagnosticCtx(), parser)
-        return result
-    else:
-        raise TypeError("Only function definitions are supported.")
-
-
-def ir_module(input_module: type) -> IRModule:
-    """Decorate a python class as tvm IRModule.
-
-    Parameters
-    ----------
-    input_module : type
-        The python class to be parsed.
-
-    Returns
-    -------
-    output : IRModule
-        The result IRModule.
-    """
-    if inspect.isclass(input_module):
-        func_dict = {
-            name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
-        }
-        return IRModule(func_dict)
-    raise TypeError("Only class definitions are supported.")
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 83%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..5161a2601c 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,9 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+# under the Licens.
+"""The parser"""
+from . import _core, ir, tir
+from ._core import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/_core.py
similarity index 76%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/_core.py
index 555659d0c5..4f5411dc36 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/_core.py
@@ -13,9 +13,10 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+# under the Licens.
+"""The core parser infra"""
+# pylint: disable=unused-import
+from .core import dispatch, doc, utils
+from .core.dispatch import OpMethod, register_op
+from .core.entry import parse
+from .core.parser import Parser
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/core/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/core/__init__.py
index 555659d0c5..94d8dab032 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/core/__init__.py
@@ -14,8 +14,5 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""The core parser infra"""
+from . import diagnostics, dispatch, doc, doc_core, entry, evaluator, parser, utils
diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py
new file mode 100644
index 0000000000..51c26bbc24
--- /dev/null
+++ b/python/tvm/script/parser/core/diagnostics.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+import re
+import sys
+from typing import Union
+
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+
+
+class Source:
+    source_name: str
+    start_line: int
+    start_column: int
+    source: str
+    full_source: str
+
+    def __init__(self, program: Union[str, doc.AST]):
+        if isinstance(program, str):
+            self.source_name = "<str>"
+            self.start_line = 1
+            self.start_column = 0
+            self.source = program
+            self.full_source = program
+            return
+
+        self.source_name = inspect.getsourcefile(program)  # type: ignore
+        lines, self.start_line = getsourcelines(program)  # type: ignore
+        if lines:
+            self.start_column = len(lines[0]) - len(lines[0].lstrip())
+        else:
+            self.start_column = 0
+        if self.start_column and lines:
+            self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+        else:
+            self.source = "".join(lines)
+        try:
+            # It will cause a problem when running in Jupyter Notebook.
+            # `mod` will be <module '__main__'>, which is a built-in module
+            # and `getsource` will throw a TypeError
+            mod = inspect.getmodule(program)
+            if mod:
+                self.full_source = inspect.getsource(mod)
+            else:
+                self.full_source = self.source
+        except TypeError:
+            # It's a work around for Jupyter problem.
+            # Since `findsource` is an internal API of inspect, we just use it
+            # as a fallback method.
+            src, _ = inspect.findsource(program)  # type: ignore
+            self.full_source = "".join(src)
+
+    def as_ast(self) -> doc.AST:
+        return doc.parse(self.source)
+
+
+_getfile = inspect.getfile  # pylint: disable=invalid-name
+_findsource = inspect.findsource  # pylint: disable=invalid-name
+
+
+def _patched_inspect_getfile(obj):
+    if not inspect.isclass(obj):
+        return _getfile(obj)
+    mod = getattr(obj, "__module__", None)
+    if mod is not None:
+        file = getattr(sys.modules[mod], "__file__", None)
+        if file is not None:
+            return file
+    for _, member in inspect.getmembers(obj):
+        if inspect.isfunction(member):
+            if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+                return inspect.getfile(member)
+    raise TypeError(f"Source for {obj:!r} not found")
+
+
+def findsource(obj):
+    import linecache  # pylint: disable=import-outside-toplevel
+
+    if not inspect.isclass(obj):
+        return _findsource(obj)
+
+    file = inspect.getsourcefile(obj)
+    if file:
+        linecache.checkcache(file)
+    else:
+        file = inspect.getfile(obj)
+        if not (file.startswith("<") and file.endswith(">")):
+            raise OSError("source code not available")
+
+    module = inspect.getmodule(obj, file)
+    if module:
+        lines = linecache.getlines(file, module.__dict__)
+    else:
+        lines = linecache.getlines(file)
+    if not lines:
+        raise OSError("could not get source code")
+    qual_names = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
+    pattern_list = []
+    for name in qual_names:
+        if name.endswith("<locals>"):
+            pattern_list.append(re.compile(r"^(\s*)def\s*" + name[:-8] + r"\b"))
+        else:
+            pattern_list.append(re.compile(r"^(\s*)class\s*" + name + r"\b"))
+    for i, line in enumerate(lines):
+        match = pattern_list[0].match(line)
+        if match:
+            pattern_list.pop(0)
+        if not pattern_list:
+            return lines, i
+    raise OSError("could not find class definition")
+
+
+def getsourcelines(obj):
+    obj = inspect.unwrap(obj)
+    lines, l_num = findsource(obj)
+    return inspect.getblock(lines[l_num:]), l_num + 1
+
+
+inspect.getfile = _patched_inspect_getfile
+
+
+class Diagnostics:
+
+    source: Source
+    ctx: diagnostics.DiagnosticContext
+
+    def __init__(self, source: Source):
+        mod = IRModule()
+        mod.source_map.add(source.source_name, source.full_source)
+        self.source = source
+        self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+    def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+        lineno = node.lineno or self.source.start_line
+        col_offset = node.col_offset or self.source.start_column
+        end_lineno = node.end_lineno or lineno
+        end_col_offset = node.end_col_offset or col_offset
+        lineno += self.source.start_line - 1
+        end_lineno += self.source.start_line - 1
+        col_offset += self.source.start_column + 1
+        end_col_offset += self.source.start_column + 1
+        self.ctx.emit(
+            diagnostics.Diagnostic(
+                level=level,
+                span=Span(
+                    source_name=SourceName(self.source.source_name),
+                    line=lineno,
+                    end_line=end_lineno,
+                    column=col_offset,
+                    end_column=end_col_offset,
+                ),
+                message=message,
+            )
+        )
+
+    def error(self, node: doc.AST, message: str) -> None:
+        self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+        self.ctx.render()
diff --git a/python/tvm/script/parser/core/dispatch.py b/python/tvm/script/parser/core/dispatch.py
new file mode 100644
index 0000000000..f10b90961a
--- /dev/null
+++ b/python/tvm/script/parser/core/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+    """Register a method for a dispatch token and type name"""
+
+    def f(method: ParseMethod):
+        ParseVTable[(token, type_name)] = method
+
+    return f
+
+
+def get(
+    token: str,
+    type_name: str,
+    default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+    return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int):  # pylint: disable=invalid-name
+    def f(method: OpMethod):
+        OpVTable[(ty, op, operand_index)] = method
+
+    return f
+
+
+def get_op(  # pylint: disable=invalid-name
+    ty: Type,
+    op: Type,
+    operand_index: int,
+    default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+    return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py
new file mode 100644
index 0000000000..f6a641cb64
--- /dev/null
+++ b/python/tvm/script/parser/core/doc.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import *  # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+    to_doc: typing.Optional[FnToDoc]
+    from_doc: typing.Optional[FnFromDoc]
+
+    def __init__(self):
+        self.to_doc = None
+        self.from_doc = None
+
+
+class Registry:
+    _inst: typing.Optional["Registry"] = None
+    table: typing.Dict[str, Entry]
+
+    def __init__(self):
+        self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+    def f(to_doc: FnToDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].to_doc = to_doc
+
+    return f
+
+
+def register_from_doc(name: str):
+    def f(to_doc: FnFromDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].from_doc = to_doc
+
+    return f
+
+
+def _is_atomic_type(node):
+    return (
+        node is None
+        or node in [..., True, False]
+        or isinstance(
+            node,
+            (
+                int,
+                float,
+                str,
+                bool,
+                bytes,
+                complex,
+            ),
+        )
+    )
+
+
+def _get_registry_entry(cls_name, attr):
+    cls_name = cls_name.split(".")[-1]
+    reg = Registry._inst  # pylint: disable=protected-access
+    if cls_name in reg.table:
+        entry = reg.table[cls_name]
+        return getattr(entry, attr, None)
+    return None
+
+
+def from_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(from_doc(n) for n in node)
+    if isinstance(node, list):
+        return [from_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "from_doc")
+    if not func:
+        raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def to_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(to_doc(n) for n in node)
+    if isinstance(node, list):
+        return [to_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "to_doc")
+    if not func:
+        raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def parse(
+    source,
+    filename="<unknown>",
+    mode="exec",
+) -> doc.AST:
+    try:
+        program = ast.parse(  # pylint: disable=unexpected-keyword-arg
+            source=source,
+            filename=filename,
+            mode=mode,
+            feature_version=(3, 8),
+        )
+    except:  # pylint: disable=bare-except
+        program = ast.parse(
+            source=source,
+            filename=filename,
+            mode=mode,
+        )
+    return to_doc(program)
+
+
+class NodeVisitor:
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> None:
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                self.visit(value)
+
+
+class NodeTransformer:
+    def visit(self, node: doc.AST) -> doc.AST:
+        if isinstance(node, list):
+            return [self.visit(item) for item in node]
+        if isinstance(node, tuple):
+            return tuple(self.visit(item) for item in node)
+        if not isinstance(node, doc.AST):
+            return node
+        return getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> doc.AST:
+        kv: typing.Dict[str, typing.Any] = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                value = self.visit(value)
+            kv[field] = value
+        return node.__class__(**kv)
+
+
+def _register_default():
+    class DefaultTranslator:
+        def __init__(self, doc_cls, func, fields):
+            self.doc_cls = doc_cls  # getattr(doc, name)
+            self.func = func
+            self.fields = fields
+
+        def __call__(self, node):
+            kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+            return self.doc_cls(**kv)
+
+    Registry._inst = Registry()  # pylint: disable=protected-access
+    for cls_name in dir(doc):
+        doc_cls = getattr(doc, cls_name)
+        if not hasattr(ast, cls_name):
+            continue
+        if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+            assert "." not in cls_name
+            register_to_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(doc, cls_name),
+                    to_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+            register_from_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(ast, cls_name),
+                    from_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+    return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+    if _py_version() not in [(3, 6), (3, 7)]:
+        return
+
+    def as_constant(f) -> doc.Constant:
+        def to_doc_func(x: ast.AST) -> doc.Constant:
+            return doc.Constant(
+                value=getattr(x, f) if isinstance(f, str) else f(x),
+                kind=None,
+                s=None,
+                n=None,
+                lineno=x.lineno,
+                col_offset=x.col_offset,
+                end_lineno=x.lineno,
+                end_col_offset=x.col_offset,
+            )
+
+        return to_doc_func
+
+    register_to_doc("Str")(as_constant("s"))
+    register_to_doc("NameConstant")(as_constant("value"))
+    register_to_doc("Num")(as_constant("n"))
+    register_to_doc("Bytes")(as_constant("s"))
+    register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+        if isinstance(x.slice, ast.Slice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Slice(
+                    lower=to_doc(x.slice.lower),
+                    upper=to_doc(x.slice.upper),
+                    step=to_doc(x.slice.step),
+                    lineno=getattr(x.slice, "lineno", None),
+                    col_offset=getattr(x.slice, "col_offset", None),
+                    end_lineno=getattr(x.slice, "end_lineno", None),
+                    end_col_offset=getattr(x.slice, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.ExtSlice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Tuple(
+                    elts=[to_doc(i) for i in x.slice.dims],
+                    ctx=doc.Load(
+                        lineno=None,
+                        col_offset=None,
+                        end_lineno=None,
+                        end_col_offset=None,
+                    ),
+                    lineno=getattr(x, "lineno", None),
+                    col_offset=getattr(x, "col_offset", None),
+                    end_lineno=getattr(x, "end_lineno", None),
+                    end_col_offset=getattr(x, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.Index):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=to_doc(x.slice.value),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+    def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+        if isinstance(x.slice, doc.Slice):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=from_doc(x.slice),
+                ctx=from_doc(x.ctx),
+            )
+        elif isinstance(x.slice, doc.Tuple):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.ExtSlice(
+                    dims=[from_doc(i) for i in x.slice.elts],
+                ),
+                ctx=from_doc(x.ctx),
+            )
+        else:
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.Index(value=from_doc(x.slice)),
+                ctx=from_doc(x.ctx),
+            )
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Subscript")(subscript_to_doc)
+    register_from_doc("Subscript")(subscript_from_doc)
+
+
+def _register_index_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def index_to_doc(x: ast.Index) -> doc.Expr:
+        return to_doc(x.value)
+
+    def index_from_doc(x: doc.Expr) -> ast.Index:
+        result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx))
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Index")(index_to_doc)
+    register_from_doc("Index")(index_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
+_register_index_handling()
diff --git a/python/tvm/script/printer/doc_core.py b/python/tvm/script/parser/core/doc_core.py
similarity index 100%
rename from python/tvm/script/printer/doc_core.py
rename to python/tvm/script/parser/core/doc_core.py
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/core/entry.py
similarity index 51%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/core/entry.py
index 923eb97d27..ccf42e8c15 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/core/entry.py
@@ -14,32 +14,30 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
 
-import inspect
-from typing import Callable
+from ...ir_builder import IRBuilder
+from . import doc
+from .diagnostics import Source
+from .parser import Parser
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
 
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+    if extra_vars is None:
+        from tvm.script.parser import ir  # pylint: disable=import-outside-toplevel
+        from tvm.script.parser import tir  # pylint: disable=import-outside-toplevel
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
+        extra_vars = {
+            "I": ir,
+            "ir": ir,
+            "T": tir,
+            "tir": tir,
+        }
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
-
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
-
-    raise TypeError("Only function definitions are supported.")
+    source = Source(program)
+    parser = Parser(source)
+    with IRBuilder() as builder:
+        parser.parse(extra_vars=extra_vars)
+    return builder.get()
diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py
new file mode 100644
index 0000000000..405281a65e
--- /dev/null
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+    doc.Add: lambda a, b: a + b,
+    doc.Sub: lambda a, b: a - b,
+    doc.Mult: lambda a, b: a * b,
+    doc.Div: lambda a, b: a / b,
+    doc.FloorDiv: lambda a, b: a // b,
+    doc.Mod: lambda a, b: a % b,
+    doc.LShift: lambda a, b: a << b,
+    doc.RShift: lambda a, b: a >> b,
+    doc.BitOr: lambda a, b: a | b,
+    doc.BitXor: lambda a, b: a ^ b,
+    doc.BitAnd: lambda a, b: a & b,
+    doc.MatMult: lambda a, b: a @ b,
+    doc.Pow: lambda a, b: a ** b,
+    doc.Eq: lambda a, b: a == b,
+    doc.NotEq: lambda a, b: a != b,
+    doc.Lt: lambda a, b: a < b,
+    doc.LtE: lambda a, b: a <= b,
+    doc.Gt: lambda a, b: a > b,
+    doc.GtE: lambda a, b: a >= b,
+    doc.Is: lambda a, b: a is b,
+    doc.IsNot: lambda a, b: a is not b,
+    doc.In: lambda a, b: a in b,
+    doc.NotIn: lambda a, b: a not in b,
+    doc.And: lambda a, b: a and b,
+    doc.Or: lambda a, b: a or b,
+    doc.Invert: lambda a: ~a,
+    doc.Not: lambda a: not a,
+    doc.UAdd: lambda a: +a,
+    doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+    parser: "Parser"
+    value_table: Dict[str, Any]
+    new_value_count: int
+
+    def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+        super().__init__()
+        self.parser = parser
+        self.value_table = value_table
+        self.new_value_count = 0
+
+    @staticmethod
+    def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+        self = ExprEvaluator(parser, value_table)
+        result = self._visit(node)  # pylint: disable=protected-access
+        if isinstance(result, doc.Name):
+            if result.id not in self.value_table:
+                self.parser.report_error(result, f"Undefined variable: {result.id}")
+            return self.value_table[result.id]
+        if isinstance(result, doc.Constant):
+            return result.value
+        raise TypeError(f"Unexpected result type: {type(result)}")
+
+    def _add_intermediate_result(self, value: Any) -> doc.Name:
+        name = f"__tvm_tmp_value_{self.new_value_count}"
+        self.new_value_count += 1
+        self.value_table[name] = value
+        lineno = 0
+        col_offset = 0
+        return doc.Name(
+            id=name,
+            ctx=doc.Load(
+                lineno=lineno,
+                col_offset=col_offset,
+                end_lineno=None,
+                end_col_offset=None,
+            ),
+            lineno=lineno,
+            col_offset=col_offset,
+            end_lineno=None,
+            end_col_offset=None,
+        )
+
+    def _visit(self, node: doc.AST) -> Any:
+        if isinstance(node, list):
+            return [self._visit(n) for n in node]
+        if isinstance(node, tuple):
+            return tuple(self._visit(n) for n in node)
+        assert isinstance(node, doc.AST)
+        if isinstance(node, doc.Name):
+            if node.id not in self.value_table:
+                self.parser.report_error(node, f"Undefined variable: {node.id}")
+            return node
+        if isinstance(
+            node,
+            (
+                doc.Constant,
+                doc.expr_context,
+                doc.operator,
+                doc.boolop,
+                doc.unaryop,
+                doc.cmpop,
+            ),
+        ):
+            return node
+        if not isinstance(node, (doc.expr, doc.slice)):
+            return node
+        if isinstance(node, doc.Lambda):
+            return self._eval_lambda(node)
+        fields = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            attr = getattr(node, field)
+            if isinstance(attr, (doc.AST, tuple, list)):
+                fields[field] = self._visit(attr)
+            else:
+                fields[field] = attr
+        try:
+            if isinstance(node, doc.BoolOp):
+                value = self._eval_bool_op(fields)
+            elif isinstance(node, doc.Compare):
+                value = self._eval_compare(fields)
+            elif isinstance(node, doc.UnaryOp):
+                value = self._eval_unary_op(fields)
+            elif isinstance(node, doc.BinOp):
+                value = self._eval_bin_op(fields)
+            elif isinstance(node, doc.Slice):
+                value = self._eval_slice(fields)
+            else:
+                value = self._eval_expr(node.__class__(**fields))
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_lambda(self, node: doc.Lambda) -> Any:
+        try:
+            value = self._eval_expr(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+        op = fields["op"]
+        if not isinstance(op, (doc.And, doc.Or)):
+            raise TypeError(f"Unexpected operator: {op}")
+        value = self._eval_expr(fields["values"][0])
+        for rhs in fields["values"][1:]:
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["left"])
+        for op, rhs in zip(fields["ops"], fields["comparators"]):
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["operand"])
+        value = _eval_op(fields["op"], values=[value])
+        return value
+
+    def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+        return _eval_op(
+            fields["op"],
+            values=[
+                self._eval_expr(fields["left"]),
+                self._eval_expr(fields["right"]),
+            ],
+        )
+
+    def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+        lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+        lower = self._eval_expr(lower) if lower is not None else None
+        upper = self._eval_expr(upper) if upper is not None else None
+        step = self._eval_expr(step) if step is not None else None
+
+        return slice(lower, upper, step)
+
+    def _eval_expr(self, v: Any) -> Any:
+        return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+    parser: "Parser",
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    value_table = {}
+    if dict_globals is not None:
+        value_table.update(dict_globals)
+    return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+    parser: "Parser",
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    try:
+        return _eval_assign(target, source)
+    except Exception as e:  # pylint: disable=broad-except,invalid-name
+        parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+        raise
+
+
+def _eval_expr(
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    node = doc.from_doc(node)
+    if isinstance(node, ast.expr):
+        node = ast.Expression(body=node)
+    assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+    if dict_globals is None:
+        dict_globals = {}
+    node = ast.fix_missing_locations(node)
+    exe = compile(node, filename="<ast>", mode="eval")
+    return eval(exe, dict_globals)  # pylint: disable=eval-used
+
+
+def _eval_op(
+    op: doc.AST,
+    values: List[Any],
+):
+    op_type = type(op)  # pylint: disable=protected-access
+    for i, v in enumerate(values):
+        v_type = getattr(type(v), "_dispatch_type", None)
+        if v_type is None:
+            continue
+        f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+        if f is not None:
+            return f(*values)
+    return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    target = doc.from_doc(target)
+    assert isinstance(target, ast.expr)
+    RHS_VAR_NAME = "__tvm_rhs_var__"  # pylint: disable=invalid-name
+    rhs_var_name = RHS_VAR_NAME
+    dict_locals = {rhs_var_name: source}
+    mod = ast.fix_missing_locations(
+        ast.Module(
+            body=[
+                ast.Assign(
+                    targets=[target],
+                    value=ast.Name(
+                        id=rhs_var_name,
+                        ctx=ast.Load(),
+                    ),
+                )
+            ],
+            type_ignores=[],
+        )
+    )
+    exe = compile(mod, filename="<ast>", mode="exec")
+    exec(exe, {}, dict_locals)  # pylint: disable=exec-used
+    del dict_locals[rhs_var_name]
+    return dict_locals
diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
new file mode 100644
index 0000000000..e26324262f
--- /dev/null
+++ b/python/tvm/script/parser/core/parser.py
@@ -0,0 +1,273 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from collections import defaultdict
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, List, Optional, Set, Union
+
+from tvm.error import DiagnosticError
+
+from . import dispatch, doc
+from .diagnostics import Diagnostics, Source
+from .evaluator import eval_assign, eval_expr
+
+DEFAULT_VISIT = {
+    "Interactive",
+    "Module",
+    "Expression",
+    "Pass",
+}
+
+
+def _deferred(f: Callable[[], None]):
+    @contextmanager
+    def context():
+        try:
+            yield
+        finally:
+            f()
+
+    return context()
+
+
+class VarTableFrame:
+    vars: Set[str]
+
+    def __init__(self):
+        self.vars = set()
+
+    def add(self, var: str):
+        if var in self.vars:
+            raise ValueError(f"Variable {var} already defined in current scope")
+        self.vars.add(var)
+
+    def pop_all(self, fn_pop: Callable[[str], None]):
+        for var in self.vars:
+            fn_pop(var)
+        self.vars.clear()
+
+
+class VarTable:
+
+    frames: List[VarTableFrame]
+    name2value: Dict[str, List[Any]]
+
+    def __init__(self):
+        self.frames = []
+        self.name2value = defaultdict(list)
+
+    def with_frame(self):
+        def pop_frame():
+            frame = self.frames.pop()
+            frame.pop_all(lambda name: self.name2value[name].pop())
+
+        self.frames.append(VarTableFrame())
+        return _deferred(pop_frame)
+
+    def add(self, var: str, value: Any):
+        self.frames[-1].add(var)
+        self.name2value[var].append(value)
+
+    def get(self) -> Dict[str, Any]:
+        return {key: values[-1] for key, values in self.name2value.items() if values}
+
+    def exist(self, value: Any):
+        for v in self.name2value.values():
+            if v is value:
+                return True
+        return False
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+    def _wrapper(self: "Parser", node: doc.AST) -> None:
+        try:
+            return func(self, node)
+        except DiagnosticError:
+            raise
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+            raise
+
+    return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+    for token in [self.dispatch_tokens[-1], "default"]:
+        func = dispatch.get(token=token, type_name=type_name, default=None)
+        if func is not None:
+            return _dispatch_wrapper(func)
+    return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+    """The TVMScript parser"""
+
+    diag: Diagnostics
+    dispatch_tokens: List[str]
+    var_table: VarTable
+
+    def __init__(self, source: Source) -> None:
+        self.diag = Diagnostics(source)
+        self.dispatch_tokens = ["default"]
+        self.var_table = VarTable()
+
+    def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+        if extra_vars is None:
+            extra_vars = {}
+        with self.var_table.with_frame():
+            for k, v in extra_vars.items():
+                self.var_table.add(k, v)
+            node = self.diag.source.as_ast()
+            self.visit(node)
+
+    def with_dispatch_token(self, token: str):
+        def pop_token():
+            self.dispatch_tokens.pop()
+
+        self.dispatch_tokens.append(token)
+        return _deferred(pop_token)
+
+    def eval_expr(
+        self,
+        node: Union[doc.Expression, doc.expr],
+        extra_vars: Optional[Dict[str, Any]] = None,
+    ) -> Any:
+        var_values = self.var_table.get()
+        if extra_vars is not None:
+            for k, v in extra_vars.items():
+                var_values[k] = v
+        return eval_expr(self, node, var_values)
+
+    def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
+        if isinstance(target, (doc.Tuple, doc.List)):
+            vars: Set[str] = set()  # pylint: disable=redefined-builtin
+            for i in target.elts:
+                res = self._duplicate_lhs_check(i)
+                if isinstance(res, bool) and res:
+                    return True
+                assert isinstance(res, set)
+                if vars & res:
+                    return True
+                vars = vars.union(res)
+            return vars
+        elif isinstance(target, doc.Name):
+            return {target.id}
+        else:
+            self.report_error(target, "Invalid type in assign statement")
+            raise NotImplementedError
+
+    def eval_assign(
+        self,
+        target: doc.expr,
+        source: Any,
+        bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+    ) -> Dict[str, Any]:
+        if self._duplicate_lhs_check(target) is True:
+            self.report_error(target, "Duplicate vars assigned.")
+        var_values = eval_assign(self, target, source)
+        for k, v in var_values.items():
+            var = bind_value(self, target, k, v)
+            self.var_table.add(k, var)
+        return var_values
+
+    def report_error(self, node: doc.AST, msg: str) -> None:  # pylint: disable=no-self-use
+        self.diag.error(node, msg)
+
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        name = node.__class__.__name__.split(".")[-1]
+        if name in DEFAULT_VISIT:
+            func = self.generic_visit
+        else:
+            func = getattr(self, "visit_" + name, None)
+        if func is None:
+            raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+        try:
+            func(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_body(self, node: List[doc.stmt]) -> Any:
+        for stmt in node:
+            self.visit(stmt)
+
+    def visit_tvm_annotation(self, node: doc.expr) -> Any:
+        return _dispatch(self, "tvm_annotation")(self, node)
+
+    def visit_FunctionDef(self, node: doc.FunctionDef) -> Any:  # pylint: disable=invalid-name
+        if not node.decorator_list:
+            self.report_error(node, "Function must be decorated")
+        # TODO: only the last decorator is parsed
+        decorator = self.eval_expr(node.decorator_list[-1])
+        if not hasattr(decorator, "dispatch_token"):
+            self.report_error(node, "The parser does not understand the decorator")
+        token = decorator.dispatch_token
+        func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        try:
+            func(self, node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_ClassDef(self, node: doc.ClassDef) -> Any:  # pylint: disable=invalid-name
+        func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        try:
+            func(self, node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_arguments(self, node: doc.arguments) -> Any:
+        return _dispatch(self, "arguments")(self, node)
+
+    def visit_For(self, node: doc.For) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "For")(self, node)
+
+    def visit_While(self, node: doc.While) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "While")(self, node)
+
+    def visit_With(self, node: doc.With) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "With")(self, node)
+
+    def visit_Assign(self, node: doc.Assign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assign")(self, node)
+
+    def visit_Expr(self, node: doc.Expr) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Expr")(self, node)
+
+    def visit_If(self, node: doc.If) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "If")(self, node)
+
+    def visit_AnnAssign(self, node: doc.AnnAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AnnAssign")(self, node)
+
+    def visit_AugAssign(self, node: doc.AugAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AugAssign")(self, node)
+
+    def visit_Assert(self, node: doc.Assert) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assert")(self, node)
+
+    def visit_Return(self, node: doc.Return) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Return")(self, node)
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/core/utils.py
similarity index 58%
copy from python/tvm/script/tir/__init__.py
copy to python/tvm/script/parser/core/utils.py
index 2f2b4bbc25..aae45fe6ff 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -14,18 +14,23 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Any, Callable, Dict
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
 
-from .prim_func import prim_func
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+    captured = {
+        **inspect.getclosurevars(func).nonlocals,
+        **func.__globals__,  # type: ignore
+    }
+    return captured
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+    result: Dict[str, Any] = {}
+    for _, v in cls.__dict__.items():
+        if inspect.isfunction(v):
+            func_vars = inspect_function_capture(v)
+            result.update(**func_vars)
+    return result
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/__init__.py
index 555659d0c5..4cbd9910a2 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,8 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
 
-from . import tir
-
-from .parser import ir_module, from_source
+__all__ = ["ir_module"]
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/ir/entry.py
similarity index 54%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser/ir/entry.py
index 923eb97d27..3c1e4de5a7 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,32 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
-
+# pylint: disable=missing-docstring
 import inspect
-from typing import Callable
+from typing import Type
+
+from tvm.ir import IRModule
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+from .._core import parse, utils
 
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
+def ir_module(f: Type) -> IRModule:
+    if not inspect.isclass(f):
+        raise TypeError(f"Expect a class, but got: {f}")
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
+    return parse(f, utils.inspect_class_capture(f))
 
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
 
-    raise TypeError("Only function definitions are supported.")
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 55%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser/ir/parser.py
index 2f2b4bbc25..8871d3b415 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,18 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+from ...ir_builder import ir as I
+from .._core import Parser, dispatch, doc
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
 
-from .prim_func import prim_func
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+    with self.var_table.with_frame():
+        with I.ir_module():
+            with self.with_dispatch_token("ir"):
+                self.visit_body(node.body)
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+    pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+    pass
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 71%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..930764f73d 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from ...ir_builder.tir import *  # pylint: disable=redefined-builtin
+from ...ir_builder.tir import ir as _tir
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
 
-from . import tir
-
-from .parser import ir_module, from_source
+__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..db4e2dd9a3
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+
+from ...ir_builder.tir import buffer_decl, ptr
+from .._core import parse, utils
+
+
+def _is_defined_in_class(frames):
+    if len(frames) > 2:
+        maybe_class_frame = frames[2]
+        statement_list = maybe_class_frame[4]
+        if statement_list is None:
+            return False
+        first_statement = statement_list[0]
+        line = first_statement.strip()
+        if line.startswith("class "):
+            return True
+        if line.startswith("@") and "ir_module" in line:
+            return True
+    return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+    if not inspect.isfunction(f):
+        raise TypeError(f"Expect a function, but got: {f}")
+    if _is_defined_in_class(inspect.stack()):
+        return f
+    return parse(f, utils.inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+    def __call__(
+        self,
+        shape,
+        dtype="float32",
+        data=None,
+        strides=None,
+        elem_offset=None,
+        scope="global",
+        align=0,
+        offset_factor=0,
+        buffer_type="",
+        axis_separators=None,
+    ) -> Buffer:
+        return buffer_decl(
+            shape,
+            dtype=dtype,
+            data=data,
+            strides=strides,
+            elem_offset=elem_offset,
+            scope=scope,
+            align=align,
+            offset_factor=offset_factor,
+            buffer_type=buffer_type,
+            axis_separators=axis_separators,
+        )
+
+    def __getitem__(self, keys) -> Buffer:
+        if not isinstance(keys, tuple):
+            return self(keys)
+        if len(keys) >= 2 and not isinstance(keys[1], str):
+            return self(keys)
+        return self(*keys)  # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+    def __call__(self, dtype, storage_scope="global"):
+        if callable(dtype):
+            dtype = dtype().dtype
+        return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+    def __getitem__(self, keys):
+        if not isinstance(keys, tuple):
+            return self(keys)
+        return self(*keys)
+
+
+Buffer = BufferProxy()  # pylint: disable=invalid-name
+Ptr = PtrProxy()  # pylint: disable=invalid-name
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..716525b984
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .._core import OpMethod, doc, register_op
+
+
+def _register_expr_op(ty: Type):  # pylint: disable=invalid-name
+    ty._dispatch_type = ty  # pylint: disable=protected-access
+
+    def _and(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.And(a, b)
+
+    def _or(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.Or(a, b)
+
+    def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
+        register_op(ty, op, i)(m)
+
+    for i in [0, 1]:
+        # Case 1. binop
+        r(doc.Add, i, tir.Add)
+        r(doc.Sub, i, tir.Sub)
+        r(doc.Mult, i, tir.Mul)
+        r(doc.Div, i, tir.Div)
+        r(doc.FloorDiv, i, tir.FloorDiv)
+        r(doc.Mod, i, tir.FloorMod)
+        r(doc.LShift, i, lambda a, b: a << b)
+        r(doc.RShift, i, lambda a, b: a >> b)
+        r(doc.BitOr, i, lambda a, b: a | b)
+        r(doc.BitXor, i, lambda a, b: a ^ b)
+        r(doc.BitAnd, i, lambda a, b: a & b)
+        # doc.MatMult <-- not implemented
+        # doc.Pow <-- not implemented
+        # Case 2. cmpop
+        r(doc.Eq, i, tir.EQ)
+        r(doc.NotEq, i, tir.NE)
+        r(doc.Lt, i, tir.LT)
+        r(doc.LtE, i, tir.LE)
+        r(doc.Gt, i, tir.GT)
+        r(doc.GtE, i, tir.GE)
+        # doc.Is <-- not implemented
+        # doc.IsNot <-- not implemented
+        # doc.In <-- not implemented
+        # doc.NotIn <-- not implemented
+        # Case 3. boolop
+        r(doc.And, i, _and)
+        r(doc.Or, i, _or)
+    for i in [0]:
+        #  Case 4. unaryop
+        r(doc.Invert, i, lambda a: ~a)
+        r(doc.Not, i, tir.Not)
+        r(doc.UAdd, i, lambda a: +a)
+        r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..351238c06f
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,268 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+
+from ...ir_builder import tir as T
+from ...ir_builder.base import IRBuilderFrame as Frame
+from ...ir_builder.base import name
+from .._core import Parser, dispatch, doc
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, (Buffer, Var)):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+        raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Var):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+        raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, T.inline):
+        return value.value
+    elif isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, _node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Frame):
+        value.add_callback(partial(value.__exit__, None, None, None))
+        res = value.__enter__()
+        name(var_name, res)
+        return res
+    elif isinstance(value, (Buffer, IterVar)) or (
+        isinstance(value, Var) and not self.var_table.exist(value)
+    ):
+        name(var_name, value)
+        return value
+    elif isinstance(value, PrimExpr):
+        var = T.var(value.dtype)
+        name(var_name, var)
+        frame = T.let(var, value)
+        frame.add_callback(partial(frame.__exit__, None, None, None))
+        frame.__enter__()
+        return var
+    return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+    for_frame = self.eval_expr(node.iter)
+    if not isinstance(for_frame, T.frame.ForFrame):
+        self.report_error(
+            node.iter,
+            "Expect the for loop to be one of the following: "
+            "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+        )
+    with self.var_table.with_frame():
+        with for_frame as iters:
+            self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+    with self.var_table.with_frame():
+        cond = self.eval_expr(node.test)
+        with T.While(cond):
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+    lhs = node.targets[0]
+    rhs = self.eval_expr(node.value)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+    lhs_pos = (
+        node.target.lineno,
+        node.target.col_offset,
+        node.target.end_lineno,
+        node.target.end_col_offset,
+    )
+    rhs_pos = (
+        node.value.lineno,
+        node.value.col_offset,
+        node.value.end_lineno,
+        node.value.end_col_offset,
+    )
+    node.target.ctx = doc.Load(*lhs_pos)
+    with self.var_table.with_frame():
+        lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+        rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+        lhs_expr = self.eval_expr(node.target)
+        rhs_expr = self.eval_expr(node.value)
+        self.var_table.add(lhs_name, lhs_expr)
+        self.var_table.add(rhs_name, rhs_expr)
+        op = doc.BinOp(
+            doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+            node.op,
+            doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+            *lhs_pos,
+        )
+        rhs = self.eval_expr(op)
+    lhs = node.target
+    lhs.ctx = doc.Store(*lhs_pos)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+    lhs = node.target
+    rhs = self.eval_expr(node.value)
+    ann_var = self.visit_tvm_annotation(node.annotation)
+    if not isinstance(ann_var, Var):
+        self.report_error(node.annotation, "Annotation should be Var")
+    self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+    frame = T.let(ann_var, rhs)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+    with contextlib.ExitStack() as stack:
+        stack.enter_context(self.var_table.with_frame())
+        for item in node.items:
+            frame = self.eval_expr(item.context_expr)
+            if not isinstance(frame, Frame):
+                self.report_error(
+                    item.context_expr, "Invalid context expression in the with-statement."
+                )
+            rhs = stack.enter_context(frame)
+            if item.optional_vars is not None:
+                self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+        self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    with self.var_table.with_frame():
+        self.var_table.add("range", T.serial)
+        with T.prim_func():
+            T.func_name(node.name)
+            if node.returns is not None:
+                ret_type = self.eval_expr(node.returns)
+                if callable(ret_type):
+                    ret_type = PrimType(ret_type().dtype)
+                T.func_ret(ret_type)
+            with self.with_dispatch_token("tir"):
+                self.visit(node.args)
+                self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+    # TODO: handle different types of arguments:
+    # - vararg: arg | None
+    # - kwonlyargs: list[arg]
+    # - kw_defaults: list[expr | None]
+    # - kwarg: arg | None
+    # - defaults: list[expr]
+    # - posonlyargs: list[arg]
+    arg: doc.arg
+    for arg in node.args:
+        if arg.annotation is None:
+            self.report_error(arg, "Type annotation is required for function parameters.")
+        param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+        self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+    annotation = self.eval_expr(node)
+    if callable(annotation):
+        annotation = annotation()
+    return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+    res = self.eval_expr(node.value)
+    if isinstance(res, Frame):
+        res.add_callback(partial(res.__exit__, None, None, None))
+        res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+    with self.var_table.with_frame():
+        with T.If(self.eval_expr(node.test)):
+            with T.Then():
+                self.visit_body(node.body)
+            if node.orelse:
+                with T.Else():
+                    self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+    cond = self.eval_expr(node.test)
+    msg = self.eval_expr(node.msg)
+    frame = T.Assert(cond, msg)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="Return")
+def visit_return(self: Parser, node: doc.Return) -> None:
+    self.report_error(node, "Return is not allowed.")
diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py
deleted file mode 100644
index e7d90dd515..0000000000
--- a/python/tvm/script/registry.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Function Registry """
-# pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
-import types
-from typing import Union, Callable, Dict, Optional, Any
-
-
-class Registry(object):
-    """Registration map
-    All these maps are static
-    """
-
-    registrations: Dict[str, type] = dict()
-
-    @staticmethod
-    def lookup(name: str) -> Optional[Any]:
-        if name in Registry.registrations:
-            # every time we create a new handler
-            # since we may want to keep some local info inside it
-            return Registry.registrations[name]()
-        return None
-
-
-def register(inputs: Union[Callable, type]) -> type:
-    """Register Intrin/ScopeHandler/SpecialStmt"""
-    registration: type
-    if isinstance(inputs, types.FunctionType):
-        # is function
-        from .tir.intrin import Intrin
-
-        def create_new_intrin(func) -> type:
-            class NewIntrin(Intrin):
-                def __init__(self):
-                    super().__init__(func)
-
-            return NewIntrin
-
-        registration = create_new_intrin(inputs)
-    elif isinstance(inputs, type):
-        # is class
-        registration = inputs
-    else:
-        raise ValueError()
-
-    key: str = registration().signature()[0]
-    Registry.registrations[key] = registration
-    return registration
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
deleted file mode 100644
index a62fb102be..0000000000
--- a/python/tvm/script/tir/__init__.pyi
+++ /dev/null
@@ -1,477 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=redefined-builtin
-from typing import (
-    Any,
-    Callable,
-    ContextManager,
-    Dict,
-    Iterable,
-    Optional,
-    Tuple,
-    Union,
-    Sequence,
-    List,
-    Mapping,
-    overload,
-)
-from numbers import Number
-import builtins
-
-from tvm.tir.function import PrimFunc
-from tvm.tir import Range
-from tvm.runtime import Object
-from tvm.target import Target
-from .node import BufferSlice
-
-"""
-redefine types
-"""
-
-class PrimExpr:
-    def __init__(self: PrimExpr) -> None: ...
-    @overload
-    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __mod__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ...
-    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __floordiv__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ...
-    def __index__(self: PrimExpr) -> int: ...  # so range doesn't complain
-
-class Var(PrimExpr): ...
-class IterVar(Var): ...
-
-class Buffer:
-    @overload
-    def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ...
-    @overload
-    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
-    @overload
-    def __setitem__(
-        self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr
-    ) -> None: ...
-    @overload
-    def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
-    @property
-    def data(self: Buffer) -> Ptr: ...
-
-"""
-Intrinsic
-"""
-
-def min_value(dtype: str) -> PrimExpr: ...
-def max_value(dtype: str) -> PrimExpr: ...
-def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def ceildiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def abs(x: PrimExpr) -> PrimExpr: ...
-def load(
-    dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None
-) -> PrimExpr: ...
-def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
-def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
-def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
-def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
-def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
-def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
-def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
-def evaluate(value: PrimExpr) -> None: ...
-def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def store(
-    var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
-) -> None: ...
-def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ...
-def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ...
-def preflattened_buffer(
-    buf: Buffer,
-    shape: Sequence[PrimExpr],
-    dtype: str = "float32",
-    data: Optional[Ptr] = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-) -> Buffer: ...
-
-"""
-Intrinsics - tvm builtin
-"""
-
-def tvm_thread_allreduce(
-    *freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str
-) -> PrimExpr: ...
-
-"""
-Unary operator
-Note that any intrinsics not registered in script.tir.intrin
-should add "dtype" as an argument. This is different from their
-definition but intentional.
-"""
-
-def exp(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def erf(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log10(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def tan(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def cos(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def acos(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sin(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def asin(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atan(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
-
-"""
-special_stmt - Buffers
-"""
-
-def match_buffer(
-    param: Union[Var, BufferSlice],
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def decl_buffer(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def buffer_decl(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def alloc_buffer(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-
-"""
-special_stmt - Reads/Writes
-"""
-
-@overload
-def reads(read_regions: List[BufferSlice]) -> None: ...
-@overload
-def reads(*read_regions: BufferSlice) -> None: ...
-@overload
-def writes(write_region: List[BufferSlice]) -> None: ...
-@overload
-def writes(*write_region: BufferSlice) -> None: ...
-def block_attr(attrs: Mapping[str, Object]) -> None: ...
-
-"""
-special_stmt - Axis
-"""
-
-class axis:
-    @overload
-    @staticmethod
-    def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def spatial(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def reduce(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def scan(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def opaque(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @staticmethod
-    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
-
-def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
-
-"""
-special_stmt - Annotations
-"""
-
-def buffer_var(dtype: str, storage_scope: str) -> Var: ...
-def func_attr(attrs: Mapping[str, Union[Object, str, bool, int, float]]) -> None: ...
-def prim_func(input_func: Callable) -> PrimFunc: ...
-
-"""
-special_stmt - Threads and Bindings
-"""
-
-def env_thread(env_name: str) -> IterVar: ...
-def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
-
-"""
-Scope handler
-"""
-
-class block(ContextManager):
-    def __init__(self, name_hint: str = "") -> None: ...
-    def __enter__(self) -> Sequence[IterVar]: ...
-
-class init(ContextManager):
-    def __init__(self) -> None: ...
-
-class let(ContextManager):
-    def __init__(self, var: Var, value: PrimExpr) -> None: ...
-
-def where(cond: PrimExpr) -> None: ...
-def allocate(
-    extents: List[PrimExpr],
-    dtype: str,
-    scope: str,
-    condition: Union[PrimExpr, builtins.bool] = True,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Buffer: ...
-def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ...
-def realize(
-    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
-) -> None: ...
-def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
-def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
-
-"""
-Scope handler - Loops
-"""
-
-@overload
-def serial(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def serial(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def parallel(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def parallel(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def vectorized(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def vectorized(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def unroll(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def unroll(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def thread_binding(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    thread: str,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def thread_binding(
-    end: Union[PrimExpr, int],
-    thread: str,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def for_range(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def for_range(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ...
-
-"""
-ty - redefine types
-"""
-
-class boolean: ...
-
-class handle(Var):
-    @overload
-    def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ...
-    @overload
-    def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ...
-    @overload
-    def __setitem__(
-        self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer
-    ) -> None: ...
-    @overload
-    def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ...
-    @property
-    def data(self: handle) -> Ptr: ...
-
-class Ptr: ...
-
-def target(target_str: Union[str, Mapping[str, Object]]) -> Target: ...
-
-class var(Var):
-    def __init__(self: Var, dtype: str): ...
-
-class bool(PrimExpr):
-    def __init__(self: bool, imm: Union[PrimExpr, builtins.bool, builtins.int]): ...
-
-class int8(PrimExpr):
-    def __init__(self: int8, imm: Union[PrimExpr, int]): ...
-
-class int16(PrimExpr):
-    def __init__(self: int16, imm: Union[PrimExpr, int]): ...
-
-class int32(PrimExpr):
-    def __init__(self: int32, imm: Union[PrimExpr, int]): ...
-
-class int64(PrimExpr):
-    def __init__(self: int64, imm: Union[PrimExpr, int]): ...
-
-class uint8(PrimExpr):
-    def __init__(self: uint8, imm: Union[PrimExpr, int]): ...
-
-class uint16(PrimExpr):
-    def __init__(self: uint16, imm: Union[PrimExpr, int]): ...
-
-class uint32(PrimExpr):
-    def __init__(self: uint32, imm: Union[PrimExpr, int]): ...
-
-class uint64(PrimExpr):
-    def __init__(self: uint64, imm: Union[PrimExpr, int]): ...
-
-class float8(PrimExpr):
-    def __init__(self: float8, imm: Union[PrimExpr, int, float]): ...
-
-class float16(PrimExpr):
-    def __init__(self: float16, imm: Union[PrimExpr, int, float]): ...
-
-class float32(PrimExpr):
-    def __init__(self: float32, imm: Union[PrimExpr, int, float]): ...
-
-class float64(PrimExpr):
-    def __init__(self: float64, imm: Union[PrimExpr, int, float]): ...
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
deleted file mode 100644
index 382431c229..0000000000
--- a/python/tvm/script/tir/intrin.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Intrinsic Classes"""
-# pylint: disable=redefined-builtin, relative-beyond-top-level
-import builtins
-from typing import List, Any
-
-import tvm.tir
-from ..registry import register
-from ...target import codegen
-from ..utils import get_param_list, tvm_span_from_synr
-
-
-class Intrin:
-    def __init__(self, intrin, stmt=False):
-        self.intrin = intrin
-        self.stmt = stmt
-
-    def signature(self):
-        return "tir." + self.intrin.__name__, get_param_list(self.intrin)
-
-    def handle(self, arg_list: List[Any], span: tvm.ir.Span):
-        return self.intrin(*arg_list, span=tvm_span_from_synr(span))
-
-
-@register
-def bool(imm, span):
-    return imm.astype("bool", span)
-
-
-# register all datatypes
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            _name = _dtype + _size + _lanes
-
-            # nest closures so we copy the name string
-            def wrap(name):
-                def f(imm, span):
-                    return imm.astype(name, span)
-
-                f.__name__ = name
-                return f
-
-            _intrin = wrap(_name)
-            register(_intrin)
-
-
-@register
-def min_value(dtype, span):
-    return tvm.tir.min_value(dtype, span)
-
-
-@register
-def max_value(dtype, span):
-    return tvm.tir.max_value(dtype, span)
-
-
-@register
-def floordiv(x, y, span):
-    return tvm.tir.floordiv(x, y, span)
-
-
-@register
-def floormod(x, y, span):
-    return tvm.tir.floormod(x, y, span)
-
-
-@register
-def truncmod(x, y, span):
-    return tvm.tir.truncmod(x, y, span)
-
-
-@register
-def ceildiv(x, y, span):
-    return tvm.tir.ceildiv(x, y, span)
-
-
-@register
-def abs(x, span):
-    return tvm.tir.abs(x, span)
-
-
-@register
-def load(dtype, var, index, predicate=None, span=None):
-    return tvm.tir.Load(dtype, var, index, predicate, span)
-
-
-@register
-def cast(value, dtype, span):
-    return tvm.tir.Cast(dtype, value, span)
-
-
-@register
-def ramp(base, stride, lanes, span):
-    return tvm.tir.Ramp(base, stride, lanes.value, span)
-
-
-@register
-def broadcast(value, lanes, span):
-    return tvm.tir.Broadcast(value, lanes.value, span)
-
-
-@register
-def iter_var(var, dom, iter_type, thread_tag, span):
-    iter_type = getattr(tvm.tir.IterVar, iter_type)
-    return tvm.tir.IterVar(dom, var, iter_type, thread_tag, span)
-
-
-@register
-def max(a, b, span):  # pylint: disable=redefined-builtin
-    return tvm.tir.Max(a, b, span)
-
-
-@register
-def min(a, b, span):  # pylint: disable=redefined-builtin
-    return tvm.tir.Min(a, b, span)
-
-
-def get_axis(begin, end, iter_type, span):
-    ana = tvm.arith.Analyzer()
-    extent = ana.simplify(end - begin)
-    block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)
-
-    iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
-    return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)
-
-
-@register
-def range(begin, end, span):
-    return get_axis(begin, end, "data_par", span)
-
-
-@register
-def reduce_axis(begin, end, span):
-    return get_axis(begin, end, "reduce", span)
-
-
-@register
-def scan_axis(begin, end, span):
-    return get_axis(begin, end, "scan", span)
-
-
-@register
-def opaque_axis(begin, end, span):
-    return get_axis(begin, end, "opaque", span)
-
-
-@register
-def Select(cond, if_body, else_body, span):  # pylint: disable=invalid-name
-    return tvm.tir.Select(cond, if_body, else_body, span)
-
-
-@register
-def Let(var, value, body, span):  # pylint: disable=invalid-name
-    return tvm.tir.Let(var, value, body, span)
-
-
-@register
-class EvaluateIntrin(Intrin):
-    def __init__(self):
-        def evaluate(value, span):
-            return tvm.tir.Evaluate(value, span)
-
-        super().__init__(evaluate, stmt=True)
-
-
-@register
-class StoreIntrin(Intrin):
-    def __init__(self):
-        def store(var, index, value, predicate=True, span=None):
-            return tvm.tir.Store(var, value, index, predicate, span)
-
-        super().__init__(store, stmt=True)
-
-
-@register
-class AssumeIntrin(Intrin):
-    def __init__(self):
-        def assume(constraint, span):
-            return tvm.tir.Evaluate(
-                tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span)
-            )
-
-        super().__init__(assume, stmt=True)
-
-
-@register
-def comm_reducer(lambda_io, identities, span):
-    """Create a CommReducer from lambda inputs/outputs and the identities"""
-    lambda_input = lambda_io[0]
-    lambda_output = lambda_io[1]
-
-    num_args = len(lambda_input)
-    num_arg_per_group = num_args // 2
-    x = [lambda_input[i] for i in builtins.range(0, num_arg_per_group)]
-    y = [lambda_input[i] for i in builtins.range(num_arg_per_group, num_args)]
-
-    if not isinstance(lambda_output, tuple):
-        lambda_output = (lambda_output,)
-
-    return tvm.tir.CommReducer(x, y, lambda_output, identities, span)
-
-
-@register
-def llvm_lookup_intrinsic_id(name, span):
-    # pylint: disable=unused-argument
-    return codegen.llvm_lookup_intrinsic_id(name)
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py
deleted file mode 100644
index 29e79607fb..0000000000
--- a/python/tvm/script/tir/node.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=redefined-builtin
-"""TVM Script nodes."""
-
-from typing import Optional, Union, List, Callable
-import synr
-from tvm.arith import Analyzer
-from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
-
-
-class Slice:
-    """A helper class to present slice information for BufferSlice
-
-    Parameters
-    ----------
-    start : Union[PrimExpr, int]
-        The start index.
-
-    stop : Optional[Union[PrimExpr, int]]
-        The stop index, None means the Slice is an element-wise index
-
-    step : int
-        The slice step
-
-    span : Optional[Span]
-        The location of the slice in the source.
-    """
-
-    start: Union[PrimExpr, int]
-    stop: Optional[Union[PrimExpr, int]]
-    step: int
-    span: Optional[Span]
-
-    def __init__(
-        self,
-        start: Union[PrimExpr, int],
-        stop: Optional[Union[PrimExpr, int]] = None,
-        step: int = 1,
-        span: Optional[Span] = None,
-    ):
-        self.start = start
-        self.stop = stop
-        self.step = step
-        self.span = span
-
-    def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
-        """Helper to create index PrimExpr from slice object
-        Parameters
-        ----------
-        report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-            The error report func
-        """
-        if self.stop is None:
-            # scalar index
-            return self.start
-        if self.step < 1:
-            report_error("Slice's step should be positive integer", self.span)
-        lanes = Analyzer().simplify((self.stop - self.start + self.step - 1) // self.step)
-        if not isinstance(lanes, (int, IntImm)):
-            report_error("Slice's lanes should be constant for buffer indices", self.span)
-        if lanes == 1:
-            return self.start
-        return Ramp(self.start, self.step, int(lanes), self.span)
-
-
-class BufferSlice(ObjectGeneric):
-    """A generic object for representing general buffer access. Following cases are supported:
-        - element wise access buffer[i, j], which can be converted to BufferLoad if necessary
-        - slice access buffer[i: i + 1, j : j + 2]
-        - union of element and slice buffer[i, j: j + 2]
-
-        This node is used in TVMScript to parse BufferLoad, BufferRegion and Realize
-
-    Parameters
-    ----------
-    buffer : Buffer
-        The buffer.
-
-    indices : List[Union[Slice, PrimExpr, int]]
-        The access indexes can be slice, PrimExpr or int.
-
-    report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-        The error report func
-
-    span : Optional[Span]
-        The location of the buffer access in the source.
-    """
-
-    buffer: Buffer
-    slices: List[Slice]
-    report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-    span: Optional[Span]
-
-    def __init__(
-        self,
-        buffer: Buffer,
-        indices: List[Union[Slice, PrimExpr, int]],
-        report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
-        span: Optional[Span] = None,
-    ):
-        def check_index(index: Union[int, PrimExpr]):
-            """Check input index is non-negative integer or PrimExpr"""
-            if isinstance(index, int):
-                if index < 0:
-                    report_error("Negative index is not allowed during buffer access", span)
-            elif isinstance(index, PrimExpr):
-                element_dtype = index.dtype.split("x", maxsplit=1)[0]
-                if element_dtype[:3] != "int":
-                    report_error(
-                        "index expected an integer type PrimExpr but got " + str(index.dtype),
-                        index.span,
-                    )
-            else:
-                report_error(
-                    "Unsupported index type, expected int or tvm.tir.PrimExpr, but got "
-                    + str(type(index)),
-                    span,
-                )
-
-        slices: List[Union[Slice, BufferSlice]] = []
-        for index in indices:
-            if isinstance(index, Slice):
-                index.start, index.stop = [convert(_) for _ in [index.start, index.stop]]
-                check_index(index.start)
-                check_index(index.stop)
-                slices.append(index)
-            elif isinstance(index, (PrimExpr, int)):
-                check_index(index)
-                slices.append(Slice(index))
-            elif isinstance(index, BufferSlice):
-                buffer_load = index.asobject()
-                check_index(buffer_load)
-                slices.append(Slice(buffer_load))
-            else:
-                report_error(
-                    "Unsupported index type for BufferSlice, "
-                    + "expected int, tvm.tir.PrimExpr, tvm.tir.Slice, but got "
-                    + str(type(index)),
-                    span,
-                )
-
-        self.buffer = buffer
-        self.slices = slices
-        self.report_error = report_error
-        self.span = span
-
-    def __str__(self):
-        regions: List[str] = []
-        for s in self.slices:
-            if s.stop is None:
-                regions.append(str(s.start))
-            else:
-                regions.append(str(s.start) + ": " + str(s.stop))
-
-        return self.buffer.name + "[" + ", ".join(regions) + "]"
-
-    def asobject(self) -> BufferLoad:
-        """Convert object."""
-        indices = [s.as_index_expr(self.report_error) for s in self.slices]
-        return BufferLoad(self.buffer, indices, span=self.span)
-
-    def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion:
-        """Construct BufferRegion from BufferSlice
-
-        Parameters
-        ----------
-        analyzer : Optional[tvm.arith.Analyzer]
-            The analyzer for simplifying. If not provided, the method will construct a new one
-
-        Returns
-        -------
-        buffer_region : BufferRegion
-            The constructed BufferRegion.
-        """
-        region: List[Range] = []
-        for s in self.slices:
-            start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start)
-            extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
-            if not analyzer:
-                analyzer = Analyzer()
-            if isinstance(extent, PrimExpr):
-                extent = analyzer.simplify(extent)
-            if s.step != 1:
-                self.report_error("BufferRegion do not support non-trivial stride", s.span)
-            region.append(Range.from_min_extent(start, extent, span=s.span))
-        return BufferRegion(self.buffer, region)
-
-    def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
-        return self.asobject().astype(dtype, span)
-
-    @property
-    def dtype(self) -> str:
-        """Return the dtype referenced by the slice.
-
-        Implemented as a property so that ``slice.dtype`` has the same
-        calling convention as ``primexpr.dtype``.  This allows a
-        BufferSlice object can be assigned to a variable without
-        requiring a type annotation on the variable, similar to other
-        expressions.
-        """
-        return self.asobject().dtype
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
deleted file mode 100644
index da7545c9a9..0000000000
--- a/python/tvm/script/tir/scope_handler.py
+++ /dev/null
@@ -1,788 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Scope Handler Classes"""
-# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
-
-import synr
-import numpy as np
-import tvm.tir
-from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
-
-from .node import BufferSlice
-
-from ..context_maintainer import ContextMaintainer
-from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
-
-
-class ScopeHandler:
-    """Base class for all scope handlers"""
-
-    def __init__(self, func: Callable):
-        self.func: Callable = func
-        self.body: Optional[Stmt] = None
-        self.node: Optional[synr.ast.Node] = None
-        self.context: Optional[ContextMaintainer] = None
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir." + self.func.__name__, get_param_list(self.func)
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        pass
-
-    def exit_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        self.node = node
-        self.context = context
-        return call_with_error_reporting(
-            context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span)
-        )
-
-
-class WithScopeHandler(ScopeHandler):
-    """Base class for all with scope handlers"""
-
-    def __init__(self, func, concise_scope, def_symbol):
-        super().__init__(func)
-        self.concise_scope = concise_scope
-        self.def_symbol = def_symbol
-
-    @staticmethod
-    def get_optional_vars(node, context):
-        """Get a list synr.ast.With's optional_vars"""
-        assert isinstance(
-            node, synr.ast.With
-        ), f"WithScopeHandler expected synr.ast.With but got {type(node)}"
-
-        if isinstance(node.lhs, list):
-            for var in node.lhs:
-                if not isinstance(var, synr.ast.Var):
-                    context.report_error(
-                        f"Invalid optional var definition, expected Var but got {type(var)}",
-                        node.span,
-                    )
-            vars = node.lhs
-        else:
-            context.report_error(
-                f"Invalid optional var definition, expected list of Var but got {type(node.lhs)}",
-                node.span,
-            )
-        return vars
-
-
-@register
-class Allocate(WithScopeHandler):
-    """With scope handler T.allocate(extents, dtype, scope, condition, annotations)"""
-
-    def __init__(self):
-        def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
-            condition = tvm.runtime.convert(condition)
-            scope = tvm.runtime.convert(scope)
-
-            return tvm.tir.Allocate(
-                self.buffer.data,
-                self.buffer.dtype,
-                self.buffer.shape,
-                condition,
-                self.body,
-                annotations=annotations,
-                span=span,
-            )
-
-        super().__init__(allocate, concise_scope=True, def_symbol=True)
-        self.buffer = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer(
-            extents, dtype, scope, condition=True, annotations=None, span: Span = None
-        ):
-            """Setup buffer object for a given type."""
-            self.buffer = tvm.tir.decl_buffer(
-                shape=extents,
-                dtype=dtype,
-                name=name,
-                scope=scope,
-                span=span,
-            )
-
-        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer, node)
-
-
-@register
-class AllocateConst(WithScopeHandler):
-    """With scope handler T.allocate_const(data, extents, dtype, condition)
-
-    TIR constant node to represent non-scalar constant
-    """
-
-    def __init__(self):
-        def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
-            list_data = []
-            for i in raw_data:
-                list_data.append(i.value)
-            nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
-            n = tvm.tir.AllocateConst(
-                self.buffer.data,
-                dtype,
-                shape,
-                nd_data,
-                self.body,
-                annotations=annotations,
-                span=span,
-            )
-            return n
-
-        super().__init__(allocate_const, concise_scope=True, def_symbol=True)
-        self.buffer = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None):
-            """Setup buffer var for a given type."""
-            self.buffer = tvm.tir.decl_buffer(
-                shape=shape,
-                dtype=dtype,
-                name=name,
-                span=span,
-            )
-
-        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer, node)
-
-
-@register
-class DeclBuffer(WithScopeHandler):
-    """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align,
-                                offset_factor, buffer_type, axis_separators)
-    Example
-    -------
-    .. code-block:: python
-        A = T.decl_buffer((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def decl_buffer(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            return tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
-
-        super().__init__(decl_buffer, concise_scope=True, def_symbol=True)
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer(
-            shape,
-            dtype,
-            data,
-            strides,
-            elem_offset,
-            scope,
-            align,
-            offset_factor,
-            buffer_type,
-            axis_separators,
-            span: Span = None,
-        ):
-            self.buffer = tvm.tir.decl_buffer(
-                shape=shape,
-                dtype=dtype,
-                data=data,
-                strides=strides,
-                elem_offset=elem_offset,
-                scope=scope,
-                data_alignment=align,
-                offset_factor=offset_factor,
-                buffer_type=buffer_type,
-                axis_separators=axis_separators,
-                span=span,
-            )
-
-        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer, node)
-
-
-@register
-class LaunchThread(WithScopeHandler):
-    """With scope handler T.launch_thread(env_var, extent)"""
-
-    def __init__(self):
-        def launch_thread(env_var, extent, span):
-            extent = tvm.runtime.convert(extent, span=span)
-            thread_id = self.context.func_var_env_dict[env_var]
-            attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
-            return tvm.tir.AttrStmt(
-                IterVar(
-                    (0, extent),
-                    env_var,
-                    getattr(IterVar, "ThreadIndex"),
-                    thread_id,
-                    span=span,
-                ),
-                attr_key,
-                extent,
-                self.body,
-                span=span,
-            )
-
-        super().__init__(launch_thread, concise_scope=True, def_symbol=False)
-
-
-@register
-class Realize(WithScopeHandler):
-    """With scope handler T.realize(buffer_bounds, scope, condition)"""
-
-    def __init__(self):
-        def realize(
-            buffer_slice: BufferSlice, scope: str, condition: bool = True, span: bool = None
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            buffer: Buffer = buffer_slice.buffer
-            bounds: List[Range] = []
-            for s in buffer_slice.slices:
-                min: Union[PrimExpr, int] = s.start
-                extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start
-                if isinstance(extent, PrimExpr):
-                    extent = self.context.analyzer.simplify(extent)
-                bounds.append(Range.from_min_extent(min, extent, span=s.span))
-
-            scope = tvm.runtime.convert(scope, span=span)
-            return tvm.tir.AttrStmt(
-                buffer,
-                "realize_scope",
-                scope,
-                tvm.tir.BufferRealize(buffer, bounds, condition, self.body, span=span),
-                span=span,
-            )
-
-        super().__init__(realize, concise_scope=True, def_symbol=False)
-
-
-@register
-class Attr(WithScopeHandler):
-    """With scope handler T.attr(attr_node, attr_key, value)"""
-
-    def __init__(self):
-        def attr(attr_node, attr_key, value, span):
-            attr_node = tvm.runtime.convert(attr_node, span=span)
-            value = tvm.runtime.convert(value, span=span)
-            return tvm.tir.AttrStmt(attr_node, attr_key, value, self.body, span=span)
-
-        super().__init__(attr, concise_scope=True, def_symbol=False)
-
-
-@register
-class AssertHandler(WithScopeHandler):
-    """With scope handler T.Assert(condition, message)"""
-
-    def __init__(self):
-        def Assert(condition, message, span):
-            return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.body, span=span)
-
-        super().__init__(Assert, concise_scope=True, def_symbol=False)
-
-
-@register
-class Let(WithScopeHandler):
-    """With scope handler T.let(var, value)"""
-
-    def __init__(self):
-        def let(var, value, span):
-            return tvm.tir.LetStmt(var, value, self.body, span=span)
-
-        super().__init__(let, concise_scope=False, def_symbol=False)
-
-    def __call__(self, var: tvm.tir.Var, value: tvm.tir.PrimExpr, body: tvm.tir.PrimExpr):
-        return tvm.tir.Let(var, value, body)
-
-
-@register
-class Block(WithScopeHandler):
-    """With scope handler T.block(name)"""
-
-    def __init__(self):
-        def block(name_hint: str = "", span: Optional[Span] = None):
-            assert (
-                self.node and self.context and self.body
-            ), "call 'exit_scope' before 'enter_scope'"
-            block_info = self.context.block_info_stack[-1]
-
-            # create block read/write regions
-            reads: List[BufferRegion] = (
-                [read.as_buffer_region() for read in block_info.reads] if block_info.reads else []
-            )
-            writes: List[BufferRegion] = (
-                [write.as_buffer_region() for write in block_info.writes]
-                if block_info.writes
-                else []
-            )
-
-            region_detect_mask: int = (block_info.reads is None) | (
-                (block_info.writes is None) << 1
-            )
-            annotations = {} if block_info.annotations is None else block_info.annotations
-            if region_detect_mask != 0:
-                annotations["tir.script_parsing_detect_access"] = region_detect_mask
-            inner = tvm.tir.Block(
-                block_info.iter_vars,
-                reads,
-                writes,
-                name_hint,
-                self.body,
-                block_info.init,
-                block_info.alloc_buffers,
-                block_info.match_buffers,
-                annotations,
-                span,
-            )
-            assert len(block_info.iter_vars) == len(block_info.iter_values)
-            predicate = (
-                tvm.tir.const(True, "bool")
-                if block_info.predicate is None
-                else block_info.predicate
-            )
-            body = tvm.tir.BlockRealize(block_info.iter_values, predicate, inner, span)
-            return body
-
-        super().__init__(func=block, concise_scope=False, def_symbol=True)
-        self.block_vars = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define block vars
-        assert isinstance(
-            node, synr.ast.With
-        ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}"
-
-        optional_vars = [var.id.name for var in WithScopeHandler.get_optional_vars(node, context)]
-        if optional_vars:
-            context.report_error(
-                f"Block expected no optional_vars (e.g., `x` in `with block() as x`), "
-                f"but got {optional_vars}",
-                node.span,
-            )
-
-
-@register
-class InitBlock(WithScopeHandler):
-    """With scope handler T.init()"""
-
-    def __init__(self):
-        def init(span: Span = None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            if self.context.block_info_stack[-2].init is not None:
-                self.context.report_error("Duplicate init block declaration", span)
-            self.context.block_info_stack[-2].init = self.body
-
-        super().__init__(func=init, concise_scope=False, def_symbol=True)
-
-
-class LoopInfo:
-    """Helper class for loop information"""
-
-    loop_var: Var
-    begin: PrimExpr
-    extent: PrimExpr
-    kind: ForKind
-    thread_binding: Optional[str]
-    annotations: Optional[Mapping[str, Object]]
-
-    def __init__(
-        self,
-        begin: PrimExpr,
-        extent: PrimExpr,
-        kind: ForKind,
-        thread_binding: Optional[str] = None,
-        annotations: Optional[Mapping[str, Object]] = None,
-    ) -> None:
-        self.begin = begin
-        self.extent = extent
-        self.kind = kind
-        self.thread_binding = thread_binding
-        self.annotations = annotations
-
-
-class ForScopeHandler(ScopeHandler):
-    """Base class for all for scope handlers"""
-
-    def __init__(self, func):
-        super().__init__(func)
-        self.loop_vars: List[Var] = []
-        self.loop_info: List[LoopInfo] = []
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        assert isinstance(
-            node, synr.ast.For
-        ), f"ForScopeHandler expected synr.ast.For but got {type(node)}"
-
-        loop_var_names = list()
-        spans = list()
-        if isinstance(node.lhs, synr.ast.Var):
-            loop_var_names.append(node.lhs.id.name)
-            spans.append(tvm_span_from_synr(node.lhs.id.span))
-        elif isinstance(node.lhs, list):
-            for elt in node.lhs:
-                if not isinstance(elt, synr.ast.Var):
-                    context.report_error(
-                        f"Invalid loop var. Expected a var, but got {type(elt)}", elt.span
-                    )
-                loop_var_names.append(elt.id.name)
-                spans.append(tvm_span_from_synr(elt.id.span))
-        else:
-            context.report_error(
-                f"Invalid loop var. Expected var or list of vars as lhs, but got {type(node.lhs)}",
-                span,
-            )
-
-        self.node = node
-        self.context = context
-        # collect loop infos by calling self.func
-        call_with_error_reporting(context.report_error, span, self.func, *arg_list)
-        if len(loop_var_names) != len(self.loop_info):
-            self.context.report_error(
-                f"Inconsistent number of vars and loops, got {len(loop_var_names)} "
-                + f"vs {len(self.loop_info)}",
-                self.node.span,
-            )
-        # generate loop vars
-        self.loop_vars = []
-        for name, lv_span, li in zip(loop_var_names, spans, self.loop_info):
-            if not li.begin.dtype.startswith("int"):
-                raise NotImplementedError(f"Unsupported dtype in loop begin: {li.begin.dtype}")
-            if not li.extent.dtype.startswith("int"):
-                raise NotImplementedError(f"Unsupported dtype in loop extent: {li.extent.dtype}")
-            dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype] else "int32"
-            self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span))
-
-        for loop_var, loop_info in zip(self.loop_vars, self.loop_info):
-            context.update_symbol(loop_var.name, loop_var, node)
-            context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent)
-
-    def exit_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
-        for loop_var in self.loop_vars:
-            context.loop_stack.pop(loop_var)
-        # Use assert here since we have check it in `enter_scope`
-        assert len(self.loop_vars) == len(self.loop_info)
-
-        body = self.body
-        for var, info in zip(reversed(self.loop_vars), reversed(self.loop_info)):
-            body = tvm.tir.For(
-                var,
-                info.begin,
-                info.extent,
-                info.kind,
-                body,
-                info.thread_binding,
-                info.annotations,
-                span=tvm_span_from_synr(span),
-            )
-
-        return body
-
-    def create_loop_info(
-        self,
-        begin: PrimExpr,
-        end: PrimExpr,
-        kind: ForKind,
-        thread_binding: Optional[str] = None,
-        annotations: Optional[Mapping[str, Object]] = None,
-    ) -> None:
-        """
-        Helper function for creating For in TVM Script parser.
-
-        Parameters
-        ----------
-        begin : PrimExpr
-            The beginning value.
-
-        end : PrimExpr
-            The endding value.
-
-        kind : ForKind
-            The type of the for.
-
-        thread_binding: Optional[str]
-            The thread this loop binds to.
-
-        annotations : Optional[Mapping[str, Object]]
-            Additional annotation hints.
-
-        span : Optional[Span]
-            The location of this for in the source code.
-
-        Returns
-        -------
-        for : For
-            The constructed For.
-        """
-        begin, end = [convert(_) for _ in [begin, end]]
-        assert self.context and self.node, "call 'exit_scope' before 'enter_scope'"
-        extent = end if begin == 0 else self.context.analyzer.simplify(end - begin)
-        self.annotations: Mapping[str, Object] = {}
-        if annotations is not None:
-            self.annotations = {
-                key: String(val) if isinstance(val, str) else val
-                for key, val in annotations.items()
-            }
-
-        self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations))
-
-
-@register
-class Serial(ForScopeHandler):
-    """For scope handler T.serial(begin, end, annotations)"""
-
-    def __init__(self):
-        def serial(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
-
-        super().__init__(serial)
-
-
-@register
-class Parallel(ForScopeHandler):
-    """For scope handler T.parallel(begin, end, annotations)"""
-
-    def __init__(self):
-        def parallel(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations)
-
-        super().__init__(parallel)
-
-
-@register
-class Vectorized(ForScopeHandler):
-    """For scope handler T.vectorized(begin, end, annotations)"""
-
-    def __init__(self):
-        def vectorized(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations)
-
-        super().__init__(vectorized)
-
-
-@register
-class Unroll(ForScopeHandler):
-    """For scope handler T.unroll(begin, end, annotations)"""
-
-    def __init__(self):
-        def unroll(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations)
-
-        super().__init__(unroll)
-
-
-@register
-class ThreadBinding(ForScopeHandler):
-    """For scope handler T.thread_binding(begin, end, thread, annotations)"""
-
-    def __init__(self):
-        def thread_binding(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            thread: str = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if thread is None:
-                if isinstance(end, str):  # handle case like thread_binding(128, "threadIdx.x")
-                    thread = end
-                    end = None
-                else:
-                    raise ValueError("Thread cannot be None for thread_binding")
-            if end is None:
-                end = begin
-                begin = 0
-            thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread)
-            self.create_loop_info(
-                begin,
-                end,
-                ForKind.THREAD_BINDING,
-                thread_binding=thread_iter_var,
-                annotations=annotations,
-            )
-
-        super().__init__(thread_binding)
-
-
-@register
-class RangeHandler(ForScopeHandler):
-    """For scope handler range(begin, end, annotations)
-    Note that tir.range is totally the same as T.serial
-    """
-
-    def __init__(self):
-        def for_range(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
-
-        super().__init__(for_range)
-
-    def signature(self):
-        return "range", get_param_list(self.func)
-
-
-@register
-class Grid(ForScopeHandler):
-    """For scope handler T.grid(extents)"""
-
-    def __init__(self):
-        def grid(*extents: List[PrimExpr]):
-            for extent in extents:
-                self.create_loop_info(0, extent, ForKind.SERIAL)
-
-        super().__init__(grid)
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
deleted file mode 100644
index 15502055b7..0000000000
--- a/python/tvm/script/tir/special_stmt.py
+++ /dev/null
@@ -1,964 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Special Stmt Classes"""
-# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
-# pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
-
-import synr
-from synr import ast
-from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
-from tvm.runtime import Object, String
-from tvm.target import Target
-from tvm.ir import Span
-from tvm.tir import IntImm, IterVar, Var
-
-from .node import BufferSlice
-
-from ..context_maintainer import BlockInfo, ContextMaintainer
-from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
-
-
-def convert_to_int(
-    value: Union[IntImm, int],
-    arg_name: str,
-    report_error: Callable,
-    span: Union[Span, synr.ast.Span],
-) -> int:
-    """convert a const int or TVM IntImm to Python int.
-    Reports an error when input cannot be converted to int.
-
-    Parameters
-    ----------
-    value : Union[tvm.tir.IntImm, int]
-        The input value to be converted.
-    arg_name : str
-        Function argument name for error reporting.
-    report_error: Callable
-        The report error function handle
-    span : Union[synr.ast.Span, tvm.ir.Span]
-        Location of the error
-    """
-    if isinstance(value, IntImm):
-        return value.value
-    if isinstance(value, int):
-        return value
-    report_error(
-        f"Expected int or IntImm for {arg_name}, but got {str(type(value))}",
-        span,
-    )
-
-
-class SpecialStmt:
-    """Base class for all Special Stmts"""
-
-    def __init__(self, func: Callable, def_symbol: bool):
-        self.func: Callable = func
-        self.def_symbol: bool = def_symbol
-        self.node: Optional[synr.ast.Node] = None
-        self.context: Optional[ContextMaintainer] = None
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir." + self.func.__name__, get_param_list(self.func)
-
-    def handle(
-        self,
-        node: ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        self.node = node
-        self.context = context
-        return call_with_error_reporting(
-            context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span)
-        )
-
-
-@register
-class MatchBuffer(SpecialStmt):
-    """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align,
-                                 offset_factor, buffer_type, axis_separators)
-
-    Note
-    ----
-    This Special Stmt will perform different behavior depends on the type of param.
-    If the param is a var in function parameter, it will create a buffer from DLTensor.
-    Else if the param is a subregion of other buffers, then create a subregion match inside a block.
-
-    Example
-    -------
-    Match buffer from function parameter
-    .. code-block:: python
-        A = T.match_buffer(a, (128, 128), dtype="float32")
-
-    Match buffer from Buffer subregion
-    .. code-block:: python
-        A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def match_buffer(
-            param,
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`match_buffer` must be assigned to a single buffer, "
-                    "e.g. A = match_buffer(...)",
-                    self.node.span,
-                )
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            if isinstance(param, tvm.tir.Var):
-                if param not in self.context.func_params:
-                    self.context.report_error(
-                        "Can not bind non-input param to buffer", self.node.rhs.params[0].span
-                    )
-                self.context.func_buffer_map[param] = buffer
-            elif isinstance(param, BufferSlice):
-                buffer_region = param.as_buffer_region()
-                self.context.current_block_scope().match_buffers.append(
-                    tvm.tir.MatchBufferRegion(buffer, buffer_region)
-                )
-            else:
-                self.context.report_error(
-                    "The source of match_buffer expected Var or BufferSlice, but got "
-                    + str(type(param)),
-                    self.node.rhs.params[0].span,
-                )
-            self.context.update_symbol(buffer_name, buffer, self.node)
-
-        super().__init__(match_buffer, def_symbol=True)
-
-
-@register
-class BufferDeclare(SpecialStmt):
-    """Special Stmt buffer_decl(shape, dtype, data, strides, elem_offset, scope, align,
-                                offset_factor, buffer_type, axis_separators)
-    Example
-    -------
-    .. code-block:: python
-        A = T.buffer_decl((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def buffer_decl(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`buffer_decl` must be assigned to a single buffer, e.g. A = buffer_decl(...)",
-                    self.node.span,
-                )
-
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            self.context.update_symbol(buffer_name, buffer, self.node)
-            return buffer
-
-        super().__init__(buffer_decl, def_symbol=True)
-
-
-@register
-class AllocBuffer(SpecialStmt):
-    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align,
-                                     offset_factor, buffer_type, axis_separators)
-
-    Example
-    -------
-    .. code-block:: python
-
-        A = T.alloc_buffer((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def alloc_buffer(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`alloc_buffer` must be assigned to a single buffer, "
-                    "e.g. A = alloc_buffer(...)",
-                    self.node.span,
-                )
-
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            if self.context.current_block_scope():
-                self.context.current_block_scope().alloc_buffers.append(buffer)
-            else:
-                # If it is allocated outside all blocks, allocate it under root block.
-                self.context.root_alloc_buffers.append(buffer)
-            self.context.update_symbol(buffer_name, buffer, self.node)
-
-        super().__init__(alloc_buffer, def_symbol=True)
-
-
-@register
-class BlockReads(SpecialStmt):
-    """Special function reads([read_regions], *other_regions)
-
-    Note
-    ----
-    *other_region is an unpackable list of BufferSlice to support
-    reads syntax sugar like reads(BufferRegion1, BufferRegion2, ...)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
-    """
-
-    def __init__(self):
-        def reads(
-            *read_regions: Union[BufferSlice, List[BufferSlice]],
-            span: Span = None,
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare read regions inside a block.",
-                    span,
-                )
-            if block_scope.reads is not None:
-                self.context.report_error(
-                    "Duplicate write region declaration, "
-                    + "previous one is "
-                    + str(", ".join(str(x) for x in block_scope.reads)),
-                    span,
-                )
-            if len(read_regions) > 1:
-                for read_region in read_regions:
-                    if not isinstance(read_region, BufferSlice):
-                        self.context.report_error(
-                            "Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
-                            + f" but got {type(read_regions)}",
-                            span,
-                        )
-            elif len(read_regions) == 1:
-                if isinstance(read_regions[0], list):
-                    read_regions = read_regions[0]
-
-            block_scope.reads = read_regions
-
-        super().__init__(reads, def_symbol=False)
-
-
-@register
-class BlockWrites(SpecialStmt):
-    """Special function writes([write_regions], *other_regions)
-
-    Note
-    ----
-    *other_region is an unpackable list of BufferSlice to support
-    writes syntax sugar like writes(BufferRegion1, BufferRegion2, ...)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.writes([C[vi: vi + 4, vj])
-    """
-
-    def __init__(self):
-        def writes(
-            *write_regions: Union[BufferSlice, List[BufferSlice]],
-            span: Span = None,
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare write regions inside a block.",
-                    span,
-                )
-            if block_scope.writes is not None:
-                self.context.report_error(
-                    "Duplicate write region declaration, "
-                    + "previous one is "
-                    + str(", ".join(str(x) for x in block_scope.writes)),
-                    span,
-                )
-            if len(write_regions) > 1:
-                for write_region in write_regions:
-                    if not isinstance(write_region, BufferSlice):
-                        self.context.report_error(
-                            "Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
-                            + f" but got {type(write_regions)}",
-                            span,
-                        )
-            elif len(write_regions) == 1:
-                if isinstance(write_regions[0], list):
-                    write_regions = write_regions[0]
-            block_scope.writes = write_regions
-
-        super().__init__(writes, def_symbol=False)
-
-
-@register
-class BlockAttr(SpecialStmt):
-    """Special function block_attr({attr_key: attr_value})
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.block_attr({"double_buffer_scope": 1})
-    """
-
-    def __init__(self):
-        def block_attr(attrs: Mapping[str, Object], span: Span = None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare block annotations inside a block.",
-                    span,
-                )
-            if block_scope.annotations is not None:
-                self.context.report_error(
-                    "Duplicate block annotations declaration, "
-                    + "previous one is "
-                    + str(block_scope.annotations),
-                    span,
-                )
-            attrs = {
-                key: String(val) if isinstance(val, str) else val for key, val in attrs.items()
-            }
-            block_scope.annotations = attrs
-
-        super().__init__(block_attr, def_symbol=False)
-
-
-class BlockAxis(SpecialStmt):
-    """Special stmt for defining a spatial block axis
-    axis.S(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.S(128, i * 4 + j)
-    """
-
-    def axis(
-        self,
-        var_name: str,
-        dom: Union[PrimExpr, Range],
-        value: PrimExpr,
-        iter_type: int,
-        span: Optional[Span] = None,
-    ) -> None:
-        """
-        Helper function for creating block axis
-
-        Parameters
-        ----------
-        var_name : str
-            The name_hint of var
-
-        dom : Union[PrimExpr, Range]
-            The iter domain.
-
-        value : PrimExpr
-            The binding value
-
-        iter_type : int
-            The iteration type.
-
-        span : Optional[Span]
-            The location of this for in the source code.
-        """
-        assert self.context, "call 'exit_scope' before 'enter_scope'"
-        block_scope: BlockInfo = self.context.current_block_scope()
-        if block_scope is None:
-            self.context.report_error(
-                "Expected to declare block axes inside a block.",
-                self.node.span,
-            )
-        if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]:
-            self.context.report_error("Duplicate block axis " + var_name, self.node.span)
-
-        dom = tvm.runtime.convert(dom)
-        if isinstance(dom, PrimExpr):
-            dom = tvm.ir.Range(dom)
-        elif isinstance(dom, tvm.ir.container.Array) and len(dom) == 2:
-            dom = tvm.ir.Range(dom[0], dom[1])
-        elif not isinstance(dom, tvm.ir.Range):
-            self.context.report_error(
-                f"Block axis domain expected PrimExpr or Range, but got {type(dom)}",
-                self.node.span,
-            )
-        block_var = tvm.tir.Var(var_name, dtype=dom.extent.dtype)
-        value = tvm.runtime.convert(value)
-        if not isinstance(value, PrimExpr):
-            self.context.report_error(
-                f"Block axis value expected PrimExpr, but got {type(value)}",
-                self.node.span,
-            )
-        iter_var = tvm.tir.IterVar(dom, block_var, iter_type)
-        block_scope.iter_vars.append(iter_var)
-        block_scope.iter_values.append(value)
-        self.context.update_symbol(var_name, block_var, self.node)
-
-
-@register
-class BlockAxisSpatial(BlockAxis):
-    """Special stmt for defining a spatial block axis
-    axis.spatial(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.spatial(128, k)
-    """
-
-    def __init__(self):
-        def axis_spatial(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.spatial` must be assigned to a var, e.g. vi = axis.spatial(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar)
-
-        super().__init__(axis_spatial, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.spatial", get_param_list(self.func)
-
-
-@register
-class BlockAxisS(BlockAxis):
-    """The sugar special stmt for defining a spatial block axis
-    axis.S(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.S(128, k)
-    """
-
-    def __init__(self):
-        def axis_spatial(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.S` must be assigned to a var, e.g. vi = axis.S(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar)
-
-        super().__init__(axis_spatial, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.S", get_param_list(self.func)
-
-
-@register
-class BlockAxisReduce(BlockAxis):
-    """Special stmt for defining a reduce block axis
-    axis.reduce(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.reduce(128, k)
-    """
-
-    def __init__(self):
-        def axis_reduce(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.reduce` must be assigned` to a var, e.g. vi = axis.reduce(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce)
-
-        super().__init__(axis_reduce, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.reduce", get_param_list(self.func)
-
-
-@register
-class BlockAxisR(BlockAxis):
-    """The sugar special stmt for defining a reduce block axis
-    axis.R(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.R(128, k)
-    """
-
-    def __init__(self):
-        def axis_reduce(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.R` must be assigned to a var, e.g. vi = axis.R(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce)
-
-        super().__init__(axis_reduce, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.R", get_param_list(self.func)
-
-
-@register
-class BlockAxisScan(BlockAxis):
-    """Special stmt for defining a ordered block axis
-    axis.scan(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.scan(128, k)
-    """
-
-    def __init__(self):
-        def axis_scan(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.scan` must be assigned to a var, e.g. vi = axis.scan(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.Ordered)
-
-        super().__init__(axis_scan, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.scan", get_param_list(self.func)
-
-
-@register
-class BlockAxisOpaque(BlockAxis):
-    """Special stmt for defining a opaque block axis
-    axis.opaque(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.opaque(128, k)
-    """
-
-    def __init__(self):
-        def axis_opaque(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.opaque` must be assigned to a var, e.g. vi = axis.opaque(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DimInfo)
-
-        super().__init__(axis_opaque, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.opaque", get_param_list(self.func)
-
-
-@register
-class BlockAxisRemap(BlockAxis):
-    """Special stmt for remapping loops vars to block axes.
-    axis.remap(iter_type, iter_value)
-
-    Note
-    ----
-    Iter_type is a string consisting of 'S' and 'R', where 'S' means
-    for spatial and 'R' means for reduce.
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi, vj = T.axis.remap("SS", [i, j])
-    """
-
-    def __init__(self):
-        def axis_remap(iter_types: str, loop_vars: List[tvm.tir.expr.Var], span: Span = None):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) >= 1:
-                self.context.report_error(
-                    "`axis.remap` must be assigned to one or more vars, "
-                    "e.g. vi, vj = axis.remap(...)",
-                    self.node.span,
-                )
-            var_num: int = len(self.node.lhs)
-            if var_num != len(iter_types):
-                self.context.report_error(
-                    f"`iter_type` expected {var_num} charactor(s), "
-                    f"but got {len(iter_types)}: {iter_types}",
-                    span,
-                )
-            if var_num != len(loop_vars):
-                self.context.report_error(
-                    f"`iter_type` expected {var_num} loop var(s), "
-                    f"but got {len(loop_vars)}: {loop_vars}",
-                    span,
-                )
-            for var, iter_ty, loop_var in zip(self.node.lhs, iter_types, loop_vars):
-                iter_type: int
-                if iter_ty == "S":
-                    iter_type = IterVar.DataPar
-                elif iter_ty == "R":
-                    iter_type = IterVar.CommReduce
-                else:
-                    self.context.report_error(
-                        f'`iter_type` only expected "S" (for spatial) or "R" (for reduce), '
-                        f'but got "{iter_ty}"',
-                        span,
-                    )
-
-                if not isinstance(loop_var, tvm.tir.expr.Var):
-                    self.context.report_error(
-                        f"Values of `axis.remap` expected single loop var, but got {loop_var}",
-                        loop_var.span,
-                    )
-                loops = self.context.loop_stack
-                if loop_var not in loops:
-                    self.context.report_error(
-                        f"Cannot find loop var {loop_var} in loop nesting.",
-                        span,
-                    )
-                self.axis(var.id.name, loops[loop_var], loop_var, iter_type)
-
-        super().__init__(axis_remap, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.remap", get_param_list(self.func)
-
-
-@register
-class BlockPredicate(SpecialStmt):
-    """Special function where(predicate)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.where(i < 4)
-    """
-
-    def __init__(self):
-        def where(predicate, span=None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare the predicate inside a block.",
-                    span,
-                )
-            if block_scope.predicate is not None:
-                self.context.report_error(
-                    "Duplicate block predicate declaration, "
-                    + "previous one is "
-                    + str(block_scope.predicate),
-                    span,
-                )
-
-            block_scope.predicate = predicate
-
-        super().__init__(where, def_symbol=False)
-
-
-@register
-class VarDef(SpecialStmt):
-    """Special function for defining a Var"""
-
-    def __init__(self):
-        def var(dtype, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"VarDef expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            v = Var(names[0], dtype, span=span)
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(var, def_symbol=True)
-
-
-@register
-class BufferVarDef(SpecialStmt):
-    """Special function for defining a variable of pointer type"""
-
-    def __init__(self):
-        def buffer_var(dtype, storage_scope, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"BufferVarDef expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
-            v = Var(names[0], ptr_type, span=span)
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(buffer_var, def_symbol=True)
-
-
-@register
-class EnvThread(SpecialStmt):
-    """Bind a var to thread env"""
-
-    def __init__(self):
-        def env_thread(env_name, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"EnvThread expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            v = Var(names[0], dtype="int32", span=span)
-            self.context.func_var_env_dict[v] = env_name
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(env_thread, def_symbol=True)
-
-
-@register
-class FuncAttr(SpecialStmt):
-    """Special Stmt for declaring the DictAttr of PrimFunc
-    Example
-    -------
-    .. code-block:: python
-         T.func_attr({"tir.noalias": True, "global_symbol"})
-    """
-
-    def __init__(self):
-        def func_attr(dict_attr, span):
-            self.context.func_dict_attr = dict_attr
-
-        super().__init__(func_attr, def_symbol=False)
-
-
-@register
-class PreflattenedBufferMap(SpecialStmt):
-    """Special Stmt for declaring the PrimFunc::preflattened_buffer_map
-
-    Example
-    -------
-    .. code-block:: python
-         A0 = T.match_buffer(A, (48,), dtype="float32")
-         T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
-    """
-
-    def __init__(self):
-        def preflattened_buffer(
-            postflattened,
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            span=None,
-        ):
-
-            param = None
-            for key, value in self.context.func_buffer_map.items():
-                if value.same_as(postflattened):
-                    param = key
-                    break
-
-            assert (
-                param is not None
-            ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."
-
-            if data is None:
-                data = self.context.func_buffer_map[param].data
-
-            buffer_name: str = f"{postflattened.name}_preflatten"
-            if align != -1:
-                if isinstance(align, IntImm):
-                    align = align.value
-                else:
-                    assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"
-
-            if offset_factor != 0:
-                if isinstance(offset_factor, IntImm):
-                    offset_factor = offset_factor.value
-                else:
-                    assert isinstance(
-                        offset_factor, int
-                    ), f"offset_factor: want int or IntImm, got {offset_factor!r}"
-
-            preflattened = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                span=span,
-            )
-
-            self.context.func_preflattened_buffer_map[param] = preflattened
-
-        super().__init__(preflattened_buffer, def_symbol=False)
-
-
-@register
-class TargetAttrValue(SpecialStmt):
-    """Special Stmt for target attr value.
-    Example
-    -------
-    .. code-block:: python
-        T.target("llvm")
-    """
-
-    def __init__(self):
-        def target(*args, span):
-            self.context.report_error(f"T.target should not appear as a stmt", span)
-
-        super().__init__(target, def_symbol=False)
-
-    def __call__(self, target_config):
-        if not isinstance(target_config, (str, dict)):
-            raise ValueError(
-                f"T.target expected a config dict or string, but got {type(target_config)}"
-            )
-        return Target(target_config)
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
deleted file mode 100644
index 4548102a9e..0000000000
--- a/python/tvm/script/tir/ty.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Typing Class for TIR
-
-This module provides typing class for TVM script type annotation usage, it can be viewed as
-a wrapper for uniform Type system in IR
-"""
-# pylint: disable=invalid-name
-from numbers import Integral
-
-import tvm
-from .special_stmt import SpecialStmt, convert_to_int
-
-
-class TypeGeneric:  # pylint: disable=too-few-public-methods
-    """Base class for all the TVM script typing class"""
-
-    def evaluate(self):
-        """Return an actual ir.Type Object that this Generic class wraps"""
-        raise TypeError("Cannot get tvm.Type from a generic type")
-
-    def require_type_generic_at(self, idx):  # pylint: disable=unused-argument
-        """If True, the `idx`th type argument must be TypeGeneric"""
-        return True
-
-    # This function is added here to avoid a pylint error
-    # for T.int/float below not being callable
-    def __call__(self):
-        raise NotImplementedError()
-
-
-class ConcreteType(TypeGeneric):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for uniform Type objects
-
-    Params
-    ------
-    vtype: Union[str, tvm.ir.Type]
-
-        The IR type represented by the type annotation.  If a string
-        (e.g. "float32"), this represents a `ir.PrimType` generated
-        from that string.  If a `ir.Type` is provided, this represents
-        the type provided.
-    """
-
-    def __init__(self, vtype):
-        if isinstance(vtype, tvm.ir.Type):
-            self.type = vtype
-        else:
-            self.type = tvm.ir.PrimType(vtype)
-
-    def __call__(self, *args):  # pylint: disable=arguments-differ
-        pass
-
-    def evaluate(self):
-        return self.type
-
-
-class VoidType(ConcreteType):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for void type"""
-
-    def __init__(self):
-        super().__init__("")
-
-
-class GenericPtrType(TypeGeneric):  # pylint: disable=abstract-method
-    """TVM script typing class generator for PtrType
-
-    [] operator is overloaded, accepts a ConcreteType and an optional storage scope string,
-    returns a ConcreteType wrapping PtrType
-    """
-
-    def __getitem__(self, args):
-        if isinstance(args, TypeGeneric):
-            args = [args]
-        if len(args) == 1:
-            vtype, scope = args[0], "global"
-        elif len(args) == 2:
-            vtype, scope = args[0], args[1]
-        else:
-            raise TypeError(f"Illegal type argument num for Ptr")
-        if not isinstance(vtype, TypeGeneric):
-            raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
-        if not isinstance(scope, str):
-            raise TypeError(f"Ptr expects storage scope argument be a string")
-        return ConcreteType(tvm.ir.PointerType(vtype.evaluate(), scope))
-
-    def require_type_generic_at(self, idx):
-        return idx != 1  # the second argument is storage scope for Ptr
-
-
-class GenericTupleType(TypeGeneric):  # pylint: disable=abstract-method
-    """TVM script typing class generator for TupleType
-
-    [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType
-    wrapping TupleType
-    """
-
-    def __getitem__(self, vtypes):
-        if isinstance(vtypes, TypeGeneric):
-            vtypes = [vtypes]
-        return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
-
-
-class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for uniform Type objects"""
-
-    def __init__(self, vtype):
-        def match_buffer_syntax_sugar(
-            shape,
-            dtype: str = "float32",
-            name: str = None,
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            return buffer
-
-        self.type = vtype
-        super().__init__(match_buffer_syntax_sugar, def_symbol=True)
-
-    def __call__(
-        self,
-        shape,
-        dtype="float32",
-        *,
-        name: str = None,
-        data=None,
-        strides=None,
-        elem_offset=None,
-        scope="global",
-        align=-1,
-        offset_factor=0,
-        buffer_type="default",
-        axis_separators=None,
-        span=None,
-    ):
-        """
-        This function is for Buffer(...) syntax sugar.
-        """
-        pass  # pylint: disable=unnecessary-pass
-
-    def __getitem__(self, args):
-        """
-        This function is for Buffer[...] syntax sugar
-        Note that args is the list of all arguments
-        """
-        if len(args) < 2:
-            raise ValueError("T.Buffer[...] needs at least two arguments: shape and dtype.")
-
-        shape = args[0]
-        dtype = args[1]
-
-        valid_shape = isinstance(shape, (tvm.ir.PrimExpr, Integral, tuple, list))
-        valid_dtype = isinstance(dtype, str)
-        if not (valid_shape and valid_dtype):
-            raise ValueError(
-                "The first argument of T.Buffer[...] needs to be a tuple, "
-                "followed by the second argument dtype as a string"
-            )
-
-
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            _name = _dtype + _size + _lanes
-            globals()[_name] = ConcreteType(_name)
-
-boolean = ConcreteType("bool")
-handle = ConcreteType("handle")
-void = VoidType()
-Ptr = GenericPtrType()
-Tuple = GenericTupleType()
-# we don't have 'buffer' type on the cpp side
-# thus 'handle' is used here for convenience's sake
-Buffer = GenericBufferType("handle")
diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py
deleted file mode 100644
index c655a62237..0000000000
--- a/python/tvm/script/utils.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Helper functions in TVM Script Parser"""
-
-from typing import Callable, List, Any, Optional, Tuple
-
-import inspect
-import synr
-
-from tvm.ir import Span, SourceName
-from tvm.error import DiagnosticError
-
-
-def get_param_list(
-    func: Callable,
-) -> Tuple[List[str], List[Tuple[str, Tuple[Any, ...]]], Optional[str]]:
-    """Get the parameter list from definition of function"""
-    full_arg_spec: inspect.FullArgSpec = inspect.getfullargspec(func)
-
-    args: List[str]
-    defaults: Optional[Tuple[Any, ...]]
-    kwonlyargs: List[str]
-    args, defaults, kwonlyargs = (
-        full_arg_spec.args,
-        full_arg_spec.defaults,
-        full_arg_spec.kwonlyargs,
-    )
-
-    if defaults is None:
-        defaults = tuple()
-
-    if full_arg_spec.varkw is not None:
-        raise RuntimeError(
-            "TVM Script register error : variable keyword argument is not supported now"
-        )
-
-    if len(kwonlyargs) == 1 and kwonlyargs[0] == "span":
-        pass
-    elif not len(kwonlyargs) == 0:
-        raise RuntimeError("TVM Script register error : keyword only argument is not supported now")
-
-    pos_only: List[str] = list()
-    for arg in args[: len(args) - len(defaults)]:
-        if arg != "span":
-            pos_only.append(arg)
-    kwargs: List[Tuple[str, Tuple[Any, ...]]] = list()
-    for default, arg in zip(defaults, args[len(args) - len(defaults) :]):
-        if arg != "span":
-            kwargs.append((arg, default))
-
-    return pos_only, kwargs, full_arg_spec.varargs
-
-
-def tvm_span_from_synr(span: synr.ast.Span) -> Span:
-    """Convert a synr span to a TVM span"""
-    return Span(
-        SourceName(span.filename),
-        span.start_line,
-        span.end_line,
-        span.start_column,
-        span.end_column,
-    )
-
-
-def synr_span_from_tvm(span: Span) -> synr.ast.Span:
-    """Convert a TVM span to a synr span"""
-    return synr.ast.Span(
-        span.source_name.name,
-        span.line,
-        span.column,
-        span.end_line,
-        span.end_column,
-    )
-
-
-def call_with_error_reporting(
-    report_error,
-    node_span,
-    func,
-    *args,
-    **kwargs,
-):
-    """Call function with exception handling and report error using node_span"""
-    try:
-        return func(*args, **kwargs)
-    except DiagnosticError:
-        raise
-    except Exception as err:  # pylint: disable=broad-except
-        # printing last non-empty row of error message.
-        error_msg = list(filter(None, str(err).split("\n")))[-1]
-        report_error(error_msg, node_span)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index c64b7dfe71..aec3eceacb 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -19,50 +19,184 @@
 from tvm.ir import PrimExpr
 from tvm.runtime import const
 
-from .buffer import Buffer, decl_buffer, DataProducer
-from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
-from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
-from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
-from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
-from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
-
-from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
+from . import analysis, schedule, stmt_functor, transform, usmp
+from .generic import cast
+from .buffer import Buffer, DataProducer, decl_buffer
+from .data_layout import BijectiveLayout, Layout, bijective_layout, layout
+from .expr import (
+    EQ,
+    GE,
+    GT,
+    LE,
+    LT,
+    NE,
+    Add,
+    And,
+    Any,
+    Broadcast,
+    BufferLoad,
+    Call,
+    CallEffectKind,
+    Cast,
+    CommReducer,
+    Div,
+    FloatImm,
+    FloorDiv,
+    FloorMod,
+    IntImm,
+    IterVar,
+    Let,
+    Load,
+    Max,
+    Min,
+    Mod,
+    Mul,
+    Not,
+    Or,
+    ProducerLoad,
+    Ramp,
+    Reduce,
+    Select,
+    Shuffle,
+    SizeVar,
+    StringImm,
+    Sub,
+    Var,
+)
+from .function import IndexMap, PrimFunc, TensorIntrin
+from .op import (
+    TVMBackendAllocWorkspace,
+    TVMBackendFreeWorkspace,
+    abs,
+    acos,
+    acosh,
+    address_of,
+    all,
+    any,
+    asin,
+    asinh,
+    assume,
+    atan,
+    atan2,
+    atanh,
+    call_cpacked,
+    call_cpacked_lowered,
+    call_extern,
+    call_intrin,
+    call_llvm_intrin,
+    call_llvm_pure_intrin,
+    call_packed,
+    call_packed_lowered,
+    call_pure_extern,
+    ceil,
+    ceildiv,
+    clz,
+    comm_reducer,
+    copysign,
+    cos,
+    cosh,
+    div,
+    erf,
+    exp,
+    exp2,
+    exp10,
+    floor,
+    floordiv,
+    floormod,
+    fmod,
+    hypot,
+    if_then_else,
+    indexdiv,
+    indexmod,
+    isfinite,
+    isinf,
+    isnan,
+    isnullptr,
+    ldexp,
+    likely,
+    log,
+    log1p,
+    log2,
+    log10,
+    lookup_param,
+    max,
+    max_value,
+    min,
+    min_value,
+    mma_fill,
+    mma_store,
+    nearbyint,
+    nextafter,
+    popcount,
+    power,
+    ptx_commit_group,
+    ptx_cp_async,
+    ptx_ldmatrix,
+    ptx_mma,
+    ptx_mma_sp,
+    ptx_wait_group,
+    q_multiply_shift,
+    ret,
+    round,
+    rsqrt,
+    shift_left,
+    shift_right,
+    sigmoid,
+    sin,
+    sinh,
+    sqrt,
+    sum,
+    tan,
+    tanh,
+    trace,
+    trunc,
+    truncdiv,
+    truncmod,
+    tvm_access_ptr,
+    tvm_bmma_sync,
+    tvm_fill_fragment,
+    tvm_load_matrix_sync,
+    tvm_mma_sync,
+    tvm_stack_alloca,
+    tvm_stack_make_array,
+    tvm_stack_make_shape,
+    tvm_store_matrix_sync,
+    tvm_struct_get,
+    tvm_struct_set,
+    tvm_thread_allreduce,
+    tvm_throw_last_error,
+    tvm_tuple,
+    undef,
+    vectorcombine,
+    vectorhigh,
+    vectorlow,
+)
+from .schedule import BlockScope, Schedule, ScheduleError, ScheduleState, StmtSRef
 from .stmt import (
-    BufferStore,
-    BufferRealize,
-    Store,
-    ProducerStore,
     Allocate,
     AllocateConst,
+    AssertStmt,
     AttrStmt,
+    Block,
+    BlockRealize,
+    BufferRealize,
+    BufferRegion,
+    BufferStore,
     DeclBuffer,
+    Evaluate,
+    For,
+    ForKind,
+    IfThenElse,
+    LetStmt,
+    MatchBufferRegion,
+    Prefetch,
+    ProducerRealize,
+    ProducerStore,
+    SeqStmt,
+    Stmt,
+    Store,
+    While,
+    stmt_list,
+    stmt_seq,
+    type_annotation,
 )
-
-from .stmt import ProducerRealize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
-from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-
-from .function import PrimFunc, TensorIntrin, IndexMap
-
-from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
-from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
-from .op import sin, sinh, asin, asinh
-from .op import cos, cosh, acos, acosh
-from .op import tan, tanh, atan, atan2, atanh
-from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
-from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
-from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
-from .op import comm_reducer, min, max, sum
-from .op import q_multiply_shift
-
-from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
-
-from . import schedule
-from . import ir_builder
-from . import transform
-from . import analysis
-from . import stmt_functor
-from . import usmp
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 13674daa24..7959e82e7b 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -20,11 +20,11 @@ from typing import Dict, List, Union
 
 from tvm import Object
 from tvm.ir import IRModule
-from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
 
-from .. import Buffer, Stmt
+from ..buffer import Buffer
+from ..expr import Var
 from ..function import PrimFunc
+from ..stmt import Block, BufferRegion, PrimExpr, Stmt
 from . import _ffi_api
 
 
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
   assert(y.a == x)
 """
 from typing import Optional, Union
-from tvm import ir
+
 import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
 from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
 
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
 from . import _ffi_api
+from . import generic as _generic
 
 
 def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
 class ExprOp(object):
     """Operator overloading for Expr like expressions."""
 
-    # TODO(tkonolige): use inspect to add source information to these objects
-
     def __add__(self, other):
         return _generic.add(self, other)
 
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
     """
 
     def __init__(self, condition, true_value, false_value, span=None):
+        if isinstance(condition, bool):
+            condition = IntImm("bool", condition)
         self.__init_handle_by_constructor__(
             _ffi_api.Select, condition, true_value, false_value, span  # type: ignore
         )
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 17005b04a4..e2cad37bb6 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -14,17 +14,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=redefined-builtin, invalid-name
+# pylint: disable=redefined-builtin,invalid-name,no-member,protected-access
 """Operators used in TIR expression."""
+import warnings
 from typing import Any, Optional
+
 import tvm._ffi
-from tvm.ir.base import Span
-from tvm.runtime import convert, const
 from tvm.ir import Array, Op
+from tvm.ir.base import Span
+from tvm.runtime import const, convert
 
-from .buffer import Buffer
-from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer
 from . import _ffi_api
+from .buffer import Buffer
+from .expr import Call, CommReducer, PrimExprWithOp, StringImm, Var
 
 
 def _pack_buffer(buf, span=None):
@@ -100,6 +102,64 @@ def call_cpacked(*args, span=None):
     return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span)
 
 
+def call_packed_lowered(*args, span=None):
+    """Lowered version of call packed.
+
+    The argument to packed function can be Expr or Buffer.
+    The argument is the corresponding POD type when Expr is presented.
+
+    When the argument is Buffer, the corresponding PackedFunc
+    will recieve an TVMArrayHandle whose content is valid during the callback period.
+    If the PackedFunc is a python callback, then the corresponding argument is NDArray.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span)
+
+
+def call_cpacked_lowered(*args, span=None):
+    """Lowered version of call c-packed.
+
+    Same as call_packed, except that the first argument is the function name
+    (as in call_extern), and the last argument is the resource handle.
+
+    Parameters
+    ----------
+    args : list of Expr or Buffer.
+        Positional arguments.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+
+    See Also
+    --------
+    te.extern : Create tensor with extern function call.
+    """
+    call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+    return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span)
+
+
 def call_intrin(dtype, func_name, *args, span=None):
     """Build expression by calling an intrinsic function.
 
@@ -151,7 +211,10 @@ def call_pure_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span
+        dtype,
+        Op.get("tir.call_pure_extern"),
+        convert((StringImm(func_name),) + args),
+        span,
     )
 
 
@@ -178,7 +241,10 @@ def call_extern(dtype, func_name, *args, span=None):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span
+        dtype,
+        Op.get("tir.call_extern"),
+        convert((StringImm(func_name),) + args),
+        span=span,
     )
 
 
@@ -207,10 +273,22 @@ def call_llvm_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    if llvm_id == 0:
+        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
     return call_intrin(
-        dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+        dtype,
+        Op.get("tir.call_llvm_intrin"),
+        tvm.tir.const(llvm_id, "uint32"),
+        *args,
+        span=span,
     )
 
 
@@ -239,8 +317,16 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    if llvm_id == 0:
+        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
     return call_intrin(
         dtype,
         Op.get("tir.call_llvm_pure_intrin"),
@@ -250,6 +336,326 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     )
 
 
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+    return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+    return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
+def tvm_stack_alloca(dtype_str, num):
+    return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num)
+
+
+def tvm_stack_make_shape(*args):
+    return call_intrin("handle", "tir.tvm_stack_make_shape", *args)
+
+
+def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
+    return call_intrin(
+        "handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset
+    )
+
+
+def address_of(buffer_load, span=None):
+    """Returns the address of an element in the buffer
+
+    Parameters
+    ----------
+    buffer_load: BufferLoad
+        The buffer load.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+
+
+def lookup_param(param_name, span=None):
+    """Returns the param by name
+
+    Parameters
+    ----------
+    param_name : str
+        The name of param.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.lookup_param", param_name, span=span)
+
+
+def tvm_tuple(*value):
+    return call_intrin("handle", "tir.tvm_tuple", *value)
+
+
+def tvm_struct_get(arr, index, field_id, dtype):
+    return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field_id)
+
+
+def tvm_struct_set(arr, index, field_id, value):
+    return call_intrin("handle", "tir.tvm_struct_set", arr, index, field_id, value)
+
+
+def tvm_thread_allreduce(*freduce_args):
+    return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_load_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def tvm_mma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_mma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_bmma_sync(
+    fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+    return call_intrin(
+        "handle",
+        "tir.tvm_bmma_sync",
+        fragment_d,
+        index_d,
+        fragment_a,
+        index_a,
+        fragment_b,
+        index_b,
+        fragment_c,
+        index_c,
+    )
+
+
+def tvm_fill_fragment(fragment, m, n, k, index, value):
+    return call_intrin(
+        "handle",
+        "tir.tvm_fill_fragment",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        value,
+    )
+
+
+def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+    return call_intrin(
+        "handle",
+        "tir.tvm_store_matrix_sync",
+        fragment,
+        m,
+        n,
+        k,
+        index,
+        buffer_ptr,
+        stride,
+        layout,
+    )
+
+
+def ptx_mma(  # pylint: disable=missing-docstring
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    saturate,
+    operator=None,
+):
+    if operator is None:
+        return call_intrin(
+            dtype,
+            "tir.ptx_mma",
+            shape,
+            A_layout,
+            B_layout,
+            A_dtype,
+            B_dtype,
+            C_dtype,
+            multiplicand_a,
+            a_index,
+            multiplicand_b,
+            b_index,
+            accumulator,
+            c_index,
+            saturate,
+        )
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        saturate,
+        operator,
+    )
+
+
+def ptx_mma_sp(  # pylint: disable=missing-docstring
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    metadata,
+    meta_index,
+    sparse_selector,
+    saturate,
+):
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma_sp",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        metadata,
+        meta_index,
+        sparse_selector,
+        saturate,
+    )
+
+
+def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
+    return call_intrin(
+        dtype,
+        "tir.ptx_ldmatrix",
+        trans,
+        num,
+        type,
+        local_ptr,
+        local_offset,
+        smem_ptr,
+        smem_offset,
+    )
+
+
+def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
+    return call_intrin(
+        dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
+    )
+
+
+def ptx_commit_group():
+    return call_intrin("", "tir.ptx_commit_group")
+
+
+def ptx_wait_group(num):
+    return call_intrin("", "tir.ptx_wait_group", num)
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+    return call_intrin(
+        dtype,
+        "tir.mma_store",
+        m,
+        n,
+        dst_ptr,
+        src_ptr,
+        src_offset,
+        dst_stride,
+    )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+    return call_intrin(
+        dtype,
+        "tir.mma_fill",
+        local_size,
+        local_ptr,
+        offset,
+    )
+
+
+def vectorlow(dtype, vec):
+    return call_intrin(dtype, "tir.vectorlow", vec)
+
+
+def vectorhigh(dtype, vec):
+    return call_intrin(dtype, "tir.vectorhigh", vec)
+
+
+def vectorcombine(dtype, vec1, vec2):
+    return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)
+
+
+def assume(cond=None):
+    return call_intrin("int32", "tir.assume", cond)
+
+
+def undef():
+    return call_intrin("int32", "tir.undef")
+
+
 def ret(val):
     """Create a tir return expression
 
@@ -286,9 +692,9 @@ def any(*args, span=None):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    val = _ffi_api._OpOr(args[0], args[1], span)  # type: ignore
+    val = _ffi_api._OpOr(args[0], args[1], span)  # type: ignore # pylint: disable=no-member,protected-access
     for i in range(2, len(args)):
-        val = _ffi_api._OpOr(val, args[i], span)  # type: ignore
+        val = _ffi_api._OpOr(val, args[i], span)  # type: ignore # pylint: disable=no-member,protected-access
     return val
 
... 1655 lines suppressed ...