You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/07 08:24:23 UTC
[tvm] 01/01: Squashed
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch ir-builder
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 3857a97931deb50e9ca6e33721a81476935d18ca
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Mon May 23 17:06:45 2022 -0700
Squashed
---
include/tvm/ir/expr.h | 20 +-
include/tvm/ir/ir_builder.h | 188 ++++
include/tvm/support/with.h | 2 +
include/tvm/tir/ir_builder.h | 138 +++
include/tvm/tir/ir_builder_frame.h | 454 ++++++++
include/tvm/tir/op.h | 34 +-
python/tvm/ir/__init__.py | 55 +-
.../_ffi_api.py => ir/_ffi_ir_builder_api.py} | 4 +-
python/tvm/ir/ir_builder.py | 84 ++
python/tvm/script/__init__.py | 17 +-
python/tvm/script/{ => parser}/__init__.py | 14 +-
python/tvm/script/parser/diagnostics.py | 60 ++
python/tvm/script/parser/dispatch.py | 63 ++
python/tvm/script/parser/doc.py | 341 ++++++
python/tvm/script/parser/doc_core.py | 1140 ++++++++++++++++++++
.../script/{tir/prim_func.py => parser/entry.py} | 46 +-
python/tvm/script/parser/evaluator.py | 282 +++++
.../script/{_ffi_api.py => parser/ir/__init__.py} | 7 +-
.../tvm/script/{__init__.py => parser/ir/entry.py} | 19 +-
.../script/{__init__.py => parser/ir/parser.py} | 24 +-
python/tvm/script/parser/parser.py | 182 ++++
python/tvm/script/parser/source.py | 89 ++
python/tvm/script/{ => parser/tir}/__init__.py | 9 +-
python/tvm/script/parser/tir/entry.py | 97 ++
python/tvm/script/parser/tir/operation.py | 85 ++
python/tvm/script/parser/tir/parser.py | 262 +++++
python/tvm/script/parser/utils.py | 63 ++
python/tvm/script/parser/var_table.py | 71 ++
python/tvm/script/{ => parser_v1}/__init__.py | 3 +-
python/tvm/script/{ => parser_v1}/_ffi_api.py | 0
.../script/{ => parser_v1}/context_maintainer.py | 8 +-
python/tvm/script/{ => parser_v1}/diagnostics.py | 6 +-
python/tvm/script/{ => parser_v1}/highlight.py | 0
python/tvm/script/{ => parser_v1}/meta_unparser.py | 0
python/tvm/script/{ => parser_v1}/parser.py | 23 +-
python/tvm/script/{ => parser_v1}/registry.py | 2 +-
python/tvm/script/{ => parser_v1}/tir/__init__.py | 0
python/tvm/script/{ => parser_v1}/tir/__init__.pyi | 0
python/tvm/script/{ => parser_v1}/tir/intrin.py | 5 +-
python/tvm/script/{ => parser_v1}/tir/node.py | 7 +-
python/tvm/script/{ => parser_v1}/tir/prim_func.py | 3 +-
.../script/{ => parser_v1}/tir/scope_handler.py | 17 +-
.../tvm/script/{ => parser_v1}/tir/special_stmt.py | 16 +-
python/tvm/script/{ => parser_v1}/tir/ty.py | 1 +
python/tvm/script/{ => parser_v1}/utils.py | 7 +-
python/tvm/tir/__init__.py | 212 +++-
.../_ffi_api.py => tir/_ffi_ir_builder_api.py} | 4 +-
python/tvm/tir/analysis/analysis.py | 4 +-
python/tvm/tir/buffer.py | 36 +-
python/tvm/tir/expr.py | 15 +-
python/tvm/tir/ir_builder_frame.py | 118 ++
python/tvm/tir/{ir_builder.py => ir_builder_v1.py} | 6 +-
python/tvm/tir/ir_builder_v2.py | 901 ++++++++++++++++
python/tvm/tir/op.py | 551 +++++++++-
python/tvm/tir/schedule/block_scope.py | 2 +-
python/tvm/tir/schedule/schedule.py | 6 +-
python/tvm/tir/schedule/state.py | 3 +-
python/tvm/tir/stmt.py | 4 +
python/tvm/tir/usmp/transform/transform.py | 5 +-
src/ir/expr.cc | 13 +
src/ir/ir_builder.cc | 134 +++
src/tir/ir/expr.cc | 19 +-
src/tir/ir/script/script_complete.cc | 5 +-
src/tir/ir/script/script_complete.h | 35 +
src/tir/ir/stmt.cc | 2 +
src/tir/ir_builder/ir_builder.cc | 637 +++++++++++
src/tir/ir_builder/ir_builder_frame.cc | 207 ++++
src/tir/ir_builder/utils.h | 92 ++
src/tir/op/op.cc | 24 +
tests/python/tvmscript/test_builder_basic.py | 227 ++++
tests/python/tvmscript/test_parse_basic.py | 118 ++
tests/python/tvmscript/test_parser_capture.py | 43 +
72 files changed, 7174 insertions(+), 197 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 5e358ed50e..dcabd7d3f1 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -764,16 +764,32 @@ struct PackedFuncValueConverter<PrimExpr> {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
- return PrimExpr(val.operator int());
+ return IntImm(runtime::DataType::Int(32), val.operator int());
}
if (val.type_code() == kDLFloat) {
- return PrimExpr(static_cast<float>(val.operator double()));
+ return FloatImm(runtime::DataType::Float(32), val.operator double());
}
return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
}
};
+template <>
+struct PackedFuncValueConverter<Array<PrimExpr>> {
+ static Array<PrimExpr> From(const TVMPODValue_& val) {
+ if (val.type_code() == kTVMNullptr) return Array<PrimExpr>(nullptr);
+ Array<ObjectRef> vals = val.AsObjectRef<Array<ObjectRef>>();
+ Array<PrimExpr> exprs;
+ for (const ObjectRef& v : vals) {
+ TVMValue value;
+ value.v_handle = const_cast<void*>(static_cast<const void*>(v.get()));
+ exprs.push_back(
+ PackedFuncValueConverter<PrimExpr>::From(TVMArgValue(value, kTVMObjectHandle)));
+ }
+ return exprs;
+ }
+};
+
template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
diff --git a/include/tvm/ir/ir_builder.h b/include/tvm/ir/ir_builder.h
new file mode 100644
index 0000000000..5e099e7310
--- /dev/null
+++ b/include/tvm/ir/ir_builder.h
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_IR_IR_BUILDER_H_
+#define TVM_IR_IR_BUILDER_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+
+namespace tvm {
+namespace ir_builder {
+
+////////////////////////////// Core Infra: Frame and IRBuilder //////////////////////////////
+
+class IRBuilderFrameNode : public runtime::Object {
+ public:
+ std::vector<runtime::TypedPackedFunc<void()>> callbacks;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `callbacks` is not visited.
+ }
+
+ static constexpr const char* _type_key = "ir_builder.IRBuilderFrame";
+ TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
+
+ public:
+ virtual ~IRBuilderFrameNode() = default;
+ virtual void EnterWithScope();
+ virtual void ExitWithScope();
+
+ void AddCallback(runtime::TypedPackedFunc<void()> callback);
+};
+
+class IRBuilderFrame : public runtime::ObjectRef {
+ public:
+ virtual ~IRBuilderFrame() = default;
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
+
+ protected:
+ IRBuilderFrame() = default;
+
+ public:
+ inline void EnterWithScope();
+ inline void ExitWithScope();
+};
+
+class IRBuilderNode : public runtime::Object {
+ public:
+ runtime::Array<IRBuilderFrame> frames;
+ Optional<ObjectRef> result;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("frames", &frames);
+ v->Visit("result", &result);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.IRBuilder";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
+
+ public:
+ template <typename TFrame>
+ inline Optional<TFrame> FindFrame() const;
+ template <typename TFrame>
+ inline Optional<TFrame> GetLastFrame() const;
+ template <typename TObjectRef>
+ inline TObjectRef Get() const;
+};
+
+class IRBuilder : public runtime::ObjectRef {
+ public:
+ explicit IRBuilder();
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
+
+ public:
+ void EnterWithScope();
+ void ExitWithScope();
+ static IRBuilder Current();
+ template <class TObjectRef>
+ inline static TObjectRef Name(String name, TObjectRef obj);
+};
+
+////////////////////////////// Generic IRModule //////////////////////////////
+
+class IRModuleFrameNode : public IRBuilderFrameNode {
+ public:
+ Array<GlobalVar> global_vars;
+ Array<BaseFunc> functions;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ IRBuilderFrameNode::VisitAttrs(v);
+ v->Visit("global_vars", &global_vars);
+ v->Visit("functions", &functions);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.IRModuleFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class IRModuleFrame : public IRBuilderFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame,
+ IRModuleFrameNode);
+};
+
+TVM_DLL IRModuleFrame IRModule();
+
+////////////////////////////// Details //////////////////////////////
+
+namespace details {
+
+class Namer {
+ public:
+ using FType = NodeFunctor<void(const ObjectRef&, String)>;
+ static FType& vtable();
+ static void Name(ObjectRef node, String name);
+};
+
+} // namespace details
+
+template <class TObjectRef>
+inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
+ details::Namer::Name(obj, name);
+ return Downcast<TObjectRef>(obj);
+}
+
+inline void IRBuilderFrame::EnterWithScope() {
+ ICHECK(data_ != nullptr);
+ static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
+}
+
+inline void IRBuilderFrame::ExitWithScope() {
+ ICHECK(data_ != nullptr);
+ static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
+ data_.reset();
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::FindFrame() const {
+ using TFrameNode = typename TFrame::ContainerType;
+ for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
+ if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
+ return GetRef<TFrame>(p);
+ }
+ }
+ return NullOpt;
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
+ using TFrameNode = typename TFrame::ContainerType;
+ if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
+ return Downcast<TFrame>(frames.back());
+ }
+ return NullOpt;
+}
+
+template <typename TObjectRef>
+inline TObjectRef IRBuilderNode::Get() const {
+ using TObject = typename TObjectRef::ContainerType;
+ CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
+ const auto* n = result.as<TObject>();
+ CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key;
+ return GetRef<TObjectRef>(n);
+}
+
+} // namespace ir_builder
+} // namespace tvm
+
+#endif // TVM_IR_IR_BUILDER_H_
diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h
index 3651e05e74..bbc36419ff 100644
--- a/include/tvm/support/with.h
+++ b/include/tvm/support/with.h
@@ -75,6 +75,8 @@ class With {
ContextType& operator*() { return *get(); }
const ContextType* operator*() const { return *get(); }
+ ContextType operator()() { return ctx_; }
+
private:
/*! \brief internal context type. */
ContextType ctx_;
diff --git a/include/tvm/tir/ir_builder.h b/include/tvm/tir/ir_builder.h
new file mode 100644
index 0000000000..5d4c61635c
--- /dev/null
+++ b/include/tvm/tir/ir_builder.h
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_H_
+#define TVM_TIR_IR_BUILDER_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/ir_builder_frame.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+using tvm::runtime::NDArray;
+using tvm::tir::Buffer;
+using tvm::tir::IterVar;
+using tvm::tir::Var;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+ Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+ String storage_scope, int align, int offset_factor, String buffer_type,
+ Optional<Array<IntImm>> axis_separators);
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
+BlockFrame Block(String name, bool no_realize = false);
+BlockInitFrame Init();
+void Where(PrimExpr predicate);
+void Reads(Array<ObjectRef> buffer_slices);
+void Writes(Array<ObjectRef> buffer_slices);
+void BlockAttrs(Map<String, ObjectRef> attrs);
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+ Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+ PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
+ int offset_factor = 0, String buffer_type = "default",
+ Array<IntImm> axis_separators = {});
+
+namespace axis {
+IterVar Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+IterVar Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Array<IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+} // namespace axis
+
+ForFrame Serial(PrimExpr start, PrimExpr stop,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Parallel(PrimExpr start, PrimExpr stop,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Vectorized(PrimExpr start, PrimExpr stop,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Unroll(PrimExpr start, PrimExpr stop,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Grid(Array<PrimExpr> extents);
+
+PrimFuncFrame PrimFunc();
+Var Arg(String name, Var var);
+Buffer Arg(String name, Buffer buffer);
+void FuncName(String name);
+void FuncAttrs(Map<String, ObjectRef> attrs);
+Type FuncRet(Type ret_type);
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+ Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+ PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
+ int align = -1, int offset_factor = 0, String buffer_type = "default",
+ Array<IntImm> axis_separators = {});
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
+ DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
+ Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
+ String storage_scope = "global", int align = -1, int offset_factor = 0,
+ String buffer_type = "default", Array<IntImm> axis_separators = {});
+
+AssertFrame Assert(PrimExpr condition, String message);
+LetFrame Let(Var var, PrimExpr value);
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+ Optional<PrimExpr> condition = NullOpt,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+AllocateConstFrame AllocateConst(
+ NDArray data, DataType dtype, Array<PrimExpr> extents,
+ Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+WhileFrame While(PrimExpr condition);
+IfFrame If(PrimExpr condition);
+ThenFrame Then();
+ElseFrame Else();
+LaunchThreadFrame LaunchThread(IterVar iter_var, PrimExpr extent);
+IterVar EnvThread(String thread_tag);
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+void Prefetch(Buffer buffer, Array<Range> bounds);
+void Evaluate(PrimExpr value);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
+ inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
+ DataType dtype = DType; \
+ return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
+ }
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
+
+#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
+
+#endif // TVM_TIR_IR_BUILDER_H_
diff --git a/include/tvm/tir/ir_builder_frame.h b/include/tvm/tir/ir_builder_frame.h
new file mode 100644
index 0000000000..549847da6c
--- /dev/null
+++ b/include/tvm/tir/ir_builder_frame.h
@@ -0,0 +1,454 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_FRAME_H_
+#define TVM_TIR_IR_BUILDER_FRAME_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+class TIRFrameNode : public IRBuilderFrameNode {
+ public:
+ Array<tvm::tir::Stmt> stmts;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ IRBuilderFrameNode::VisitAttrs(v);
+ v->Visit("stmts", &stmts);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.TIRFrame";
+ TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode);
+};
+
+class TIRFrame : public IRBuilderFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode);
+
+ protected:
+ TIRFrame() = default;
+};
+
+class BlockFrameNode : public TIRFrameNode {
+ public:
+ String name;
+ Array<tvm::tir::IterVar> iter_vars;
+ Optional<Array<tvm::tir::BufferRegion>> reads;
+ Optional<Array<tvm::tir::BufferRegion>> writes;
+ Optional<tvm::tir::Stmt> init;
+ Array<tvm::tir::Buffer> alloc_buffers;
+ Array<tvm::tir::MatchBufferRegion> match_buffers;
+ Map<String, ObjectRef> annotations;
+
+ Array<PrimExpr> iter_values;
+ Optional<PrimExpr> predicate;
+ bool no_realize;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("name", &name);
+ v->Visit("iter_vars", &iter_vars);
+ v->Visit("reads", &reads);
+ v->Visit("writes", &writes);
+ v->Visit("init", &init);
+ v->Visit("alloc_buffers", &alloc_buffers);
+ v->Visit("match_buffers", &match_buffers);
+ v->Visit("annotations", &annotations);
+ v->Visit("iter_values", &iter_values);
+ v->Visit("predicate", &predicate);
+ v->Visit("no_realize", &no_realize);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.BlockFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class BlockFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
+};
+
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+ void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+ static constexpr const char* _type_key = "ir_builder.tir.BlockInitFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+ void EnterWithScope() final;
+ void ExitWithScope() final;
+};
+
+class BlockInitFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode);
+};
+
+class ForFrameNode : public TIRFrameNode {
+ public:
+ using FMakeForLoop =
+ runtime::TypedPackedFunc<tvm::tir::Stmt(Array<tvm::tir::Var>, Array<Range>, tvm::tir::Stmt)>;
+
+ Array<tvm::tir::Var> vars;
+ Array<Range> doms;
+ FMakeForLoop f_make_for_loop;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("vars", &vars);
+ v->Visit("doms", &doms);
+ // `f_make_for_loop` is not visited.
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.ForFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class ForFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
+};
+
+class PrimFuncFrameNode : public TIRFrameNode {
+ public:
+ Optional<String> name;
+ Array<tvm::tir::Var> args;
+ Optional<Type> ret_type;
+ Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
+ Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
+ Map<String, ObjectRef> attrs;
+ Array<tvm::tir::Buffer> root_alloc_buffers;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("name", &name);
+ v->Visit("args", &args);
+ v->Visit("ret_type", &ret_type);
+ v->Visit("buffer_map", &buffer_map);
+ v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
+ v->Visit("attrs", &attrs);
+ v->Visit("root_alloc_buffers", &root_alloc_buffers);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.PrimFuncFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class PrimFuncFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
+};
+
+class AssertFrameNode : public TIRFrameNode {
+ public:
+ PrimExpr condition;
+ PrimExpr message;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("condition", &condition);
+ v->Visit("message", &message);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.AssertFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class AssertFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+class LetFrameNode : public TIRFrameNode {
+ public:
+ tvm::tir::Var var;
+ PrimExpr value;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("var", &var);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.LetFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class LetFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+ Array<PrimExpr> extents;
+ DataType dtype;
+ String storage_scope;
+ PrimExpr condition;
+ Map<String, ObjectRef> annotations;
+ tvm::tir::Buffer buffer;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("extents", &extents);
+ v->Visit("dtype", &dtype);
+ v->Visit("storage_scope", &storage_scope);
+ v->Visit("condition", &condition);
+ v->Visit("annotations", &annotations);
+ v->Visit("buffer", &buffer);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.AllocateFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class AllocateFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+ DataType dtype;
+ Array<PrimExpr> extents;
+ tvm::runtime::NDArray data;
+ tvm::tir::Buffer buffer;
+ Map<String, ObjectRef> annotations;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("dtype", &dtype);
+ v->Visit("extents", &extents);
+ v->Visit("data", &data);
+ v->Visit("buffer", &buffer);
+ v->Visit("annotations", &annotations);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.AllocateConstFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class AllocateConstFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+ AllocateConstFrameNode);
+};
+
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+ PrimExpr extent;
+ String attr_key;
+ tvm::tir::IterVar iter_var;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("extent", &extent);
+ v->Visit("attr_key", &attr_key);
+ v->Visit("iter_var", &iter_var);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.LaunchThreadFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class LaunchThreadFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+ LaunchThreadFrameNode);
+};
+
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+ tvm::tir::BufferRegion buffer_slice;
+ String storage_scope;
+ PrimExpr condition;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("buffer_slice", &buffer_slice);
+ v->Visit("storage_scope", &storage_scope);
+ v->Visit("condition", &condition);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.RealizeFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class RealizeFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
+
+class AttrFrameNode : public TIRFrameNode {
+ public:
+ ObjectRef node;
+ String attr_key;
+ PrimExpr value;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("node", &node);
+ v->Visit("attr_key", &attr_key);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.AttrFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class AttrFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+class WhileFrameNode : public TIRFrameNode {
+ public:
+ PrimExpr condition;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("condition", &condition);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.WhileFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class WhileFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+class IfFrameNode : public TIRFrameNode {
+ public:
+ PrimExpr condition;
+ Optional<Array<tvm::tir::Stmt>> then_stmts;
+ Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("condition", &condition);
+ v->Visit("then_stmts", &then_stmts);
+ v->Visit("else_stmts", &else_stmts);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.IfFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class IfFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+class ThenFrameNode : public TIRFrameNode {
+ public:
+ static constexpr const char* _type_key = "ir_builder.tir.ThenFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+ void EnterWithScope() final;
+ void ExitWithScope() final;
+};
+
+class ThenFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+class ElseFrameNode : public TIRFrameNode {
+ public:
+ static constexpr const char* _type_key = "ir_builder.tir.ElseFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+ void EnterWithScope() final;
+ void ExitWithScope() final;
+};
+
+class ElseFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+ tvm::tir::Buffer buffer;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("buffer", &buffer);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.DeclBufferFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
+
+#endif // TVM_TIR_IR_BUILDER_FRAME_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 7236c6a611..09758cc923 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
/*!
* \brief Check if x is infinite.
* \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
Span span = Span());
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
/*!
* \brief Calculate floor(x)
* \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
Span span = Span());
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
@@ -701,6 +732,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(log2);
TVM_DECLARE_INTRIN_UNARY(log10);
+TVM_DECLARE_INTRIN_UNARY(log1p);
TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(tan);
TVM_DECLARE_INTRIN_UNARY(cos);
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4e847c0310..51dff32ac6 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -16,29 +16,46 @@
# under the License.
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
-from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .base import structural_equal, assert_structural_equal, structural_hash
-from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
-from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
-from .tensor_type import TensorType
-from .affine_type import TensorAffineType, TupleAffineType
-from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
-from .op import Op, register_op_attr, register_intrin_lowering
-from .function import CallingConv, BaseFunc
+from . import diagnostics, instrument, ir_builder, transform
from .adt import Constructor, TypeData
-from .module import IRModule
+from .affine_type import TensorAffineType, TupleAffineType
from .attrs import Attrs, DictAttrs, make_node
+from .base import (
+ EnvFunc,
+ Node,
+ SourceName,
+ Span,
+ assert_structural_equal,
+ load_json,
+ save_json,
+ structural_equal,
+ structural_hash,
+)
from .container import Array, Map
+from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
+from .function import BaseFunc, CallingConv
from .memory_pools import (
- PoolInfo,
- WorkspacePoolInfo,
- ConstantPoolInfo,
- WorkspaceMemoryPools,
ConstantMemoryPools,
+ ConstantPoolInfo,
+ PoolInfo,
PoolInfoProperties,
+ WorkspaceMemoryPools,
+ WorkspacePoolInfo,
)
-
-from . import transform
-from . import instrument
-from . import diagnostics
+from .module import IRModule
+from .op import Op, register_intrin_lowering, register_op_attr
+from .tensor_type import TensorType
+from .type import (
+ FuncType,
+ GlobalTypeVar,
+ IncompleteType,
+ PointerType,
+ PrimType,
+ RelayRefType,
+ TupleType,
+ Type,
+ TypeConstraint,
+ TypeKind,
+ TypeVar,
+)
+from .type_relation import TypeCall, TypeRelation
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/ir/_ffi_ir_builder_api.py
similarity index 88%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/ir/_ffi_ir_builder_api.py
index 926d17b166..9d08bc9b70 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/ir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
import tvm._ffi
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/ir/ir_builder.py b/python/tvm/ir/ir_builder.py
new file mode 100644
index 0000000000..df05bf8361
--- /dev/null
+++ b/python/tvm/ir/ir_builder.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A generic IRBuilder across the TVM stack"""
+from typing import List, TypeVar
+
+from tvm._ffi import register_object as _register_object
+from tvm.runtime import Object as _Object
+
+from . import _ffi_ir_builder_api as _ffi_api
+
+
+@_register_object("ir_builder.IRBuilderFrame")
+class IRBuilderFrame(_Object):
+ def __enter__(self) -> "IRBuilderFrame":
+ _ffi_api.IRBuilderFrameEnter(self) # pylint: disable=no-member # type: ignore
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
+ _ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type: ignore
+
+ def add_callback(self, callback) -> None: # pylint: disable=unused-argument
+ _ffi_api.IRBuilderFrameAddCallback( # pylint: disable=no-member # type: ignore
+ self, callback
+ )
+
+
+@_register_object("ir_builder.IRBuilder")
+class IRBuilder(_Object):
+ def __init__(self) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.IRBuilder # pylint: disable=no-member # type: ignore
+ )
+
+ def __enter__(self) -> "IRBuilder":
+ _ffi_api.IRBuilderEnter(self) # pylint: disable=no-member # type: ignore
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
+ _ffi_api.IRBuilderExit(self) # pylint: disable=no-member # type: ignore
+
+ @staticmethod
+ def current() -> "IRBuilder":
+ return _ffi_api.IRBuilderCurrent() # pylint: disable=no-member # type: ignore
+
+ def get(self) -> _Object:
+ return _ffi_api.IRBuilderGet(self) # pylint: disable=no-member # type: ignore
+
+
+DefType = TypeVar("DefType", bound=_Object)
+
+
+def name(s: str, v: DefType) -> DefType:
+ return _ffi_api.IRBuilderName(s, v) # pylint: disable=no-member # type: ignore
+
+
+def name_many(
+ s: List[str],
+ vs: List[DefType], # pylint: disable=invalid-name
+) -> List[DefType]:
+ assert len(s) == len(vs)
+ return [name(i, v) for i, v in zip(s, vs)]
+
+
+@_register_object("ir_builder.IRModuleFrame")
+class IRModuleFrame(IRBuilderFrame):
+ ...
+
+
+def ir_module() -> IRModuleFrame:
+ return _ffi_api.IRModule() # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..d9605b2e70 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,20 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+from . import parser, parser_v1
-from . import tir
+#############
+from .parser import ir as ir_v2
+from .parser import ir_module as ir_module_v2
+from .parser import parse as from_source_v2
+from .parser import tir as tir_v2
-from .parser import ir_module, from_source
+#############
+from .parser_v1 import from_source as from_source_v1
+from .parser_v1 import ir_module as ir_module_v1
+from .parser_v1 import tir as tir_v1
+
+ir = ir_v2
+ir_module = ir_module_v2
+tir = tir_v2
+from_source = from_source_v2
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 77%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..d8530e0ab1 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,13 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
+# under the Licens.
+"""The parser"""
+from . import dispatch as _dispatch
+from . import doc as _doc
+from . import ir
+from . import parser as _parser
from . import tir
-
-from .parser import ir_module, from_source
+from .entry import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/parser/diagnostics.py b/python/tvm/script/parser/diagnostics.py
new file mode 100644
index 0000000000..fd06d20e61
--- /dev/null
+++ b/python/tvm/script/parser/diagnostics.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+from .source import Source
+
+
+class Diagnostics:
+
+ source: Source
+ ctx: diagnostics.DiagnosticContext
+
+ def __init__(self, source: Source):
+ mod = IRModule()
+ mod.source_map.add(source.source_name, source.full_source)
+ self.source = source
+ self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+ lineno = node.lineno
+ col_offset = node.col_offset
+ end_lineno = node.end_lineno or lineno
+ end_col_offset = node.end_col_offset or col_offset
+ lineno += self.source.start_line - 1
+ end_lineno += self.source.start_line - 1
+ col_offset += self.source.start_column + 1
+ end_col_offset += self.source.start_column + 1
+ self.ctx.emit(
+ diagnostics.Diagnostic(
+ level=level,
+ span=Span(
+ source_name=SourceName(self.source.source_name),
+ line=lineno,
+ end_line=end_lineno,
+ column=col_offset,
+ end_column=end_col_offset,
+ ),
+ message=message,
+ )
+ )
+
+ def error(self, node: doc.AST, message: str) -> None:
+ self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+ self.ctx.render()
diff --git a/python/tvm/script/parser/dispatch.py b/python/tvm/script/parser/dispatch.py
new file mode 100644
index 0000000000..6237371a30
--- /dev/null
+++ b/python/tvm/script/parser/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+ from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+ """Register a method for a dispatch token and type name"""
+
+ def f(method: ParseMethod):
+ ParseVTable[(token, type_name)] = method
+
+ return f
+
+
+def get(
+ token: str,
+ type_name: str,
+ default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+ return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int): # pylint: disable=invalid-name
+ def f(method: OpMethod):
+ OpVTable[(ty, op, operand_index)] = method
+
+ return f
+
+
+def get_op(
+ ty: Type, # pylint: disable=invalid-name
+ op: Type,
+ operand_index: int,
+ default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+ return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/doc.py b/python/tvm/script/parser/doc.py
new file mode 100644
index 0000000000..15ad166c33
--- /dev/null
+++ b/python/tvm/script/parser/doc.py
@@ -0,0 +1,341 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import * # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+ to_doc: typing.Optional[FnToDoc]
+ from_doc: typing.Optional[FnFromDoc]
+
+ def __init__(self):
+ self.to_doc = None
+ self.from_doc = None
+
+
+class Registry:
+ _inst: typing.Optional["Registry"] = None
+ table: typing.Dict[str, Entry]
+
+ def __init__(self):
+ self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+ def f(to_doc: FnToDoc): # pylint: disable=redefined-outer-name
+ reg = Registry._inst # pylint: disable=protected-access
+ reg.table[name].to_doc = to_doc
+
+ return f
+
+
+def register_from_doc(name: str):
+ def f(to_doc: FnFromDoc): # pylint: disable=redefined-outer-name
+ reg = Registry._inst # pylint: disable=protected-access
+ reg.table[name].from_doc = to_doc
+
+ return f
+
+
+def _is_atomic_type(node):
+ return (
+ node is None
+ or node in [..., True, False]
+ or isinstance(
+ node,
+ (
+ int,
+ float,
+ str,
+ bool,
+ bytes,
+ complex,
+ ),
+ )
+ )
+
+
+def _get_registry_entry(cls_name, attr):
+ cls_name = cls_name.split(".")[-1]
+ reg = Registry._inst # pylint: disable=protected-access
+ if cls_name in reg.table:
+ entry = reg.table[cls_name]
+ return getattr(entry, attr, None)
+ return None
+
+
+def from_doc(node):
+ if _is_atomic_type(node):
+ return node
+ if isinstance(node, tuple):
+ return tuple(from_doc(n) for n in node)
+ if isinstance(node, list):
+ return [from_doc(n) for n in node]
+ func = _get_registry_entry(node.__class__.__name__, "from_doc")
+ if not func:
+ raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+ return func(node)
+
+
+def to_doc(node):
+ if _is_atomic_type(node):
+ return node
+ if isinstance(node, tuple):
+ return tuple(to_doc(n) for n in node)
+ if isinstance(node, list):
+ return [to_doc(n) for n in node]
+ func = _get_registry_entry(node.__class__.__name__, "to_doc")
+ if not func:
+ raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+ return func(node)
+
+
+def parse(
+ source,
+ filename="<unknown>",
+ mode="exec",
+) -> doc.AST:
+ try:
+ program = ast.parse(
+ source=source,
+ filename=filename,
+ mode=mode,
+ feature_version=(3, 8),
+ )
+ except: # pylint: disable=bare-except
+ program = ast.parse(
+ source=source,
+ filename=filename,
+ mode=mode,
+ )
+ return to_doc(program)
+
+
+class NodeVisitor:
+ def visit(self, node: doc.AST) -> None:
+ if isinstance(node, (list, tuple)):
+ for item in node:
+ self.visit(item)
+ return
+ if not isinstance(node, doc.AST):
+ return
+ getattr(
+ self,
+ "visit_" + node.__class__.__name__.split(".")[-1],
+ self.generic_visit,
+ )(node)
+
+ def generic_visit(self, node: doc.AST) -> None:
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ value = getattr(node, field, None)
+ if value is None:
+ pass
+ elif isinstance(value, (doc.AST, list, tuple)):
+ self.visit(value)
+
+
+class NodeTransformer:
+ def visit(self, node: doc.AST) -> doc.AST:
+ if isinstance(node, list):
+ return [self.visit(item) for item in node]
+ if isinstance(node, tuple):
+ return tuple(self.visit(item) for item in node)
+ if not isinstance(node, doc.AST):
+ return node
+ return getattr(
+ self,
+ "visit_" + node.__class__.__name__.split(".")[-1],
+ self.generic_visit,
+ )(node)
+
+ def generic_visit(self, node: doc.AST) -> doc.AST:
+ kv: typing.Dict[str, typing.Any] = {}
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ value = getattr(node, field, None)
+ if value is None:
+ pass
+ elif isinstance(value, (doc.AST, list, tuple)):
+ value = self.visit(value)
+ kv[field] = value
+ return node.__class__(**kv)
+
+
+def _register_default():
+ class DefaultTranslator:
+ def __init__(self, doc_cls, func, fields):
+ self.doc_cls = doc_cls # getattr(doc, name)
+ self.func = func
+ self.fields = fields
+
+ def __call__(self, node):
+ kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+ return self.doc_cls(**kv)
+
+ Registry._inst = Registry() # pylint: disable=protected-access
+ for cls_name in dir(doc):
+ doc_cls = getattr(doc, cls_name)
+ if not hasattr(ast, cls_name):
+ continue
+ if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+ assert "." not in cls_name
+ register_to_doc(cls_name)(
+ DefaultTranslator(
+ getattr(doc, cls_name),
+ to_doc,
+ doc_cls._FIELDS, # pylint: disable=protected-access
+ )
+ )
+ register_from_doc(cls_name)(
+ DefaultTranslator(
+ getattr(ast, cls_name),
+ from_doc,
+ doc_cls._FIELDS, # pylint: disable=protected-access
+ )
+ )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+ return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+ if _py_version() not in [(3, 6), (3, 7)]:
+ return
+
+ def as_constant(f) -> doc.Constant:
+ def to_doc_func(x: ast.AST) -> doc.Constant:
+ return doc.Constant(
+ value=getattr(x, f) if isinstance(f, str) else f(x),
+ kind=None,
+ s=None,
+ n=None,
+ lineno=x.lineno,
+ col_offset=x.col_offset,
+ end_lineno=x.lineno,
+ end_col_offset=x.col_offset,
+ )
+
+ return to_doc_func
+
+ register_to_doc("Str")(as_constant("s"))
+ register_to_doc("NameConstant")(as_constant("value"))
+ register_to_doc("Num")(as_constant("n"))
+ register_to_doc("Bytes")(as_constant("s"))
+ register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+ if _py_version() >= (3, 9):
+ return
+
+ def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+ if isinstance(x.slice, ast.Slice):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=doc.Slice(
+ lower=to_doc(x.slice.lower),
+ upper=to_doc(x.slice.upper),
+ step=to_doc(x.slice.step),
+ lineno=getattr(x.slice, "lineno", None),
+ col_offset=getattr(x.slice, "col_offset", None),
+ end_lineno=getattr(x.slice, "end_lineno", None),
+ end_col_offset=getattr(x.slice, "end_col_offset", None),
+ ),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ if isinstance(x.slice, ast.ExtSlice):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=doc.Tuple(
+ elts=[to_doc(i) for i in x.slice.dims],
+ ctx=doc.Load(
+ lineno=None,
+ col_offset=None,
+ end_lineno=None,
+ end_col_offset=None,
+ ),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ ),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ if isinstance(x.slice, ast.Index):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=to_doc(x.slice.value),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+ def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+ if isinstance(x.slice, doc.Slice):
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=from_doc(x.slice),
+ ctx=from_doc(x.ctx),
+ )
+ elif isinstance(x.slice, doc.Tuple):
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=ast.ExtSlice(
+ dims=[from_doc(i) for i in x.slice.elts],
+ ),
+ ctx=from_doc(x.ctx),
+ )
+ else:
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=ast.Index(value=from_doc(x.slice)),
+ ctx=from_doc(x.ctx),
+ )
+ result.lineno = x.lineno
+ result.col_offset = x.col_offset
+ result.end_lineno = x.end_lineno
+ result.end_col_offset = x.end_col_offset
+ return result
+
+ register_to_doc("Subscript")(subscript_to_doc)
+ register_from_doc("Subscript")(subscript_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
diff --git a/python/tvm/script/parser/doc_core.py b/python/tvm/script/parser/doc_core.py
new file mode 100644
index 0000000000..b88eef9a0e
--- /dev/null
+++ b/python/tvm/script/parser/doc_core.py
@@ -0,0 +1,1140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-outer-name,missing-docstring,invalid-name
+# pylint: disable=useless-super-delegation,redefined-builtin
+# pylint: disable=too-few-public-methods,too-many-arguments
+class AST:
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__()
+ self.lineno = lineno
+ self.col_offset = col_offset
+ self.end_lineno = end_lineno
+ self.end_col_offset = end_col_offset
+
+
+class mod(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Module(mod):
+ _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.body = body
+
+
+class Interactive(mod):
+ _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.body = body
+
+
+class Expression(mod):
+ _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.body = body
+
+
+class stmt(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class FunctionDef(stmt):
+ _FIELDS = [
+ "name",
+ "args",
+ "body",
+ "decorator_list",
+ "returns",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(
+ self,
+ name,
+ args,
+ body,
+ decorator_list,
+ returns,
+ lineno,
+ col_offset,
+ end_lineno,
+ end_col_offset,
+ ):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.name = name
+ self.args = args
+ self.body = body
+ self.decorator_list = decorator_list
+ self.returns = returns
+
+
+class ClassDef(stmt):
+ _FIELDS = [
+ "name",
+ "bases",
+ "keywords",
+ "body",
+ "decorator_list",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(
+ self,
+ name,
+ bases,
+ keywords,
+ body,
+ decorator_list,
+ lineno,
+ col_offset,
+ end_lineno,
+ end_col_offset,
+ ):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.name = name
+ self.bases = bases
+ self.keywords = keywords
+ self.body = body
+ self.decorator_list = decorator_list
+
+
+class Return(stmt):
+ _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+
+
+class Delete(stmt):
+ _FIELDS = ["targets", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, targets, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.targets = targets
+
+
+class Assign(stmt):
+ _FIELDS = ["targets", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, targets, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.targets = targets
+ self.value = value
+
+
+class AugAssign(stmt):
+ _FIELDS = ["target", "op", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, target, op, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.target = target
+ self.op = op
+ self.value = value
+
+
+class AnnAssign(stmt):
+ _FIELDS = [
+ "target",
+ "annotation",
+ "value",
+ "simple",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(
+ self, target, annotation, value, simple, lineno, col_offset, end_lineno, end_col_offset
+ ):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.target = target
+ self.annotation = annotation
+ self.value = value
+ self.simple = simple
+
+
+class For(stmt):
+ _FIELDS = [
+ "target",
+ "iter",
+ "body",
+ "orelse",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(self, target, iter, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.target = target
+ self.iter = iter
+ self.body = body
+ self.orelse = orelse
+
+
+class While(stmt):
+ _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.test = test
+ self.body = body
+ self.orelse = orelse
+
+
+class If(stmt):
+ _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.test = test
+ self.body = body
+ self.orelse = orelse
+
+
+class With(stmt):
+ _FIELDS = ["items", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, items, body, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.items = items
+ self.body = body
+
+
+class Raise(stmt):
+ _FIELDS = ["exc", "cause", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, exc, cause, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.exc = exc
+ self.cause = cause
+
+
+class Try(stmt):
+ _FIELDS = [
+ "body",
+ "handlers",
+ "orelse",
+ "finalbody",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(
+ self, body, handlers, orelse, finalbody, lineno, col_offset, end_lineno, end_col_offset
+ ):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.body = body
+ self.handlers = handlers
+ self.orelse = orelse
+ self.finalbody = finalbody
+
+
+class Assert(stmt):
+ _FIELDS = ["test", "msg", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, test, msg, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.test = test
+ self.msg = msg
+
+
+class Import(stmt):
+ _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.names = names
+
+
+class ImportFrom(stmt):
+ _FIELDS = ["module", "names", "level", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, module, names, level, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.module = module
+ self.names = names
+ self.level = level
+
+
+class Global(stmt):
+ _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.names = names
+
+
+class Nonlocal(stmt):
+ _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.names = names
+
+
+class Expr(stmt):
+ _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+
+
+class Pass(stmt):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Break(stmt):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Continue(stmt):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class expr(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BoolOp(expr):
+ _FIELDS = ["op", "values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, op, values, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.op = op
+ self.values = values
+
+
+class BinOp(expr):
+ _FIELDS = ["left", "op", "right", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, left, op, right, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.left = left
+ self.op = op
+ self.right = right
+
+
+class UnaryOp(expr):
+ _FIELDS = ["op", "operand", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, op, operand, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.op = op
+ self.operand = operand
+
+
+class Lambda(expr):
+ _FIELDS = ["args", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, args, body, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.args = args
+ self.body = body
+
+
+class IfExp(expr):
+ _FIELDS = ["test", "body", "orelse", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.test = test
+ self.body = body
+ self.orelse = orelse
+
+
+class Dict(expr):
+ _FIELDS = ["keys", "values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, keys, values, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.keys = keys
+ self.values = values
+
+
+class Set(expr):
+ _FIELDS = ["elts", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, elts, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.elts = elts
+
+
+class ListComp(expr):
+ _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.elt = elt
+ self.generators = generators
+
+
+class SetComp(expr):
+ _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.elt = elt
+ self.generators = generators
+
+
+class DictComp(expr):
+ _FIELDS = ["key", "value", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, key, value, generators, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.key = key
+ self.value = value
+ self.generators = generators
+
+
+class GeneratorExp(expr):
+ _FIELDS = ["elt", "generators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.elt = elt
+ self.generators = generators
+
+
+class Yield(expr):
+ _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+
+
+class YieldFrom(expr):
+ _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+
+
+class Compare(expr):
+ _FIELDS = ["left", "ops", "comparators", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, left, ops, comparators, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.left = left
+ self.ops = ops
+ self.comparators = comparators
+
+
+class Call(expr):
+ _FIELDS = ["func", "args", "keywords", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, func, args, keywords, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.func = func
+ self.args = args
+ self.keywords = keywords
+
+
+class FormattedValue(expr):
+ _FIELDS = [
+ "value",
+ "conversion",
+ "format_spec",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(
+ self, value, conversion, format_spec, lineno, col_offset, end_lineno, end_col_offset
+ ):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+ self.conversion = conversion
+ self.format_spec = format_spec
+
+
+class JoinedStr(expr):
+ _FIELDS = ["values", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, values, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.values = values
+
+
+class Constant(expr):
+ _FIELDS = ["value", "kind", "s", "n", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, kind, s, n, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+ self.kind = kind
+ self.s = s
+ self.n = n
+
+
+class NamedExpr(expr):
+ _FIELDS = ["target", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, target, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.target = target
+ self.value = value
+
+
+class Attribute(expr):
+ _FIELDS = ["value", "attr", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, attr, ctx, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+ self.attr = attr
+ self.ctx = ctx
+
+
+class slice(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Slice(slice):
+ _FIELDS = ["lower", "upper", "step", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lower, upper, step, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.lower = lower
+ self.upper = upper
+ self.step = step
+
+
+class ExtSlice(slice):
+ _FIELDS = ["dims", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, dims, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.dims = dims
+
+
+class Index(slice):
+ _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+
+
+class Subscript(expr):
+ _FIELDS = ["value", "slice", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, slice, ctx, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+ self.slice = slice
+ self.ctx = ctx
+
+
+class Starred(expr):
+ _FIELDS = ["value", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, value, ctx, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.value = value
+ self.ctx = ctx
+
+
+class Name(expr):
+ _FIELDS = ["id", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, id, ctx, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.id = id
+ self.ctx = ctx
+
+
+class List(expr):
+ _FIELDS = ["elts", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, elts, ctx, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.elts = elts
+ self.ctx = ctx
+
+
+class Tuple(expr):
+ _FIELDS = ["elts", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, elts, ctx, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.elts = elts
+ self.ctx = ctx
+
+
+class expr_context(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class AugLoad(expr_context):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class AugStore(expr_context):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Param(expr_context):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Suite(mod):
+ _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.body = body
+
+
+class Del(expr_context):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Load(expr_context):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Store(expr_context):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class boolop(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class And(boolop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Or(boolop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class operator(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Add(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitAnd(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitOr(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class BitXor(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Div(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class FloorDiv(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class LShift(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Mod(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Mult(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class MatMult(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Pow(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class RShift(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Sub(operator):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class unaryop(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Invert(unaryop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Not(unaryop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class UAdd(unaryop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class USub(unaryop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class cmpop(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Eq(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Gt(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class GtE(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class In(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Is(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class IsNot(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class Lt(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class LtE(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class NotEq(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class NotIn(cmpop):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class comprehension(AST):
+ _FIELDS = [
+ "target",
+ "iter",
+ "ifs",
+ "is_async",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(self, target, iter, ifs, is_async, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.target = target
+ self.iter = iter
+ self.ifs = ifs
+ self.is_async = is_async
+
+
+class excepthandler(AST):
+ _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+
+
+class ExceptHandler(excepthandler):
+ _FIELDS = ["type", "name", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, type, name, body, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.type = type
+ self.name = name
+ self.body = body
+
+
+class arguments(AST):
+ _FIELDS = [
+ "args",
+ "vararg",
+ "kwonlyargs",
+ "kw_defaults",
+ "kwarg",
+ "defaults",
+ "posonlyargs",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(
+ self,
+ args,
+ vararg,
+ kwonlyargs,
+ kw_defaults,
+ kwarg,
+ defaults,
+ posonlyargs,
+ lineno,
+ col_offset,
+ end_lineno,
+ end_col_offset,
+ ):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.args = args
+ self.vararg = vararg
+ self.kwonlyargs = kwonlyargs
+ self.kw_defaults = kw_defaults
+ self.kwarg = kwarg
+ self.defaults = defaults
+ self.posonlyargs = posonlyargs
+
+
+class arg(AST):
+ _FIELDS = ["arg", "annotation", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, arg, annotation, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.arg = arg
+ self.annotation = annotation
+
+
+class keyword(AST):
+ _FIELDS = ["arg", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, arg, value, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.arg = arg
+ self.value = value
+
+
+class alias(AST):
+ _FIELDS = ["name", "asname", "lineno", "col_offset", "end_lineno", "end_col_offset"]
+
+ def __init__(self, name, asname, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.name = name
+ self.asname = asname
+
+
+class withitem(AST):
+ _FIELDS = [
+ "context_expr",
+ "optional_vars",
+ "lineno",
+ "col_offset",
+ "end_lineno",
+ "end_col_offset",
+ ]
+
+ def __init__(self, context_expr, optional_vars, lineno, col_offset, end_lineno, end_col_offset):
+ super().__init__(lineno, col_offset, end_lineno, end_col_offset)
+ self.context_expr = context_expr
+ self.optional_vars = optional_vars
+
+
+__all__ = [
+ "AST",
+ "mod",
+ "Module",
+ "Interactive",
+ "Expression",
+ "stmt",
+ "FunctionDef",
+ "ClassDef",
+ "Return",
+ "Delete",
+ "Assign",
+ "AugAssign",
+ "AnnAssign",
+ "For",
+ "While",
+ "If",
+ "With",
+ "Raise",
+ "Try",
+ "Assert",
+ "Import",
+ "ImportFrom",
+ "Global",
+ "Nonlocal",
+ "Expr",
+ "Pass",
+ "Break",
+ "Continue",
+ "expr",
+ "BoolOp",
+ "BinOp",
+ "UnaryOp",
+ "Lambda",
+ "IfExp",
+ "Dict",
+ "Set",
+ "ListComp",
+ "SetComp",
+ "DictComp",
+ "GeneratorExp",
+ "Yield",
+ "YieldFrom",
+ "Compare",
+ "Call",
+ "FormattedValue",
+ "JoinedStr",
+ "Constant",
+ "NamedExpr",
+ "Attribute",
+ "slice",
+ "Slice",
+ "ExtSlice",
+ "Index",
+ "Subscript",
+ "Starred",
+ "Name",
+ "List",
+ "Tuple",
+ "expr_context",
+ "AugLoad",
+ "AugStore",
+ "Param",
+ "Suite",
+ "Del",
+ "Load",
+ "Store",
+ "boolop",
+ "And",
+ "Or",
+ "operator",
+ "Add",
+ "BitAnd",
+ "BitOr",
+ "BitXor",
+ "Div",
+ "FloorDiv",
+ "LShift",
+ "Mod",
+ "Mult",
+ "MatMult",
+ "Pow",
+ "RShift",
+ "Sub",
+ "unaryop",
+ "Invert",
+ "Not",
+ "UAdd",
+ "USub",
+ "cmpop",
+ "Eq",
+ "Gt",
+ "GtE",
+ "In",
+ "Is",
+ "IsNot",
+ "Lt",
+ "LtE",
+ "NotEq",
+ "NotIn",
+ "comprehension",
+ "excepthandler",
+ "ExceptHandler",
+ "arguments",
+ "arg",
+ "keyword",
+ "alias",
+ "withitem",
+]
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/entry.py
similarity index 50%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/entry.py
index 923eb97d27..ed736ba282 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/entry.py
@@ -14,32 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
-import inspect
-from typing import Callable
+from tvm.ir.ir_builder import IRBuilder
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+from . import doc
+from .parser import Parser
+from .source import Source
-def prim_func(input_func: Callable) -> PrimFunc:
- """Decorate a python function as tvm script.
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+ if isinstance(program, str) and extra_vars is None:
+ from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
+ from tvm.script.parser import tir # pylint: disable=import-outside-toplevel
- Parameters
- ----------
- func : input_func
- The function to be parsed.
-
- Returns
- -------
- output : PrimFunc
- The result functions.
- """
- if inspect.isfunction(input_func):
- result = from_source(input_func)
- result.__name__ = input_func.__name__
- result.__qualname__ = input_func.__qualname__
- return result
-
- raise TypeError("Only function definitions are supported.")
+ extra_vars = {
+ "I": ir,
+ "ir": ir,
+ "T": tir,
+ "tir": tir,
+ }
+ source = Source(program)
+ parser = Parser(source)
+ with IRBuilder() as builder:
+ parser.parse(extra_vars=extra_vars)
+ return builder.get()
diff --git a/python/tvm/script/parser/evaluator.py b/python/tvm/script/parser/evaluator.py
new file mode 100644
index 0000000000..3899531b21
--- /dev/null
+++ b/python/tvm/script/parser/evaluator.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+ from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+ doc.Add: lambda a, b: a + b,
+ doc.Sub: lambda a, b: a - b,
+ doc.Mult: lambda a, b: a * b,
+ doc.Div: lambda a, b: a / b,
+ doc.FloorDiv: lambda a, b: a // b,
+ doc.Mod: lambda a, b: a % b,
+ doc.LShift: lambda a, b: a << b,
+ doc.RShift: lambda a, b: a >> b,
+ doc.BitOr: lambda a, b: a | b,
+ doc.BitXor: lambda a, b: a ^ b,
+ doc.BitAnd: lambda a, b: a & b,
+ doc.MatMult: lambda a, b: a @ b,
+ doc.Pow: lambda a, b: a**b,
+ doc.Eq: lambda a, b: a == b,
+ doc.NotEq: lambda a, b: a != b,
+ doc.Lt: lambda a, b: a < b,
+ doc.LtE: lambda a, b: a <= b,
+ doc.Gt: lambda a, b: a > b,
+ doc.GtE: lambda a, b: a >= b,
+ doc.Is: lambda a, b: a is b,
+ doc.IsNot: lambda a, b: a is not b,
+ doc.In: lambda a, b: a in b,
+ doc.NotIn: lambda a, b: a not in b,
+ doc.And: lambda a, b: a and b,
+ doc.Or: lambda a, b: a or b,
+ doc.Invert: lambda a: ~a,
+ doc.Not: lambda a: not a,
+ doc.UAdd: lambda a: +a,
+ doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+ parser: "Parser"
+ value_table: Dict[str, Any]
+ new_value_count: int
+
+ def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+ super().__init__()
+ self.parser = parser
+ self.value_table = value_table
+ self.new_value_count = 0
+
+ @staticmethod
+ def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+ self = ExprEvaluator(parser, value_table)
+ result = self._visit(node) # pylint: disable=protected-access
+ if isinstance(result, doc.Name):
+ if result.id not in self.value_table:
+ self.parser.report_error(result, f"Undefined variable: {result.id}")
+ return self.value_table[result.id]
+ if isinstance(result, doc.Constant):
+ return result.value
+ raise TypeError(f"Unexpected result type: {type(result)}")
+
+ def _add_intermediate_result(self, value: Any) -> doc.Name:
+ name = f"__tvm_tmp_value_{self.new_value_count}"
+ self.new_value_count += 1
+ self.value_table[name] = value
+ lineno = 0
+ col_offset = 0
+ return doc.Name(
+ id=name,
+ ctx=doc.Load(
+ lineno=lineno,
+ col_offset=col_offset,
+ end_lineno=None,
+ end_col_offset=None,
+ ),
+ lineno=lineno,
+ col_offset=col_offset,
+ end_lineno=None,
+ end_col_offset=None,
+ )
+
+ def _visit(self, node: doc.AST) -> Any:
+ if isinstance(node, list):
+ return [self._visit(n) for n in node]
+ if isinstance(node, tuple):
+ return tuple(self._visit(n) for n in node)
+ assert isinstance(node, doc.AST)
+ if isinstance(node, doc.Name):
+ if node.id not in self.value_table:
+ self.parser.report_error(node, f"Undefined variable: {node.id}")
+ return node
+ if isinstance(
+ node,
+ (
+ doc.Constant,
+ doc.expr_context,
+ doc.operator,
+ doc.boolop,
+ doc.unaryop,
+ doc.cmpop,
+ ),
+ ):
+ return node
+ if not isinstance(node, (doc.expr, doc.slice)):
+ return node
+ if isinstance(node, doc.Lambda):
+ return self._eval_lambda(node)
+ fields = {}
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ attr = getattr(node, field)
+ if isinstance(attr, (doc.AST, tuple, list)):
+ fields[field] = self._visit(attr)
+ else:
+ fields[field] = attr
+ try:
+ if isinstance(node, doc.BoolOp):
+ value = self._eval_bool_op(fields)
+ elif isinstance(node, doc.Compare):
+ value = self._eval_compare(fields)
+ elif isinstance(node, doc.UnaryOp):
+ value = self._eval_unary_op(fields)
+ elif isinstance(node, doc.BinOp):
+ value = self._eval_bin_op(fields)
+ elif isinstance(node, doc.Slice):
+ value = self._eval_slice(fields)
+ else:
+ value = self._eval_expr(node.__class__(**fields))
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.parser.report_error(node, str(e))
+ return self._add_intermediate_result(value)
+
+ def _eval_lambda(self, node: doc.Lambda) -> Any:
+ try:
+ value = self._eval_expr(node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.parser.report_error(node, str(e))
+ return self._add_intermediate_result(value)
+
+ def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+ op = fields["op"]
+ if not isinstance(op, (doc.And, doc.Or)):
+ raise TypeError(f"Unexpected operator: {op}")
+ value = self._eval_expr(fields["values"][0])
+ for rhs in fields["values"][1:]:
+ value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+ return value
+
+ def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+ value = self._eval_expr(fields["left"])
+ for op, rhs in zip(fields["ops"], fields["comparators"]):
+ value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+ return value
+
+ def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+ value = self._eval_expr(fields["operand"])
+ value = _eval_op(fields["op"], values=[value])
+ return value
+
+ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+ return _eval_op(
+ fields["op"],
+ values=[
+ self._eval_expr(fields["left"]),
+ self._eval_expr(fields["right"]),
+ ],
+ )
+
+ def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+ lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+ lower = self._eval_expr(lower) if lower is not None else None
+ upper = self._eval_expr(upper) if upper is not None else None
+ step = self._eval_expr(step) if step is not None else None
+
+ return slice(lower, upper, step)
+
+ def _eval_expr(self, v: Any) -> Any:
+ return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+ parser: "Parser",
+ node: Union[doc.expr, doc.Expression],
+ dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+ value_table = {}
+ if dict_globals is not None:
+ value_table.update(dict_globals)
+ return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+ parser: "Parser",
+ target: doc.expr,
+ source: Any,
+) -> Dict[str, Any]:
+ try:
+ return _eval_assign(target, source)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+ raise
+
+
+def _eval_expr(
+ node: Union[doc.expr, doc.Expression],
+ dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+ node = doc.from_doc(node)
+ if isinstance(node, ast.expr):
+ node = ast.Expression(body=node)
+ assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+ if dict_globals is None:
+ dict_globals = {}
+ node = ast.fix_missing_locations(node)
+ exe = compile(node, filename="<ast>", mode="eval")
+ return eval(exe, dict_globals) # pylint: disable=eval-used
+
+
+def _eval_op(
+ op: doc.AST,
+ values: List[Any],
+):
+ op_type = type(op) # pylint: disable=protected-access
+ for i, v in enumerate(values):
+ v_type = getattr(type(v), "_dispatch_type", None)
+ if v_type is None:
+ continue
+ f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+ if f is not None:
+ return f(*values)
+ return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+ target: doc.expr,
+ source: Any,
+) -> Dict[str, Any]:
+ target = doc.from_doc(target)
+ assert isinstance(target, ast.expr)
+ RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name
+ rhs_var_name = RHS_VAR_NAME
+ dict_locals = {rhs_var_name: source}
+ mod = ast.fix_missing_locations(
+ ast.Module(
+ body=[
+ ast.Assign(
+ targets=[target],
+ value=ast.Name(
+ id=rhs_var_name,
+ ctx=ast.Load(),
+ ),
+ )
+ ],
+ type_ignores=[],
+ )
+ )
+ exe = compile(mod, filename="<ast>", mode="exec")
+ exec(exe, {}, dict_locals) # pylint: disable=exec-used
+ del dict_locals[rhs_var_name]
+ return dict_locals
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser/ir/__init__.py
similarity index 89%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser/ir/__init__.py
index 926d17b166..bea08cfb1b 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script"""
-import tvm._ffi
-
-tvm._ffi._init_api("script", __name__)
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/entry.py
similarity index 66%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/entry.py
index 555659d0c5..353963f29b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,8 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Type
-from . import tir
+from tvm.ir import IRModule
-from .parser import ir_module, from_source
+from ..entry import parse
+from ..utils import inspect_class_capture
+
+
+def ir_module(f: Type) -> IRModule:
+ if not inspect.isclass(f):
+ raise TypeError(f"Expect a class, but got: {f}")
+
+ return parse(f, inspect_class_capture(f))
+
+
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 54%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/parser.py
index 555659d0c5..aec203c7d9 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,8 +14,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from tvm.ir import ir_builder as I
-from . import tir
+from .. import dispatch, doc
+from ..parser import Parser
-from .parser import ir_module, from_source
+
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+ with self.var_table.with_frame():
+ with I.ir_module():
+ with self.with_dispatch_token("ir"):
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+ pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+ pass
diff --git a/python/tvm/script/parser/parser.py b/python/tvm/script/parser/parser.py
new file mode 100644
index 0000000000..c0eb79f144
--- /dev/null
+++ b/python/tvm/script/parser/parser.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from typing import Any, Callable, Dict, List, Optional, Union
+
+from ...error import DiagnosticError
+from . import dispatch, doc
+from .diagnostics import Diagnostics
+from .evaluator import eval_assign, eval_expr
+from .source import Source
+from .utils import deferred
+from .var_table import VarTable
+
+DEFAULT_VISIT = {
+ "Interactive",
+ "Module",
+ "Expression",
+ "Pass",
+}
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+ def _wrapper(self: "Parser", node: doc.AST) -> None:
+ try:
+ return func(self, node)
+ except DiagnosticError:
+ raise
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+ raise
+
+ return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+ for token in [self.dispatch_tokens[-1], "default"]:
+ func = dispatch.get(token=token, type_name=type_name, default=None)
+ if func is not None:
+ return _dispatch_wrapper(func)
+ return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+ """The TVMScript parser"""
+
+ diag: Diagnostics
+ dispatch_tokens: List[str]
+ var_table: VarTable
+
+ def __init__(self, source: Source) -> None:
+ self.diag = Diagnostics(source)
+ self.dispatch_tokens = ["default"]
+ self.var_table = VarTable()
+
+ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+ if extra_vars is None:
+ extra_vars = {}
+ with self.var_table.with_frame():
+ for k, v in extra_vars.items():
+ self.var_table.add(k, v)
+ node = self.diag.source.as_ast()
+ self.visit(node)
+
+ def with_dispatch_token(self, token: str):
+ def pop_token():
+ self.dispatch_tokens.pop()
+
+ self.dispatch_tokens.append(token)
+ return deferred(pop_token)
+
+ def eval_expr(
+ self,
+ node: Union[doc.Expression, doc.expr],
+ extra_vars: Optional[Dict[str, Any]] = None,
+ ) -> Any:
+ var_values = self.var_table.get()
+ if extra_vars is not None:
+ for k, v in extra_vars.items():
+ var_values[k] = v
+ return eval_expr(self, node, var_values)
+
+ def eval_assign(
+ self,
+ target: doc.expr,
+ source: Any,
+ bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+ ) -> Dict[str, Any]:
+ var_values = eval_assign(self, target, source)
+ for k, v in var_values.items():
+ var = bind_value(self, target, k, v)
+ self.var_table.add(k, var)
+ return var_values
+
+ def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use
+ self.diag.error(node, msg)
+
+ def visit(self, node: doc.AST) -> None:
+ if isinstance(node, (list, tuple)):
+ for item in node:
+ self.visit(item)
+ return
+ if not isinstance(node, doc.AST):
+ return
+ name = node.__class__.__name__.split(".")[-1]
+ if name in DEFAULT_VISIT:
+ func = self.generic_visit
+ else:
+ func = getattr(self, "visit_" + name, None)
+ if func is None:
+ raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+ func(node)
+
+ def visit_body(self, node: List[doc.stmt]) -> Any:
+ for stmt in node:
+ self.visit(stmt)
+
+ def visit_tvm_annotation(self, node: doc.expr) -> Any:
+ return _dispatch(self, "tvm_annotation")(self, node)
+
+ def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
+ if not node.decorator_list:
+ self.report_error(node, "Function must be decorated")
+ # TODO: only the last decorator is parsed
+ decorator = self.eval_expr(node.decorator_list[-1])
+ if not hasattr(decorator, "dispatch_token"):
+ self.report_error(node, "The parser does not understand the decorator")
+ token = decorator.dispatch_token
+ func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+ if func is None:
+ self.report_error(node, "The parser does not understand the decorator")
+ func(self, node)
+
+ def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
+ func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+ if func is None:
+ self.report_error(node, "The parser does not understand the decorator")
+ func(self, node)
+
+ def visit_arguments(self, node: doc.arguments) -> Any:
+ return _dispatch(self, "arguments")(self, node)
+
+ def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "For")(self, node)
+
+ def visit_While(self, node: doc.While) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "While")(self, node)
+
+ def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "With")(self, node)
+
+ def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Assign")(self, node)
+
+ def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Expr")(self, node)
+
+ def visit_If(self, node: doc.If) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "If")(self, node)
+
+ def visit_AnnAssign(self, node: doc.AnnAssign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "AnnAssign")(self, node)
+
+ def visit_AugAssign(self, node: doc.AugAssign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "AugAssign")(self, node)
+
+ def visit_Assert(self, node: doc.Assert) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Assert")(self, node)
diff --git a/python/tvm/script/parser/source.py b/python/tvm/script/parser/source.py
new file mode 100644
index 0000000000..9674dc5494
--- /dev/null
+++ b/python/tvm/script/parser/source.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+import sys
+from typing import Union
+
+from . import doc
+
+
+class Source:
+ source_name: str
+ start_line: int
+ start_column: int
+ source: str
+ full_source: str
+
+ def __init__(self, program: Union[str, doc.AST]):
+ if isinstance(program, str):
+ self.source_name = "<str>"
+ self.start_line = 1
+ self.start_column = 0
+ self.source = program
+ self.full_source = program
+ return
+
+ self.source_name = inspect.getsourcefile(program) # type: ignore
+ lines, self.start_line = inspect.getsourcelines(program) # type: ignore
+ if lines:
+ self.start_column = len(lines[0]) - len(lines[0].lstrip())
+ else:
+ self.start_column = 0
+ if self.start_column and lines:
+ self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+ else:
+ self.source = "".join(lines)
+ try:
+ # It will cause a problem when running in Jupyter Notebook.
+ # `mod` will be <module '__main__'>, which is a built-in module
+ # and `getsource` will throw a TypeError
+ mod = inspect.getmodule(program)
+ if mod:
+ self.full_source = inspect.getsource(mod)
+ else:
+ self.full_source = self.source
+ except TypeError:
+ # It's a work around for Jupyter problem.
+ # Since `findsource` is an internal API of inspect, we just use it
+ # as a fallback method.
+ src, _ = inspect.findsource(program) # type: ignore
+ self.full_source = "".join(src)
+
+ def as_ast(self) -> doc.AST:
+ return doc.parse(self.source)
+
+
+_getfile = inspect.getfile
+
+
+def _patched_inspect_getfile(obj):
+ if not inspect.isclass(obj):
+ return _getfile(obj)
+ mod = getattr(obj, "__module__", None)
+ if mod is not None:
+ file = getattr(sys.modules[mod], "__file__", None)
+ if file is not None:
+ return file
+ for _, member in inspect.getmembers(obj):
+ if inspect.isfunction(member):
+ if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+ return inspect.getfile(member)
+ raise TypeError(f"Source for {obj:!r} not found")
+
+
+inspect.getfile = _patched_inspect_getfile
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 78%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..caa51744c8 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from tvm.tir.ir_builder_v2 import * # pylint: disable=redefined-builtin
-from . import tir
-
-from .parser import ir_module, from_source
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..b92341aecd
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+from tvm.tir.ir_builder_v2 import buffer_decl, ptr
+
+from ..entry import parse
+from ..utils import inspect_function_capture
+
+
+def _is_defined_in_class(frames):
+ if len(frames) > 2:
+ maybe_class_frame = frames[2]
+ statement_list = maybe_class_frame[4]
+ first_statement = statement_list[0]
+ line = first_statement.strip()
+ if line.startswith("class "):
+ return True
+ if line.startswith("@") and "ir_module" in line:
+ return True
+ return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+ if not inspect.isfunction(f):
+ raise TypeError(f"Expect a function, but got: {f}")
+ if _is_defined_in_class(inspect.stack()):
+ return f
+ return parse(f, inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+ def __call__(
+ self,
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+ ) -> Buffer:
+ return buffer_decl(
+ shape,
+ dtype=dtype,
+ data=data,
+ strides=strides,
+ elem_offset=elem_offset,
+ scope=scope,
+ align=align,
+ offset_factor=offset_factor,
+ buffer_type=buffer_type,
+ axis_separators=axis_separators,
+ )
+
+ def __getitem__(self, keys) -> Buffer:
+ return self(*keys) # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+ def __call__(self, dtype, storage_scope="global"):
+ if callable(dtype):
+ dtype = dtype().dtype
+ return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
+
+ def __getitem__(self, keys):
+ if not isinstance(keys, tuple):
+ return self(keys)
+ return self(*keys)
+
+
+Buffer = BufferProxy()
+Ptr = PtrProxy()
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..11ee92ad29
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .. import doc
+from ..dispatch import OpMethod, register_op
+
+
+def _register_expr_op(ty: Type): # pylint: disable=invalid-name
+ ty._dispatch_type = ty # pylint: disable=protected-access
+
+ def _and(a, b):
+ if isinstance(a, bool):
+ a = IntImm("bool", a)
+ if isinstance(b, bool):
+ b = IntImm("bool", b)
+ return tir.And(a, b)
+
+ def _or(a, b):
+ if isinstance(a, bool):
+ a = IntImm("bool", a)
+ if isinstance(b, bool):
+ b = IntImm("bool", b)
+ return tir.Or(a, b)
+
+ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
+ register_op(ty, op, i)(m)
+
+ for i in [0, 1]:
+ # Case 1. binop
+ r(doc.Add, i, tir.Add)
+ r(doc.Sub, i, tir.Sub)
+ r(doc.Mult, i, tir.Mul)
+ r(doc.Div, i, tir.Div)
+ r(doc.FloorDiv, i, tir.FloorDiv)
+ r(doc.Mod, i, tir.FloorMod)
+ r(doc.LShift, i, lambda a, b: a << b)
+ r(doc.RShift, i, lambda a, b: a >> b)
+ r(doc.BitOr, i, lambda a, b: a | b)
+ r(doc.BitXor, i, lambda a, b: a ^ b)
+ r(doc.BitAnd, i, lambda a, b: a & b)
+ # doc.MatMult <-- not implemented
+ # doc.Pow <-- not implemented
+ # Case 2. cmpop
+ r(doc.Eq, i, tir.EQ)
+ r(doc.NotEq, i, tir.NE)
+ r(doc.Lt, i, tir.LT)
+ r(doc.LtE, i, tir.LE)
+ r(doc.Gt, i, tir.GT)
+ r(doc.GtE, i, tir.GE)
+ # doc.Is <-- not implemented
+ # doc.IsNot <-- not implemented
+ # doc.In <-- not implemented
+ # doc.NotIn <-- not implemented
+ # Case 3. boolop
+ r(doc.And, i, _and)
+ r(doc.Or, i, _or)
+ for i in [0]:
+ # Case 4. unaryop
+ r(doc.Invert, i, lambda a: ~a)
+ r(doc.Not, i, tir.Not)
+ r(doc.UAdd, i, lambda a: +a)
+ r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..c66abf2013
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.ir.ir_builder import IRBuilderFrame as Frame
+from tvm.ir.ir_builder import name
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+from tvm.tir import ir_builder_v2 as T
+
+from .. import dispatch, doc
+from ..parser import Parser
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, (Buffer, Var)):
+ name(var_name, value)
+ return value
+ else:
+ self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+ raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, Var):
+ name(var_name, value)
+ return value
+ else:
+ self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+ raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, Frame):
+ value.add_callback(partial(value.__exit__, None, None, None))
+ res = value.__enter__()
+ name(var_name, res)
+ return res
+ elif isinstance(value, (Buffer, IterVar)) or (
+ isinstance(value, Var) and not self.var_table.exist(value)
+ ):
+ name(var_name, value)
+ return value
+ elif isinstance(value, PrimExpr):
+ var = T.var(value.dtype)
+ name(var_name, var)
+ frame = T.let(var, value)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+ return var
+ return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+ for_frame = self.eval_expr(node.iter)
+ if not isinstance(for_frame, T.frame.ForFrame):
+ self.report_error(
+ node.iter,
+ "Expect the for loop to be one of the following: "
+ "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+ )
+ with self.var_table.with_frame():
+ with for_frame as iters:
+ self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+ with self.var_table.with_frame():
+ cond = self.eval_expr(node.test)
+ with T.While(cond):
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+ if len(node.targets) != 1:
+ self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+ lhs = node.targets[0]
+ rhs = self.eval_expr(node.value)
+ if isinstance(lhs, doc.Subscript):
+ if isinstance(lhs.slice, doc.Tuple):
+ indices = []
+ for index in lhs.slice.elts:
+ indices.append(self.eval_expr(index))
+ else:
+ indices = [self.eval_expr(lhs.slice)]
+ T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+ else:
+ self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+ lhs_pos = (
+ node.target.lineno,
+ node.target.col_offset,
+ node.target.end_lineno,
+ node.target.end_col_offset,
+ )
+ rhs_pos = (
+ node.value.lineno,
+ node.value.col_offset,
+ node.value.end_lineno,
+ node.value.end_col_offset,
+ )
+ node.target.ctx = doc.Load(*lhs_pos)
+ with self.var_table.with_frame():
+ lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+ rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+ lhs_expr = self.eval_expr(node.target)
+ rhs_expr = self.eval_expr(node.value)
+ self.var_table.add(lhs_name, lhs_expr)
+ self.var_table.add(rhs_name, rhs_expr)
+ op = doc.BinOp(
+ doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+ node.op,
+ doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+ *lhs_pos,
+ )
+ rhs = self.eval_expr(op)
+ lhs = node.target
+ lhs.ctx = doc.Store(*lhs_pos)
+ if isinstance(lhs, doc.Subscript):
+ if isinstance(lhs.slice, doc.Tuple):
+ indices = []
+ for index in lhs.slice.elts:
+ indices.append(self.eval_expr(index))
+ else:
+ indices = [self.eval_expr(lhs.slice)]
+ T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+ else:
+ self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+ lhs = node.target
+ rhs = self.eval_expr(node.value)
+ ann_var = self.visit_tvm_annotation(node.annotation)
+ if not isinstance(ann_var, Var):
+ self.report_error(node.annotation, "Annotation should be Var")
+ self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+ frame = T.let(ann_var, rhs)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+ with contextlib.ExitStack() as stack:
+ stack.enter_context(self.var_table.with_frame())
+ for item in node.items:
+ frame = self.eval_expr(item.context_expr)
+ if not isinstance(frame, Frame):
+ self.report_error(
+ item.context_expr, "Invalid context expression in the with-statement."
+ )
+ rhs = stack.enter_context(frame)
+ if item.optional_vars is not None:
+ self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+ with self.var_table.with_frame():
+ self.var_table.add("range", T.serial)
+ with T.prim_func():
+ T.func_name(node.name)
+ if node.returns is not None:
+ ret_type = self.eval_expr(node.returns)
+ if callable(ret_type):
+ ret_type = PrimType(ret_type().dtype)
+ T.func_ret(ret_type)
+ with self.with_dispatch_token("tir"):
+ self.visit(node.args)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+ # TODO: handle different types of arguments:
+ # - vararg: arg | None
+ # - kwonlyargs: list[arg]
+ # - kw_defaults: list[expr | None]
+ # - kwarg: arg | None
+ # - defaults: list[expr]
+ # - posonlyargs: list[arg]
+ arg: doc.arg
+ for arg in node.args:
+ if arg.annotation is None:
+ self.report_error(arg, "Type annotation is required for function parameters.")
+ param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+ self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+ annotation = self.eval_expr(node)
+ if callable(annotation):
+ annotation = annotation()
+ return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+ res = self.eval_expr(node.value)
+ if isinstance(res, Frame):
+ res.add_callback(partial(res.__exit__, None, None, None))
+ res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+ with self.var_table.with_frame():
+ with T.If(self.eval_expr(node.test)):
+ with T.Then():
+ self.visit_body(node.body)
+ if len(node.orelse):
+ with T.Else():
+ self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+ cond = self.eval_expr(node.test)
+ msg = self.eval_expr(node.msg)
+ frame = T.Assert(cond, msg)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
diff --git a/python/tvm/script/parser/utils.py b/python/tvm/script/parser/utils.py
new file mode 100644
index 0000000000..9d681236c8
--- /dev/null
+++ b/python/tvm/script/parser/utils.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+from contextlib import contextmanager
+from typing import Any, Callable, Dict
+
+
+def deferred(f: Callable[[], None]):
+ @contextmanager
+ def context():
+ try:
+ yield
+ finally:
+ f()
+
+ return context()
+
+
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+ prefix = "tvm."
+ result = {}
+ captured = {
+ **inspect.getclosurevars(func).nonlocals,
+ **func.__globals__,
+ }
+ for k, v in captured.items():
+ # Case 1: a module like `T` or `tvm.tir.ir_builder`
+ if inspect.ismodule(v) and v.__name__.startswith(prefix):
+ result[k] = v
+ continue
+ # Case 2: a function like `T.match_buffer`
+ if hasattr(v, "__module__") and v.__module__.startswith(prefix):
+ result[k] = v
+ continue
+ # Case 3: atomic types
+ if v is None or isinstance(v, (int, float, str, bool)):
+ result[k] = v
+ continue
+ return result
+
+
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+ result: Dict[str, Any] = {}
+ for _, v in cls.__dict__.items():
+ if inspect.isfunction(v):
+ func_vars = inspect_function_capture(v)
+ result.update(**func_vars)
+ return result
diff --git a/python/tvm/script/parser/var_table.py b/python/tvm/script/parser/var_table.py
new file mode 100644
index 0000000000..32fced625a
--- /dev/null
+++ b/python/tvm/script/parser/var_table.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The symbol table of variable values"""
+
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Set
+
+from .utils import deferred
+
+
+class VarTableFrame:
+ vars: Set[str]
+
+ def __init__(self):
+ self.vars = set()
+
+ def add(self, var: str):
+ if var in self.vars:
+ raise ValueError(f"Variable {var} already defined in current scope")
+ self.vars.add(var)
+
+ def pop_all(self, fn_pop: Callable[[str], None]):
+ for var in self.vars:
+ fn_pop(var)
+ self.vars.clear()
+
+
+class VarTable:
+
+ frames: List[VarTableFrame]
+ name2value: Dict[str, List[Any]]
+
+ def __init__(self):
+ self.frames = []
+ self.name2value = defaultdict(list)
+
+ def with_frame(self):
+ def pop_frame():
+ frame = self.frames.pop()
+ frame.pop_all(lambda name: self.name2value[name].pop())
+
+ self.frames.append(VarTableFrame())
+ return deferred(pop_frame)
+
+ def add(self, var: str, value: Any):
+ self.frames[-1].add(var)
+ self.name2value[var].append(value)
+
+ def get(self) -> Dict[str, Any]:
+ return {key: values[-1] for key, values in self.name2value.items() if values}
+
+ def exist(self, value: Any):
+ for v in self.name2value.values():
+ if v is value:
+ return True
+ return False
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser_v1/__init__.py
similarity index 95%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser_v1/__init__.py
index 555659d0c5..004e947bf6 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser_v1/__init__.py
@@ -17,5 +17,4 @@
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
from . import tir
-
-from .parser import ir_module, from_source
+from .parser import from_source, ir_module
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser_v1/_ffi_api.py
similarity index 100%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser_v1/_ffi_api.py
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py
similarity index 98%
rename from python/tvm/script/context_maintainer.py
rename to python/tvm/script/parser_v1/context_maintainer.py
index f7f16855c7..400baacc4b 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/parser_v1/context_maintainer.py
@@ -16,16 +16,16 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
+from typing import Callable, Dict, List, Mapping, Optional, Union
+import synr
import tvm
from tvm.ir import Span
from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
+from tvm.tir import Buffer, MatchBufferRegion, PrimExpr, Stmt, Var
from tvm.tir.expr import IterVar
+
from .tir.node import BufferSlice
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/parser_v1/diagnostics.py
similarity index 95%
rename from python/tvm/script/diagnostics.py
rename to python/tvm/script/parser_v1/diagnostics.py
index e676461ab3..b15997552f 100644
--- a/python/tvm/script/diagnostics.py
+++ b/python/tvm/script/parser_v1/diagnostics.py
@@ -17,11 +17,11 @@
"""Bridge from synr's (the library used for parsing the python AST)
DiagnosticContext to TVM's diagnostics
"""
-from synr import DiagnosticContext, ast
-
import tvm
+from synr import DiagnosticContext, ast
+from tvm.ir.diagnostics import Diagnostic
from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
+from tvm.ir.diagnostics import DiagnosticLevel, get_renderer
class TVMDiagnosticCtx(DiagnosticContext):
diff --git a/python/tvm/script/highlight.py b/python/tvm/script/parser_v1/highlight.py
similarity index 100%
rename from python/tvm/script/highlight.py
rename to python/tvm/script/parser_v1/highlight.py
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/parser_v1/meta_unparser.py
similarity index 100%
rename from python/tvm/script/meta_unparser.py
rename to python/tvm/script/parser_v1/meta_unparser.py
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser_v1/parser.py
similarity index 99%
rename from python/tvm/script/parser.py
rename to python/tvm/script/parser_v1/parser.py
index 908af081c9..b2b7e388f1 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser_v1/parser.py
@@ -20,35 +20,34 @@ We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
different python versions. Synr also provides an error handling context that we
use for error reporting.
"""
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
+import inspect
import json
import operator
-import inspect
+
+# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
+import types
from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
import tvm
+from synr import Transformer, ast, to_ast
from tvm import IRModule
from tvm._ffi.base import TVMError
from tvm.ir import GlobalVar
from tvm.ir.function import BaseFunc
from tvm.tir import buffer
from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
+from . import _ffi_api, tir
from .context_maintainer import ContextMaintainer
+from .diagnostics import TVMDiagnosticCtx
from .meta_unparser import MetaUnparser
from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
+from .tir import ty
from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
+from .tir.node import BufferSlice, Slice
+from .tir.scope_handler import ForScopeHandler, ScopeHandler, WithScopeHandler
from .tir.special_stmt import SpecialStmt
-from .tir import ty
+from .utils import call_with_error_reporting, synr_span_from_tvm, tvm_span_from_synr
class CallArgumentReader(object):
diff --git a/python/tvm/script/registry.py b/python/tvm/script/parser_v1/registry.py
similarity index 97%
rename from python/tvm/script/registry.py
rename to python/tvm/script/parser_v1/registry.py
index e7d90dd515..e816b90f5d 100644
--- a/python/tvm/script/registry.py
+++ b/python/tvm/script/parser_v1/registry.py
@@ -17,7 +17,7 @@
"""TVM Script Parser Function Registry """
# pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
import types
-from typing import Union, Callable, Dict, Optional, Any
+from typing import Any, Callable, Dict, Optional, Union
class Registry(object):
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser_v1/tir/__init__.py
similarity index 100%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser_v1/tir/__init__.py
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi
similarity index 100%
rename from python/tvm/script/tir/__init__.pyi
rename to python/tvm/script/parser_v1/tir/__init__.pyi
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py
similarity index 99%
rename from python/tvm/script/tir/intrin.py
rename to python/tvm/script/parser_v1/tir/intrin.py
index 627b89086a..008d7495a0 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/parser_v1/tir/intrin.py
@@ -17,11 +17,12 @@
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
import builtins
-from typing import List, Any
+from typing import Any, List
import tvm.tir
+from tvm.target import codegen
+
from ..registry import register
-from ...target import codegen
from ..utils import get_param_list, tvm_span_from_synr
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/parser_v1/tir/node.py
similarity index 98%
rename from python/tvm/script/tir/node.py
rename to python/tvm/script/parser_v1/tir/node.py
index 29e79607fb..cfaf9df476 100644
--- a/python/tvm/script/tir/node.py
+++ b/python/tvm/script/parser_v1/tir/node.py
@@ -17,12 +17,13 @@
# pylint: disable=redefined-builtin
"""TVM Script nodes."""
-from typing import Optional, Union, List, Callable
+from typing import Callable, List, Optional, Union
+
import synr
from tvm.arith import Analyzer
+from tvm.ir import Range, Span
from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
+from tvm.tir import Buffer, BufferLoad, BufferRegion, IntImm, PrimExpr, Ramp
class Slice:
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser_v1/tir/prim_func.py
similarity index 97%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser_v1/tir/prim_func.py
index 923eb97d27..a5fdcc15c5 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser_v1/tir/prim_func.py
@@ -19,7 +19,8 @@
import inspect
from typing import Callable
-from tvm.tir.function import PrimFunc
+from tvm.tir import PrimFunc
+
from ..parser import from_source
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/parser_v1/tir/scope_handler.py
similarity index 98%
rename from python/tvm/script/tir/scope_handler.py
rename to python/tvm/script/parser_v1/tir/scope_handler.py
index da7545c9a9..cd2167f4ab 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/parser_v1/tir/scope_handler.py
@@ -16,24 +16,19 @@
# under the License.
"""TVM Script Parser Scope Handler Classes"""
# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
-import synr
import numpy as np
+import synr
import tvm.tir
+from tvm.ir import Range, Span
from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
-
-from .node import BufferSlice
+from tvm.tir import Buffer, BufferRegion, ForKind, IterVar, PrimExpr, Stmt, Var
from ..context_maintainer import ContextMaintainer
from ..registry import register
-from ..utils import (
- get_param_list,
- tvm_span_from_synr,
- call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
class ScopeHandler:
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py
similarity index 99%
rename from python/tvm/script/tir/special_stmt.py
rename to python/tvm/script/parser_v1/tir/special_stmt.py
index 15502055b7..42a90f647f 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/parser_v1/tir/special_stmt.py
@@ -17,27 +17,21 @@
"""TVM Script Parser Special Stmt Classes"""
# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
# pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
import synr
+import tvm.tir
from synr import ast
+from tvm.ir import Span
from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
from tvm.runtime import Object, String
from tvm.target import Target
-from tvm.ir import Span
from tvm.tir import IntImm, IterVar, Var
-from .node import BufferSlice
-
from ..context_maintainer import BlockInfo, ContextMaintainer
from ..registry import register
-from ..utils import (
- get_param_list,
- tvm_span_from_synr,
- call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
def convert_to_int(
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/parser_v1/tir/ty.py
similarity index 99%
rename from python/tvm/script/tir/ty.py
rename to python/tvm/script/parser_v1/tir/ty.py
index a64485b215..8a048c07b5 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/parser_v1/tir/ty.py
@@ -23,6 +23,7 @@ a wrapper for uniform Type system in IR
from numbers import Integral
import tvm
+
from .special_stmt import SpecialStmt, convert_to_int
diff --git a/python/tvm/script/utils.py b/python/tvm/script/parser_v1/utils.py
similarity index 97%
rename from python/tvm/script/utils.py
rename to python/tvm/script/parser_v1/utils.py
index c655a62237..f358a90081 100644
--- a/python/tvm/script/utils.py
+++ b/python/tvm/script/parser_v1/utils.py
@@ -16,13 +16,12 @@
# under the License.
"""Helper functions in TVM Script Parser"""
-from typing import Callable, List, Any, Optional, Tuple
-
import inspect
-import synr
+from typing import Any, Callable, List, Optional, Tuple
-from tvm.ir import Span, SourceName
+import synr
from tvm.error import DiagnosticError
+from tvm.ir import SourceName, Span
def get_param_list(
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index c64b7dfe71..0d949234d8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -19,50 +19,180 @@
from tvm.ir import PrimExpr
from tvm.runtime import const
-from .buffer import Buffer, decl_buffer, DataProducer
-from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
-from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
-from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
-from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
-from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
-
-from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
+from . import analysis
+from . import ir_builder_v1 as ir_builder
+from . import schedule, stmt_functor, transform, usmp
+from .buffer import Buffer, DataProducer, decl_buffer
+from .data_layout import BijectiveLayout, Layout, bijective_layout, layout
+from .expr import (
+ EQ,
+ GE,
+ GT,
+ LE,
+ LT,
+ NE,
+ Add,
+ And,
+ Any,
+ Broadcast,
+ BufferLoad,
+ Call,
+ CallEffectKind,
+ Cast,
+ CommReducer,
+ Div,
+ FloatImm,
+ FloorDiv,
+ FloorMod,
+ IntImm,
+ IterVar,
+ Let,
+ Load,
+ Max,
+ Min,
+ Mod,
+ Mul,
+ Not,
+ Or,
+ ProducerLoad,
+ Ramp,
+ Reduce,
+ Select,
+ Shuffle,
+ SizeVar,
+ StringImm,
+ Sub,
+ Var,
+)
+from .function import IndexMap, PrimFunc, TensorIntrin
+from .op import (
+ TVMBackendAllocWorkspace,
+ TVMBackendFreeWorkspace,
+ abs,
+ acos,
+ acosh,
+ address_of,
+ all,
+ any,
+ asin,
+ asinh,
+ atan,
+ atan2,
+ atanh,
+ call_cpacked,
+ call_cpacked_lowered,
+ call_extern,
+ call_intrin,
+ call_llvm_intrin,
+ call_llvm_pure_intrin,
+ call_packed,
+ call_packed_lowered,
+ call_pure_extern,
+ ceil,
+ ceildiv,
+ clz,
+ comm_reducer,
+ copysign,
+ cos,
+ cosh,
+ div,
+ erf,
+ exp,
+ exp2,
+ exp10,
+ floor,
+ floordiv,
+ floormod,
+ fmod,
+ hypot,
+ if_then_else,
+ indexdiv,
+ indexmod,
+ isfinite,
+ isinf,
+ isnan,
+ isnullptr,
+ ldexp,
+ likely,
+ log,
+ log1p,
+ log2,
+ log10,
+ lookup_param,
+ max,
+ max_value,
+ min,
+ min_value,
+ mma_fill,
+ mma_store,
+ nearbyint,
+ nextafter,
+ popcount,
+ power,
+ ptx_commit_group,
+ ptx_cp_async,
+ ptx_ldmatrix,
+ ptx_mma,
+ ptx_mma_sp,
+ ptx_wait_group,
+ q_multiply_shift,
+ ret,
+ round,
+ rsqrt,
+ shift_left,
+ shift_right,
+ sigmoid,
+ sin,
+ sinh,
+ sqrt,
+ sum,
+ tan,
+ tanh,
+ trace,
+ trunc,
+ truncdiv,
+ truncmod,
+ tvm_access_ptr,
+ tvm_bmma_sync,
+ tvm_fill_fragment,
+ tvm_load_matrix_sync,
+ tvm_mma_sync,
+ tvm_stack_alloca,
+ tvm_stack_make_array,
+ tvm_stack_make_shape,
+ tvm_store_matrix_sync,
+ tvm_struct_get,
+ tvm_struct_set,
+ tvm_thread_allreduce,
+ tvm_throw_last_error,
+ tvm_tuple,
+)
+from .schedule import BlockScope, Schedule, ScheduleError, ScheduleState, StmtSRef
from .stmt import (
- BufferStore,
- BufferRealize,
- Store,
- ProducerStore,
Allocate,
AllocateConst,
+ AssertStmt,
AttrStmt,
+ Block,
+ BlockRealize,
+ BufferRealize,
+ BufferRegion,
+ BufferStore,
DeclBuffer,
+ Evaluate,
+ For,
+ ForKind,
+ IfThenElse,
+ LetStmt,
+ MatchBufferRegion,
+ Prefetch,
+ ProducerRealize,
+ ProducerStore,
+ SeqStmt,
+ Stmt,
+ Store,
+ While,
+ stmt_list,
+ stmt_seq,
+ type_annotation,
)
-
-from .stmt import ProducerRealize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
-from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-
-from .function import PrimFunc, TensorIntrin, IndexMap
-
-from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
-from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
-from .op import sin, sinh, asin, asinh
-from .op import cos, cosh, acos, acosh
-from .op import tan, tanh, atan, atan2, atanh
-from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
-from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
-from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
-from .op import comm_reducer, min, max, sum
-from .op import q_multiply_shift
-
-from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
-
-from . import schedule
-from . import ir_builder
-from . import transform
-from . import analysis
-from . import stmt_functor
-from . import usmp
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/tir/_ffi_ir_builder_api.py
similarity index 88%
rename from python/tvm/script/_ffi_api.py
rename to python/tvm/tir/_ffi_ir_builder_api.py
index 926d17b166..61b288d498 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/tir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
import tvm._ffi
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder.tir", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 13674daa24..ea220fea22 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -21,9 +21,9 @@ from typing import Dict, List, Union
from tvm import Object
from tvm.ir import IRModule
from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
+from tvm.tir.stmt import Block, BufferRegion, PrimExpr, Stmt
-from .. import Buffer, Stmt
+from ..buffer import Buffer
from ..function import PrimFunc
from . import _ffi_api
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index d9b0aec76a..6752082fa1 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -20,7 +20,7 @@ import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
-from tvm.ir import PrimExpr, PointerType, PrimType
+from tvm.ir import PrimExpr, PointerType, PrimType, Range
from . import _ffi_api
@@ -176,6 +176,40 @@ class Buffer(Object):
"""
return _ffi_api.BufferOffsetOf(self, indices) # type: ignore
+ def __getitem__(self, indices):
+ from .expr import BufferLoad, Ramp
+ from .stmt import BufferRegion
+ from ..arith import Analyzer
+
+ if not isinstance(indices, (tuple, list)):
+ indices = [indices]
+ if any(isinstance(index, slice) and index.step is None for index in indices):
+ region = []
+ for index in indices:
+ if isinstance(index, slice):
+ region.append(
+ Range.from_min_extent(
+ index.start, Analyzer().simplify(index.stop - index.start)
+ )
+ )
+ else:
+ region.append(Range.from_min_extent(index, 1))
+ return BufferRegion(self, region)
+ else:
+ expr_indices = []
+ for index in indices:
+ if isinstance(index, slice):
+ lanes = Analyzer().simplify(
+ (index.stop - index.start + index.step - 1) // index.step
+ )
+ if lanes == 1:
+ expr_indices.append(index.start)
+ else:
+ expr_indices.append(Ramp(index.start, index.step, int(lanes)))
+ else:
+ expr_indices.append(index)
+ return BufferLoad(self, expr_indices)
+
def decl_buffer(
shape,
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
assert(y.a == x)
"""
from typing import Optional, Union
-from tvm import ir
+
import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
from . import _ffi_api
+from . import generic as _generic
def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
class ExprOp(object):
"""Operator overloading for Expr like expressions."""
- # TODO(tkonolige): use inspect to add source information to these objects
-
def __add__(self, other):
return _generic.add(self, other)
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
"""
def __init__(self, condition, true_value, false_value, span=None):
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
self.__init_handle_by_constructor__(
_ffi_api.Select, condition, true_value, false_value, span # type: ignore
)
diff --git a/python/tvm/tir/ir_builder_frame.py b/python/tvm/tir/ir_builder_frame.py
new file mode 100644
index 0000000000..a1f457aad2
--- /dev/null
+++ b/python/tvm/tir/ir_builder_frame.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IRBuilder for TIR"""
+
+from typing import List
+
+from tvm._ffi import register_object as _register_object
+from tvm.ir.ir_builder import IRBuilderFrame
+
+from .buffer import Buffer
+from .expr import Var
+
+
+@_register_object("ir_builder.tir.TIRFrame")
+class TIRFrame(IRBuilderFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.BlockFrame")
+class BlockFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.ForFrame")
+class ForFrame(TIRFrame):
+ def __enter__(self) -> List[Var]:
+ super().__enter__()
+ return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("ir_builder.tir.PrimFuncFrame")
+class PrimFuncFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
+
+
+@_register_object("ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
+
+
+@_register_object("ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder_v1.py
similarity index 99%
rename from python/tvm/tir/ir_builder.py
rename to python/tvm/tir/ir_builder_v1.py
index ce8cd1b403..8a68faaac6 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder_v1.py
@@ -17,13 +17,13 @@
"""Developer API of IR node builder make function."""
import tvm
from tvm._ffi.base import string_types
-from tvm.runtime import ObjectGeneric, convert, const
from tvm.ir import container as _container
+from tvm.runtime import ObjectGeneric, const, convert
-from . import stmt as _stmt
-from . import expr as _expr
from . import buffer as _buffer
+from . import expr as _expr
from . import op
+from . import stmt as _stmt
class WithScope(object):
diff --git a/python/tvm/tir/ir_builder_v2.py b/python/tvm/tir/ir_builder_v2.py
new file mode 100644
index 0000000000..f297ec69ff
--- /dev/null
+++ b/python/tvm/tir/ir_builder_v2.py
@@ -0,0 +1,901 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""IRBuilder for TIR"""
+import functools
+import inspect
+from numbers import Integral
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
+from tvm.target import Target as target
+
+from . import _ffi_ir_builder_api as _ffi_api
+from . import ir_builder_frame as frame
+from . import op as _tir_op
+from .buffer import Buffer
+from .expr import Broadcast as broadcast
+from .expr import BufferLoad, CommReducer, IntImm, IterVar, Let, PrimExpr
+from .expr import Ramp as ramp
+from .expr import Select, Shuffle, StringImm, Var
+from .generic import cast
+from .stmt import BufferRegion, type_annotation
+
+
+def buffer_decl(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+) -> Buffer:
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ return _ffi_api.BufferDecl( # pylint: disable=no-member # type: ignore
+ shape,
+ dtype,
+ "",
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def ptr(dtype, storage_scope="global"):
+ return _ffi_api.Ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
+
+
+buffer_var = ptr
+
+
+def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
+ return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore
+
+
+def init() -> frame.BlockInitFrame:
+ return _ffi_api.Init() # pylint: disable=no-member # type: ignore
+
+
+def where(predicate) -> None:
+ if isinstance(predicate, bool):
+ predicate = IntImm("bool", predicate)
+ _ffi_api.Where(predicate) # pylint: disable=no-member # type: ignore
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+ if len(buffer_slices) == 1:
+ if isinstance(buffer_slices[0], tuple):
+ buffer_slices = list(buffer_slices[0])
+ elif isinstance(buffer_slices[0], list):
+ buffer_slices = buffer_slices[0] # type: ignore
+ else:
+ buffer_slices = [buffer_slices[0]] # type: ignore
+ else:
+ buffer_slices = list(buffer_slices) # type: ignore
+ _ffi_api.Reads(buffer_slices) # pylint: disable=no-member # type: ignore
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+ if len(buffer_slices) == 1:
+ if isinstance(buffer_slices[0], tuple):
+ buffer_slices = list(buffer_slices[0])
+ elif isinstance(buffer_slices[0], list):
+ buffer_slices = buffer_slices[0] # type: ignore
+ else:
+ buffer_slices = [buffer_slices[0]]
+ else:
+ buffer_slices = list(buffer_slices) # type: ignore
+ _ffi_api.Writes(buffer_slices) # pylint: disable=no-member # type: ignore
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+ return _ffi_api.BlockAttrs(attrs) # pylint: disable=no-member # type: ignore
+
+
+def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ axis_separators=None,
+) -> Buffer:
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is None:
+ strides = []
+ return _ffi_api.AllocBuffer( # pylint: disable=no-member # type: ignore
+ shape,
+ dtype,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def _as_range(dom) -> Range:
+ if isinstance(dom, Range):
+ return dom
+ if isinstance(dom, (list, tuple)):
+ return Range(dom[0], dom[1])
+ return Range(0, dom)
+
+
+class axis: # pylint: disable=invalid-name
+ @staticmethod
+ def spatial(dom, binding, dtype="int32") -> IterVar:
+ return _ffi_api.AxisSpatial( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def reduce(dom, binding, dtype="int32") -> IterVar:
+ return _ffi_api.AxisReduce( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def scan(dom, binding, dtype="int32") -> IterVar:
+ return _ffi_api.AxisScan( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def opaque(dom, binding, dtype="int32") -> IterVar:
+ return _ffi_api.AxisOpaque( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def remap(kinds, bindings, dtype="int32") -> Union[List[IterVar], IterVar]:
+ iter_vars = _ffi_api.AxisRemap( # pylint: disable=no-member # type: ignore
+ kinds, bindings, dtype
+ )
+ return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+ S = spatial # pylint: disable=invalid-name
+ R = reduce # pylint: disable=invalid-name
+
+
+def serial(start, stop=None, *, annotations=None) -> frame.ForFrame:
+ if stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore
+
+
+def parallel(start, stop=None, *, annotations=None) -> frame.ForFrame:
+ if stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore
+
+
+def vectorized(start, stop=None, *, annotations=None) -> frame.ForFrame:
+ if stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore
+
+
+def unroll(start, stop=None, *, annotations=None) -> frame.ForFrame:
+ if stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore
+
+
+def thread_binding(
+ start,
+ stop=None,
+ thread=None,
+ *,
+ annotations=None,
+) -> frame.ForFrame:
+ if thread is None:
+ if not isinstance(stop, str):
+ raise ValueError("Thread cannot be None for thread_binding")
+ thread = stop
+ stop = start
+ start = 0
+ elif stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore
+ start, stop, thread, annotations
+ )
+
+
+def grid(*extents) -> frame.ForFrame:
+ return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore
+
+
+def prim_func() -> frame.PrimFuncFrame:
+ return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore
+
+
+def arg(name, obj):
+ return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore
+
+
+def func_name(name: str) -> str:
+ return _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore
+
+
+def func_attr(attrs: Dict[str, Any]) -> None:
+ return _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore
+
+
+def func_ret(ret_type) -> Type:
+ return _ffi_api.FuncRet(ret_type) # pylint: disable=no-member # type: ignore
+
+
+def match_buffer(
+ param,
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ axis_separators=None,
+) -> Buffer:
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is None:
+ strides = []
+ return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore
+ param,
+ shape,
+ dtype,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def preflattened_buffer(
+ postflattened,
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ axis_separators=None,
+) -> None:
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is None:
+ strides = []
+ _ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore
+ postflattened,
+ shape,
+ dtype,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name
+ return _ffi_api.Assert(condition, message) # pylint: disable=no-member # type: ignore
+
+
+def let(
+ v: Var,
+ value: PrimExpr,
+ body: PrimExpr = None,
+) -> frame.LetFrame:
+ if body is None:
+ return _ffi_api.Let(v, value) # pylint: disable=no-member # type: ignore
+ return Let(v, value, body)
+
+
+def allocate(
+ extents: List[PrimExpr],
+ dtype: str,
+ scope: str = "",
+ condition: PrimExpr = None,
+ annotations=None,
+) -> frame.AllocateFrame:
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.Allocate( # pylint: disable=no-member # type: ignore
+ extents, dtype, scope, condition, annotations
+ )
+
+
+def allocate_const(
+ data: List[PrimExpr],
+ dtype: str,
+ extents: List[PrimExpr],
+ annotations=None,
+) -> frame.AllocateConstFrame:
+
+ return _ffi_api.AllocateConst( # pylint: disable=no-member # type: ignore
+ ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+ )
+
+
+def realize(
+ buffer_slice: BufferRegion,
+ storage_scope: str,
+ condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+ return _ffi_api.Realize( # pylint: disable=no-member # type: ignore
+ buffer_slice, storage_scope, condition
+ )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+ node = convert(node)
+ value = convert(value)
+ return _ffi_api.Attr(node, attr_key, value) # pylint: disable=no-member # type: ignore
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame: # pylint: disable=invalid-name
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.While(condition) # pylint: disable=no-member # type: ignore
+
+
+def If(condition: PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.If(condition) # pylint: disable=no-member # type: ignore
+
+
+def Then() -> frame.ThenFrame: # pylint: disable=invalid-name
+ return _ffi_api.Then() # pylint: disable=no-member # type: ignore
+
+
+def Else() -> frame.ElseFrame: # pylint: disable=invalid-name
+ return _ffi_api.Else() # pylint: disable=no-member # type: ignore
+
+
+def decl_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+) -> frame.DeclBufferFrame:
+
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ return _ffi_api.DeclBuffer( # pylint: disable=no-member # type: ignore
+ shape,
+ dtype,
+ "",
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def launch_thread(
+ iter_var: IterVar, # pylint: disable=redefined-outer-name
+ extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+ return _ffi_api.LaunchThread(iter_var, extent) # pylint: disable=no-member # type: ignore
+
+
+def env_thread(thread_tag: str) -> IterVar:
+ return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore
+
+
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel
+
+ expr_indices = []
+ for index in indices:
+ if isinstance(index, slice):
+ step = 1 if index.step is None else index.step
+ lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+ if lanes == 1:
+ expr_indices.append(index.start)
+ else:
+ expr_indices.append(ramp(index.start, step, int(lanes)))
+ else:
+ expr_indices.append(index)
+ if isinstance(value, bool) and buffer.dtype == "bool":
+ value = IntImm("bool", value)
+ return _ffi_api.BufferStore( # pylint: disable=no-member # type: ignore
+ buffer, value, expr_indices
+ )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+ return _ffi_api.Prefetch(buffer, indices) # pylint: disable=no-member # type: ignore
+
+
+def evaluate(value: PrimExpr) -> None:
+ if isinstance(value, str):
+ value = StringImm(value)
+ return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore
+
+
+def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int8(expr) # pylint: disable=no-member # type: ignore
+
+
+def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int16(expr) # pylint: disable=no-member # type: ignore
+
+
+def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int32(expr) # pylint: disable=no-member # type: ignore
+
+
+def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int64(expr) # pylint: disable=no-member # type: ignore
+
+
+def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.UInt8(expr) # pylint: disable=no-member # type: ignore
+
+
+def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.UInt16(expr) # pylint: disable=no-member # type: ignore
+
+
+def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.UInt32(expr) # pylint: disable=no-member # type: ignore
+
+
+def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.UInt64(expr) # pylint: disable=no-member # type: ignore
+
+
+def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ if not isinstance(expr, PrimExpr):
+ expr = convert(expr)
+ return _ffi_api.Float8(expr) # pylint: disable=no-member # type: ignore
+
+
+def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ if not isinstance(expr, PrimExpr):
+ expr = convert(expr)
+ return _ffi_api.Float16(expr) # pylint: disable=no-member # type: ignore
+
+
+def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ if not isinstance(expr, PrimExpr):
+ expr = convert(expr)
+ return _ffi_api.Float32(expr) # pylint: disable=no-member # type: ignore
+
+
+def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ if not isinstance(expr, PrimExpr):
+ expr = convert(expr)
+ return _ffi_api.Float64(expr) # pylint: disable=no-member # type: ignore
+
+
+def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Boolean(expr) # pylint: disable=no-member # type: ignore
+
+
+def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Handle(expr) # pylint: disable=no-member # type: ignore
+
+
+def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Void(expr) # pylint: disable=no-member # type: ignore
+
+
+def min(a, b): # pylint: disable=redefined-builtin
+ """Compute the minimum value of two expressions.
+
+ Parameters
+ ----------
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
+ Returns
+ -------
+ res : PrimExpr
+ The result expression.
+ """
+ return _ffi_api.min(a, b) # pylint: disable=no-member # type: ignore
+
+
+def max(a, b): # pylint: disable=redefined-builtin
+ """Compute the maximum value of two expressions.
+
+ Parameters
+ ----------
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
+ Returns
+ -------
+ res : PrimExpr
+ The result expression.
+ """
+ return _ffi_api.max(a, b) # pylint: disable=no-member # type: ignore
+
+
+def var(dtype, name="") -> Var:
+ return Var(name, dtype) # pylint: disable=no-member # type: ignore
+
+
+def iter_var(v, dom, iter_type, thread_tag):
+ iter_type = getattr(IterVar, iter_type)
+ return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner, identity):
+ """Create a CommReducer from lambda inputs/outputs and the identities"""
+ params = inspect.signature(combiner).parameters
+ num_args = len(params)
+ args = []
+ for name, i in zip(params.keys(), identity + identity):
+ if isinstance(i, int):
+ args.append(Var(name, "int32"))
+ else:
+ args.append(Var(name, i.dtype))
+ res = combiner(*args)
+ if not isinstance(res, tuple):
+ res = (res,)
+ return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def llvm_lookup_intrinsic_id(name):
+ # pylint: disable=import-outside-toplevel
+ from tvm.target.codegen import llvm_lookup_intrinsic_id as f
+
+ # pylint: enable=import-outside-toplevel
+ return f(name)
+
+
+def _op_wrapper(func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if "dtype" in kwargs:
+ kwargs.pop("dtype")
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
+def _dtype_forward(func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if "dtype" in kwargs:
+ args = (kwargs.pop("dtype"),) + args
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
+abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+
+__all__ = [
+ "Assert",
+ "Else",
+ "If",
+ "Let",
+ "Select",
+ "Shuffle",
+ "TVMBackendAllocWorkspace",
+ "TVMBackendFreeWorkspace",
+ "Then",
+ "While",
+ "abs",
+ "acos",
+ "acosh",
+ "address_of",
+ "alloc_buffer",
+ "allocate",
+ "allocate_const",
+ "arg",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "attr",
+ "axis",
+ "block",
+ "block_attr",
+ "boolean",
+ "broadcast",
+ "buffer_decl",
+ "buffer_store",
+ "buffer_var",
+ "call_cpacked",
+ "call_cpacked_lowered",
+ "call_extern",
+ "call_intrin",
+ "call_llvm_intrin",
+ "call_llvm_pure_intrin",
+ "call_packed",
+ "call_packed_lowered",
+ "call_pure_extern",
+ "cast",
+ "ceil",
+ "ceildiv",
+ "clz",
+ "comm_reducer",
+ "copysign",
+ "cos",
+ "cosh",
+ "env_thread",
+ "erf",
+ "evaluate",
+ "exp",
+ "exp10",
+ "exp2",
+ "decl_buffer",
+ "fabs",
+ "float16",
+ "float32",
+ "float64",
+ "float8",
+ "floor",
+ "floordiv",
+ "floormod",
+ "fmod",
+ "func_attr",
+ "func_name",
+ "func_ret",
+ "grid",
+ "handle",
+ "hypot",
+ "if_then_else",
+ "infinity",
+ "init",
+ "int16",
+ "int32",
+ "int64",
+ "int8",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "isnullptr",
+ "iter_var",
+ "launch_thread",
+ "ldexp",
+ "let",
+ "likely",
+ "llvm_lookup_intrinsic_id",
+ "log",
+ "log10",
+ "log1p",
+ "log2",
+ "lookup_param",
+ "match_buffer",
+ "max",
+ "max_value",
+ "min",
+ "min_value",
+ "mma_fill",
+ "mma_store",
+ "nearbyint",
+ "nextafter",
+ "parallel",
+ "popcount",
+ "power",
+ "prefetch",
+ "preflattened_buffer",
+ "prim_func",
+ "ptr",
+ "ptx_commit_group",
+ "ptx_cp_async",
+ "ptx_ldmatrix",
+ "ptx_mma",
+ "ptx_mma_sp",
+ "ptx_wait_group",
+ "q_multiply_shift",
+ "ramp",
+ "reads",
+ "realize",
+ "reinterpret",
+ "ret",
+ "round",
+ "rsqrt",
+ "serial",
+ "shift_left",
+ "shift_right",
+ "sigmoid",
+ "sin",
+ "sinh",
+ "sqrt",
+ "tan",
+ "tanh",
+ "target",
+ "thread_binding",
+ "trunc",
+ "truncdiv",
+ "truncmod",
+ "tvm_access_ptr",
+ "tvm_bmma_sync",
+ "tvm_call_cpacked",
+ "tvm_call_cpacked_lowered",
+ "tvm_call_packed",
+ "tvm_call_packed_lowered",
+ "tvm_fill_fragment",
+ "tvm_load_matrix_sync",
+ "tvm_mma_sync",
+ "tvm_stack_alloca",
+ "tvm_stack_make_array",
+ "tvm_stack_make_shape",
+ "tvm_store_matrix_sync",
+ "tvm_struct_get",
+ "tvm_struct_set",
+ "tvm_thread_allreduce",
+ "tvm_throw_last_error",
+ "tvm_tuple",
+ "type_annotation",
+ "uint16",
+ "uint32",
+ "uint64",
+ "uint8",
+ "unroll",
+ "var",
+ "vectorized",
+ "void",
+ "where",
+ "writes",
+]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 17005b04a4..7b5f55fd5f 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -100,6 +100,64 @@ def call_cpacked(*args, span=None):
return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span)
+def call_packed_lowered(*args, span=None):
+ """Lowered version of call packed.
+
+ The argument to packed function can be Expr or Buffer.
+ The argument is the corresponding POD type when Expr is presented.
+
+ When the argument is Buffer, the corresponding PackedFunc
+ will recieve an TVMArrayHandle whose content is valid during the callback period.
+ If the PackedFunc is a python callback, then the corresponding argument is NDArray.
+
+ Parameters
+ ----------
+ args : list of Expr or Buffer.
+ Positional arguments.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+
+ See Also
+ --------
+ te.extern : Create tensor with extern function call.
+ """
+ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+ return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span)
+
+
+def call_cpacked_lowered(*args, span=None):
+ """Lowered version of call c-packed.
+
+ Same as call_packed, except that the first argument is the function name
+ (as in call_extern), and the last argument is the resource handle.
+
+ Parameters
+ ----------
+ args : list of Expr or Buffer.
+ Positional arguments.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+
+ See Also
+ --------
+ te.extern : Create tensor with extern function call.
+ """
+ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+ return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span)
+
+
def call_intrin(dtype, func_name, *args, span=None):
"""Build expression by calling an intrinsic function.
@@ -151,7 +209,10 @@ def call_pure_extern(dtype, func_name, *args, span=None):
The call expression.
"""
return Call(
- dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span
+ dtype,
+ Op.get("tir.call_pure_extern"),
+ convert((StringImm(func_name),) + args),
+ span,
)
@@ -178,7 +239,10 @@ def call_extern(dtype, func_name, *args, span=None):
The call expression.
"""
return Call(
- dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span
+ dtype,
+ Op.get("tir.call_extern"),
+ convert((StringImm(func_name),) + args),
+ span=span,
)
@@ -206,11 +270,21 @@ def call_llvm_intrin(dtype, name, *args, span=None):
"""
# pylint: disable=import-outside-toplevel
from tvm.target import codegen
-
- llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ from .expr import IntImm
+
+ if isinstance(name, str):
+ llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ elif isinstance(name, IntImm):
+ llvm_id = name.value
+ else:
+ llvm_id = name
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_intrin(
- dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+ dtype,
+ Op.get("tir.call_llvm_intrin"),
+ tvm.tir.const(llvm_id, "uint32"),
+ *args,
+ span=span,
)
@@ -238,8 +312,14 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
"""
# pylint: disable=import-outside-toplevel
from tvm.target import codegen
-
- llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ from .expr import IntImm
+
+ if isinstance(name, str):
+ llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ elif isinstance(name, IntImm):
+ llvm_id = name.value
+ else:
+ llvm_id = name
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_intrin(
dtype,
@@ -250,6 +330,310 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+ return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+ return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
+def tvm_stack_alloca(dtype_str, num):
+ return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num)
+
+
+def tvm_stack_make_shape(*args):
+ return call_intrin("handle", "tir.tvm_stack_make_shape", *args)
+
+
+def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
+ return call_intrin(
+ "handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset
+ )
+
+
+def address_of(buffer_load, span=None):
+ """Returns the address of an element in the buffer
+
+ Parameters
+ ----------
+ buffer_load: BufferLoad
+ The buffer load.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.address_of", buffer_load)
+
+
+def lookup_param(param_name, span=None):
+ """Returns the param by name
+
+ Parameters
+ ----------
+ param_name : str
+ The name of param.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.lookup_param", param_name)
+
+
+def tvm_access_ptr(dtype, data, offset, extent, rw_mask):
+ return call_intrin("handle", "tir.tvm_access_ptr", dtype, data, offset, extent, rw_mask)
+
+
+def tvm_tuple(*value):
+ return call_intrin("handle", "tir.tvm_tuple", *value)
+
+
+def tvm_struct_get(arr, index, field_id, dtype):
+ return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field_id)
+
+
+def tvm_struct_set(arr, index, field_id, value):
+ return call_intrin("handle", "tir.tvm_struct_set", arr, index, field_id, value)
+
+
+def tvm_thread_allreduce(*freduce_args):
+ return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+ return call_intrin(
+ "handle",
+ "tir.tvm_load_matrix_sync",
+ fragment,
+ m,
+ n,
+ k,
+ index,
+ buffer_ptr,
+ stride,
+ layout,
+ )
+
+
+def tvm_mma_sync(
+ fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+ return call_intrin(
+ "handle",
+ "tir.tvm_mma_sync",
+ fragment_d,
+ index_d,
+ fragment_a,
+ index_a,
+ fragment_b,
+ index_b,
+ fragment_c,
+ index_c,
+ )
+
+
+def tvm_bmma_sync(
+ fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+ return call_intrin(
+ "handle",
+ "tir.tvm_bmma_sync",
+ fragment_d,
+ index_d,
+ fragment_a,
+ index_a,
+ fragment_b,
+ index_b,
+ fragment_c,
+ index_c,
+ )
+
+
+def tvm_fill_fragment(fragment, m, n, k, index, value):
+ return call_intrin(
+ "handle",
+ "tir.tvm_fill_fragment",
+ fragment,
+ m,
+ n,
+ k,
+ index,
+ value,
+ )
+
+
+def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+ return call_intrin(
+ "handle",
+ "tir.tvm_store_matrix_sync",
+ fragment,
+ m,
+ n,
+ k,
+ index,
+ buffer_ptr,
+ stride,
+ layout,
+ )
+
+
+def ptx_mma(
+ dtype,
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ operator=None,
+):
+ if operator is None:
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ )
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ operator,
+ )
+
+
+def ptx_mma_sp(
+ dtype,
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ metadata,
+ meta_index,
+ sparse_selector,
+ saturate,
+):
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma_sp",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ metadata,
+ meta_index,
+ sparse_selector,
+ saturate,
+ )
+
+
+def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
+ return call_intrin(
+ dtype,
+ "tir.ptx_ldmatrix",
+ trans,
+ num,
+ type,
+ local_ptr,
+ local_offset,
+ smem_ptr,
+ smem_offset,
+ )
+
+
+def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
+ return call_intrin(
+ dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
+ )
+
+
+def ptx_commit_group():
+ return call_intrin("", "tir.ptx_commit_group")
+
+
+def ptx_wait_group(num):
+ return call_intrin("", "tir.ptx_wait_group", num)
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+ return call_intrin(
+ dtype,
+ "tir.mma_store",
+ m,
+ n,
+ dst_ptr,
+ src_ptr,
+ src_offset,
+ dst_stride,
+ )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+ return call_intrin(
+ dtype,
+ "tir.mma_fill",
+ local_size,
+ local_ptr,
+ offset,
+ )
+
+
def ret(val):
"""Create a tir return expression
@@ -394,6 +778,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
return _ffi_api.max_value(dtype, span) # type: ignore
+def infinity(dtype: str, span: Optional[Span] = None) -> Any:
+ """infinity value of dtype
+
+ Parameters
+ ----------
+ dtype : str
+ The data type.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ value : tvm.Expr
+ The infinity value of dtype.
+ """
+ return _ffi_api.infinity(dtype, span) # type: ignore
+
+
+def reinterpret(dtype, value, span=None) -> Any:
+ """infinity value of dtype
+
+ Parameters
+ ----------
+ dtype : str
+ The data type.
+
+ value : PrimExpr
+ The input value.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ value : tvm.Expr
+ The reinterpret cast value of dtype.
+ """
+ return _ffi_api.reinterpret(dtype, value, span) # type: ignore
+
+
def exp(x):
"""Take exponential of input x.
@@ -998,6 +1423,25 @@ def ldexp(x1, x2):
return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore
+def likely(cond, span=None):
+ """Mark condition as likely.
+
+ Parameters
+ ----------
+ cond : PrimExpr
+ Input argument.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ y : PrimExpr
+ The marked expression.
+ """
+ return _ffi_api.likely(cond, span) # type: ignore
+
+
def isnan(x, span=None):
"""Check if input value is Nan.
@@ -1017,6 +1461,25 @@ def isnan(x, span=None):
return _ffi_api.isnan(x, span) # type: ignore
+def isnullptr(x, span=None):
+ """Check if input value is nullptr.
+
+ Parameters
+ ----------
+ x : PrimExpr
+ Input argument.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ y : PrimExpr
+ The result.
+ """
+ return call_intrin("bool", "tir.isnullptr", x) # type: ignore
+
+
def isfinite(x, span=None):
"""Check if input value is finite.
@@ -1122,6 +1585,42 @@ def q_multiply_shift(x, y, q, s):
return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)
+def shift_left(x, y, span=None):
+ """Return the result of x left shifted by y bits.
+
+ Parameters
+ ----------
+ x : PrimExpr
+ Input argument.
+ y : PrimExpr
+ Input argument.
+
+ Returns
+ -------
+ z : PrimExpr
+ The result.
+ """
+ return _ffi_api.left_shift(x, y, span)
+
+
+def shift_right(x, y, span=None):
+ """Return the result of x right shifted by y bits.
+
+ Parameters
+ ----------
+ x : PrimExpr
+ Input argument.
+ y : PrimExpr
+ Input argument.
+
+ Returns
+ -------
+ z : PrimExpr
+ The result.
+ """
+ return _ffi_api.right_shift(x, y, span)
+
+
def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
@@ -1306,6 +1805,28 @@ def truncmod(a, b, span=None):
return _ffi_api._OpTruncMod(a, b, span) # type: ignore
+def ceildiv(a, b, span=None):
+ """Compute the ceildiv of two expressions.
+
+ Parameters
+ ----------
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
+ span : Optional[Span]
+ The location of this operator in the source.
+
+ Returns
+ -------
+ res : PrimExpr
+ The result expression.
+ """
+ return _ffi_api._OpCeilDiv(a, b, span) # type: ignore
+
+
def floordiv(a, b, span=None):
"""Compute the floordiv of two expressions.
@@ -1523,6 +2044,22 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return reducer
+def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
+ return call_intrin(
+ "handle",
+ "tir.TVMBackendAllocWorkspace",
+ device_type,
+ device_id,
+ nbytes,
+ dtype_code_hint,
+ dtype_bits_hint,
+ )
+
+
+def TVMBackendFreeWorkspace(device_type, device_id, ptr):
+ call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)
+
+
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
index 30e047b4f7..0ebaf212d1 100644
--- a/python/tvm/tir/schedule/block_scope.py
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -20,8 +20,8 @@ from typing import List, Optional, Union
from tvm._ffi import register_object
from tvm.runtime import Object
-from tvm.tir import Block, For
+from ..stmt import Block, For
from . import _ffi_api
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index cf031c014c..f26f954d51 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -21,9 +21,11 @@ from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
-from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
-from ..function import IndexMap
+from ..buffer import Buffer
+from ..expr import FloatImm, IntImm
+from ..function import IndexMap, PrimFunc
+from ..stmt import Block, For
from . import _ffi_api
from ._type_checker import type_checked
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..3aed52fb50 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -22,8 +22,9 @@ from typing import Dict, Optional, Union
from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import Object
-from tvm.tir import Block, BlockRealize, For, PrimFunc
+from ..function import PrimFunc
+from ..stmt import Block, BlockRealize, For
from . import _ffi_api
from .block_scope import BlockScope, StmtSRef
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4847e377de..3c2228e6d9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -754,3 +754,7 @@ def stmt_list(stmt):
res += stmt_list(x)
return res
return [stmt]
+
+
+def type_annotation(dtype, span=None):
+ return _ffi_api.TypeAnnotation(dtype, span)
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
index f472172cf3..86d8bef356 100644
--- a/python/tvm/tir/usmp/transform/transform.py
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -20,8 +20,9 @@
from typing import Dict
import tvm
-from tvm.tir import Stmt
-from tvm.tir.usmp.utils import PoolAllocation
+
+from ...stmt import Stmt
+from ..utils import PoolAllocation
from . import _ffi_api
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index a3318bf94f..a7d95848da 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -21,6 +21,7 @@
* \file src/ir/expr.cc
* \brief The expression AST nodes for the common IR infra.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/runtime/registry.h>
@@ -49,6 +50,18 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
if (auto* ptr = ref.as<runtime::StringObj>()) {
return tir::StringImm(GetRef<runtime::String>(ptr));
}
+ if (auto* ptr = ref.as<tir::BufferRegionNode>()) {
+ tir::BufferRegion buffer_region = GetRef<tir::BufferRegion>(ptr);
+ Array<PrimExpr> indices;
+ for (Range r : buffer_region->region) {
+ if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+ indices.push_back(r->min);
+ } else {
+ indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+ }
+ }
+ return tir::BufferLoad(buffer_region->buffer, indices);
+ }
Optional<String> actual_type = ObjectTypeChecker<PrimExpr>::CheckAndGetMismatch(ref.get());
ICHECK(!actual_type.defined()) << "Expected type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but got " << actual_type.value();
diff --git a/src/ir/ir_builder.cc b/src/ir/ir_builder.cc
new file mode 100644
index 0000000000..9f42cdb168
--- /dev/null
+++ b/src/ir/ir_builder.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ir/ir_builder.h>
+#include <tvm/ir/module.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace ir_builder {
+
+void IRBuilderFrameNode::EnterWithScope() {
+ IRBuilder::Current()->frames.push_back(GetRef<IRBuilderFrame>(this));
+}
+
+void IRBuilderFrameNode::ExitWithScope() {
+ for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
+ (*it)();
+ }
+ this->callbacks.clear();
+ IRBuilder::Current()->frames.pop_back();
+}
+
+void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc<void()> callback) {
+ if (IRBuilder::Current()->frames.empty()) {
+ LOG(FATAL) << "ValueError: No frames in Builder to add callback";
+ }
+ IRBuilder::Current()->frames.back()->callbacks.push_back(callback);
+}
+
+IRBuilder::IRBuilder() {
+ ObjectPtr<IRBuilderNode> n = make_object<IRBuilderNode>();
+ n->frames.clear();
+ n->result = NullOpt;
+ data_ = n;
+}
+
+std::vector<IRBuilder>* ThreadLocalBuilderStack() {
+ thread_local std::vector<IRBuilder> stack;
+ return &stack;
+}
+
+void IRBuilder::EnterWithScope() {
+ IRBuilderNode* n = this->get();
+ CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: "
+ << n->frames.size()
+ << ". Please use a fresh new builder every time building IRs";
+ n->result = NullOpt;
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ stack->push_back(*this);
+}
+
+void IRBuilder::ExitWithScope() {
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ ICHECK(!stack->empty());
+ stack->pop_back();
+}
+
+IRBuilder IRBuilder::Current() {
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ CHECK(!stack->empty()) << "ValueError: No builder in current scope";
+ return stack->back();
+}
+
+IRModuleFrame IRModule() {
+ ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
+ n->global_vars.clear();
+ n->functions.clear();
+ return IRModuleFrame(n);
+}
+
+void IRModuleFrameNode::ExitWithScope() {
+ ICHECK_EQ(functions.size(), global_vars.size());
+ int n = functions.size();
+ Map<GlobalVar, BaseFunc> func_map;
+ for (int i = 0; i < n; ++i) {
+ func_map.Set(global_vars[i], functions[i]);
+ }
+ IRBuilder builder = IRBuilder::Current();
+ ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+ builder->result = tvm::IRModule(func_map);
+}
+
+namespace details {
+
+Namer::FType& Namer::vtable() {
+ static FType inst;
+ return inst;
+}
+
+void Namer::Name(ObjectRef node, String name) {
+ static const FType& f = vtable();
+ CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name;
+ CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \""
+ << node->GetTypeKey();
+ f(node, name);
+}
+
+} // namespace details
+
+TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode);
+TVM_REGISTER_NODE_TYPE(IRBuilderNode);
+TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameEnter")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameExit")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameAddCallback")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::AddCallback);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); });
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderGet")
+ .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRModule").set_body_typed(IRModule);
+
+} // namespace ir_builder
+} // namespace tvm
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 7979c9f47a..9bbd9068e2 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -20,6 +20,7 @@
/*!
* \file expr.cc
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
@@ -828,12 +829,26 @@ TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) {
Array<PrimExpr> prim_expr_args;
for (const auto& it : args) {
- ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>())
+ ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() ||
+ it->IsInstance<IterVarNode>() || it->IsInstance<BufferRegionNode>())
<< "Argument " << it << " is not a string or primexpr";
if (const auto* str = it.as<runtime::StringObj>()) {
prim_expr_args.push_back(StringImm(str->data));
+ } else if (const auto* expr = it.as<PrimExprNode>()) {
+ prim_expr_args.push_back(GetRef<PrimExpr>(expr));
+ } else if (const auto* br = it.as<BufferRegionNode>()) {
+ BufferRegion buffer_region = GetRef<BufferRegion>(br);
+ Array<PrimExpr> indices;
+ for (Range r : buffer_region->region) {
+ if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+ indices.push_back(r->min);
+ } else {
+ indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+ }
+ }
+ prim_expr_args.push_back(BufferLoad(buffer_region->buffer, indices));
} else {
- prim_expr_args.push_back(Downcast<PrimExpr>(it));
+ prim_expr_args.push_back(Downcast<IterVar>(it).operator PrimExpr());
}
}
return Call(type, op, prim_expr_args, span);
diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index 7e3d3d1075..b11ca6650a 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -22,11 +22,10 @@
* \brief Used by TVM Script parser to expand incomplete TIR input
*/
+#include "./script_complete.h"
+
#include <tvm/arith/int_set.h>
-#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
-#include <tvm/tir/stmt.h>
-#include <tvm/tir/stmt_functor.h>
#include <utility>
diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h
new file mode 100644
index 0000000000..5392ea309d
--- /dev/null
+++ b/src/tir/ir/script/script_complete.h
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/ir/script/script_complete.h
+ * \brief Used by TVM Script parser to expand incomplete TIR input
+ */
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates);
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 524204f3d3..6dd5bf5886 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -1091,5 +1091,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
TVM_REGISTER_OP("tir.type_annotation")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
+TVM_REGISTER_GLOBAL("tir.TypeAnnotation").set_body_typed(TypeAnnotation);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder.cc b/src/tir/ir_builder/ir_builder.cc
new file mode 100644
index 0000000000..d312f014e5
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder.cc
@@ -0,0 +1,637 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+ Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+ String storage_scope, int align, int offset_factor, String buffer_type,
+ Optional<Array<IntImm>> axis_separators) {
+ Var buffer_data;
+ if (!data.defined()) {
+ DataType storage_dtype = dtype;
+ if (storage_dtype == DataType::Bool()) {
+ storage_dtype = DataType::Int(8);
+ }
+ buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope));
+ } else {
+ buffer_data = data.value();
+ }
+ if (!elem_offset.defined() && offset_factor) {
+ DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype;
+ elem_offset = tvm::tir::Var("elem_offset", shape_dtype);
+ }
+ return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
+ elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor,
+ (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault,
+ axis_separators.value_or(Array<IntImm>()));
+}
+
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+ Optional<Var> data, Optional<Array<PrimExpr>> strides,
+ Optional<PrimExpr> elem_offset, String storage_scope, int align,
+ int offset_factor, String buffer_type,
+ Optional<Array<IntImm>> axis_separators) {
+ ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
+ n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
+ align, offset_factor, buffer_type, axis_separators);
+ return DeclBufferFrame(n);
+}
+
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
+ return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
+}
+
+BlockFrame Block(String name, bool no_realize) {
+ ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
+ n->name = name;
+ n->iter_vars.clear();
+ n->reads = NullOpt;
+ n->writes = NullOpt;
+ n->init = NullOpt;
+ n->alloc_buffers.clear();
+ n->match_buffers.clear();
+ n->annotations.clear();
+ n->iter_values.clear();
+ n->predicate = NullOpt;
+ n->no_realize = no_realize;
+ return BlockFrame(n);
+}
+
+BlockInitFrame Init() { return BlockInitFrame(make_object<BlockInitFrameNode>()); }
+
+void Where(PrimExpr predicate) {
+ BlockFrame frame = FindBlockFrame("T.where");
+ if (frame->predicate.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is "
+ << frame->predicate.value();
+ }
+ frame->predicate = predicate;
+}
+
+void Reads(Array<ObjectRef> buffer_slices) {
+ using namespace tvm::tir;
+ BlockFrame frame = FindBlockFrame("T.reads");
+ if (frame->reads.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads;
+ }
+ Array<BufferRegion> reads;
+ for (const ObjectRef& obj : buffer_slices) {
+ if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+ reads.push_back(GetRef<BufferRegion>(buffer_region));
+ } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+ reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+ } else {
+ LOG(FATAL) << "Invalid type for buffer reads.";
+ }
+ }
+ frame->reads = reads;
+}
+
+void Writes(Array<ObjectRef> buffer_slices) {
+ using namespace tvm::tir;
+ BlockFrame frame = FindBlockFrame("T.writes");
+ if (frame->writes.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is "
+ << frame->writes;
+ }
+ Array<BufferRegion> writes;
+ for (const ObjectRef& obj : buffer_slices) {
+ if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+ writes.push_back(GetRef<BufferRegion>(buffer_region));
+ } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+ writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+ } else {
+ LOG(FATAL) << "Invalid type for buffer writes.";
+ }
+ }
+ frame->writes = writes;
+}
+
+void BlockAttrs(Map<String, ObjectRef> attrs) {
+ BlockFrame frame = FindBlockFrame("T.block_attr");
+ if (!frame->annotations.empty()) {
+ LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
+ }
+ frame->annotations = attrs;
+}
+
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+ Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+ int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+ Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+ offset_factor, buffer_type_str, axis_separators);
+ IRBuilder builder = IRBuilder::Current();
+ if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) {
+ frame.value()->alloc_buffers.push_back(buffer);
+ } else if (Optional<PrimFuncFrame> frame = builder->GetLastFrame<PrimFuncFrame>()) {
+ frame.value()->root_alloc_buffers.push_back(buffer);
+ } else {
+ LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure "
+ "'T.alloc_buffer' is called under T.block() or T.prim_func()";
+ }
+ return buffer;
+};
+
+namespace axis {
+
+IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
+ if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+ BlockFrame frame = opt_frame.value();
+ frame->iter_vars.push_back(iter_var);
+ frame->iter_values.push_back(binding);
+ } else {
+ LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
+ }
+ return iter_var;
+}
+
+#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \
+ IterVar Method(Range dom, PrimExpr binding, DataType dtype) { \
+ ICHECK(dom.defined()) << Name << " axis must have a domain"; \
+ int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
+ return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \
+ /*iter_type=*/Kind, /*thread_tag=*/""), \
+ binding); \
+ }
+TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
+TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
+TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
+TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
+#undef TVM_TIR_IR_BUILDER_AXIS
+
+Array<IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
+ using namespace tvm::tir;
+ Array<IterVar> results;
+ ICHECK_EQ(kinds.size(), bindings.size());
+ int n = bindings.size();
+ results.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ char c = kinds.c_str()[i];
+ PrimExpr e = bindings[i];
+ const VarNode* v = e.as<VarNode>();
+ ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
+ Range dom{nullptr};
+ for (const auto& frame : IRBuilder::Current()->frames) {
+ if (const auto* for_frame = frame.as<ForFrameNode>()) {
+ ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
+ int n = for_frame->doms.size();
+ for (int i = 0; i < n; ++i) {
+ if (for_frame->vars[i].get() == v) {
+ dom = for_frame->doms[i];
+ break;
+ }
+ }
+ if (dom.defined()) {
+ break;
+ }
+ }
+ }
+ ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
+ DataType dtype = v->dtype;
+ if (c == 'S') {
+ results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+ /*var=*/Var("", dtype),
+ /*iter_type=*/IterVarType::kDataPar,
+ /*thread_tag=*/""),
+ e));
+ } else if (c == 'R') {
+ results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+ /*var=*/Var("", dtype),
+ /*iter_type=*/IterVarType::kCommReduce,
+ /*thread_tag=*/""),
+ e));
+ } else {
+ LOG(FATAL) << "Unknown axis kind: " << c;
+ }
+ }
+ return results;
+}
+
+} // namespace axis
+
+#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \
+ ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) { \
+ PrimExpr min = start; \
+ PrimExpr extent = arith::Analyzer().Simplify(stop - start); \
+ ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \
+ int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
+ n->vars = {Var("v", DataType::Int(bits))}; \
+ n->doms = {Range::FromMinExtent(min, extent)}; \
+ n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, tvm::tir::Stmt body) { \
+ ICHECK_EQ(vars.size(), 1); \
+ ICHECK_EQ(doms.size(), 1); \
+ return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \
+ annotations.value_or(Map<String, ObjectRef>())); \
+ }; \
+ return ForFrame(n); \
+ }
+
+TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled);
+
+#undef TVM_TIR_IR_BUILDER_FOR_FRAME
+
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+ Optional<Map<String, ObjectRef>> annotations) {
+ using namespace tvm::tir;
+ PrimExpr min = start;
+ PrimExpr extent = arith::Analyzer().Simplify(stop - start);
+ ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+ int bits = std::max(min.dtype().bits(), extent.dtype().bits());
+ n->vars = {Var("v", DataType::Int(bits))};
+ n->doms = {Range::FromMinExtent(min, extent)};
+ n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
+ ICHECK_EQ(vars.size(), 1);
+ ICHECK_EQ(doms.size(), 1);
+ IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
+ thread);
+ return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
+ annotations.value_or(Map<String, ObjectRef>()));
+ };
+ return ForFrame(n);
+}
+
+ForFrame Grid(Array<PrimExpr> extents) {
+ using namespace tvm::tir;
+ ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+ n->vars.reserve(extents.size());
+ n->doms.reserve(extents.size());
+ for (const auto& extent : extents) {
+ DataType dtype = extent.dtype();
+ n->vars.push_back(Var("v", extent.dtype()));
+ n->doms.push_back(Range(make_const(dtype, 0), extent));
+ }
+ n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, Stmt body) -> Stmt {
+ ICHECK_EQ(vars.size(), doms.size());
+ int n = vars.size();
+ for (int i = n - 1; i >= 0; --i) {
+ Range dom = doms[i];
+ Var var = vars[i];
+ body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
+ /*thread_binding=*/NullOpt, /*annotations=*/{});
+ }
+ return body;
+ };
+ return ForFrame(n);
+}
+
+PrimFuncFrame PrimFunc() {
+ ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>();
+ n->name = NullOpt;
+ n->args.clear();
+ n->ret_type = NullOpt;
+ n->buffer_map.clear();
+ n->preflattened_buffer_map.clear();
+ n->attrs.clear();
+ n->root_alloc_buffers.clear();
+ return PrimFuncFrame(n);
+}
+
+Var Arg(String name, Var var) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+ details::Namer::Name(var, name);
+ frame->args.push_back(var);
+ return var;
+}
+
+Buffer Arg(String name, Buffer buffer) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+ details::Namer::Name(buffer, name);
+ Var handle(buffer->name + "_handle", DataType::Handle());
+ frame->args.push_back(handle);
+ frame->buffer_map.Set(handle, buffer);
+ return buffer;
+}
+
+void FuncName(String name) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
+ if (frame->name.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value();
+ }
+ frame->name = name;
+}
+
+void FuncAttrs(Map<String, ObjectRef> attrs) {
+ using namespace tvm::tir;
+ PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr");
+ if (!frame->attrs.empty()) {
+ LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
+ }
+ frame->attrs = attrs;
+}
+
+tvm::Type FuncRet(tvm::Type ret_type) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type");
+ if (frame->ret_type.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is "
+ << frame->ret_type.value();
+ }
+ frame->ret_type = ret_type;
+ return ret_type;
+}
+
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+ Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+ int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+ Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+ offset_factor, buffer_type_str, axis_separators);
+ if (const auto* var = param.as<tvm::tir::VarNode>()) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer");
+ Var v = GetRef<Var>(var);
+ for (auto const& arg : frame->args) {
+ if (arg.same_as(v)) {
+ frame->buffer_map.Set(v, buffer);
+ return buffer;
+ }
+ }
+ LOG(FATAL) << "ValueError: Can not bind non-input param to buffer.";
+ } else if (const auto* buffer_load = param.as<tvm::tir::BufferLoadNode>()) {
+ BlockFrame frame = FindBlockFrame("T.match_buffer");
+ frame->match_buffers.push_back(tvm::tir::MatchBufferRegion(
+ buffer, BufferRegionFromLoad(GetRef<tvm::tir::BufferLoad>(buffer_load))));
+ } else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) {
+ BlockFrame frame = FindBlockFrame("T.match_buffer");
+ frame->match_buffers.push_back(
+ tvm::tir::MatchBufferRegion(buffer, GetRef<tvm::tir::BufferRegion>(buffer_region)));
+ } else {
+ LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
+ }
+ return buffer;
+};
+
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape, DataType dtype,
+ Optional<Var> data, Array<PrimExpr> strides, PrimExpr elem_offset,
+ String storage_scope, int align, int offset_factor, String buffer_type_str,
+ Array<IntImm> axis_separators) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer");
+ for (auto const& p : frame->buffer_map) {
+ if (p.second.same_as(postflattened_buffer)) {
+ String buffer_name(postflattened_buffer->name + "_preflatten");
+ Buffer buffer =
+ BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset,
+ storage_scope, align, offset_factor, buffer_type_str, axis_separators);
+ details::Namer::Name(buffer, buffer_name);
+ frame->preflattened_buffer_map.Set(p.first, buffer);
+ return;
+ }
+ }
+ LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name
+ << " does not exist.";
+};
+
+AssertFrame Assert(PrimExpr condition, String message) {
+ ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>();
+ n->condition = condition;
+ n->message = tvm::tir::StringImm(message);
+ return AssertFrame(n);
+}
+
+LetFrame Let(Var var, PrimExpr value) {
+ ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
+ n->var = var;
+ n->value = value;
+ return LetFrame(n);
+}
+
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope,
+ Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) {
+ ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>();
+ n->extents = extents;
+ n->dtype = dtype;
+ n->storage_scope = storage_scope;
+ n->condition = condition.value_or(tvm::Bool(true));
+ n->annotations = annotations.value_or(Map<String, ObjectRef>());
+ n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
+ "default", NullOpt);
+ return AllocateFrame(n);
+}
+
+AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
+ Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+ ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
+ n->dtype = dtype;
+ n->extents = extents;
+ n->data = data;
+ n->annotations = annotations;
+ n->buffer =
+ BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+ return AllocateConstFrame(n);
+}
+
+LaunchThreadFrame LaunchThread(IterVar iter_var, PrimExpr extent) {
+ ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
+ if (!iter_var->dom.defined()) {
+ const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+ } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
+ LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
+ << iter_var->dom->extent << " vs " << extent;
+ }
+ n->iter_var = iter_var;
+ n->extent = extent;
+ n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent";
+ return LaunchThreadFrame(n);
+}
+
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
+ PrimExpr condition) {
+ ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
+ n->buffer_slice = buffer_slice;
+ n->storage_scope = storage_scope;
+ n->condition = condition;
+ return RealizeFrame(n);
+}
+
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) {
+ ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
+ n->node = node;
+ n->attr_key = attr_key;
+ n->value = value;
+ return AttrFrame(n);
+}
+
+WhileFrame While(PrimExpr condition) {
+ ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>();
+ n->condition = condition;
+ return WhileFrame(n);
+}
+
+IfFrame If(PrimExpr condition) {
+ ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>();
+ n->condition = condition;
+ n->then_stmts = NullOpt;
+ n->else_stmts = NullOpt;
+ return IfFrame(n);
+}
+
+ThenFrame Then() {
+ ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>();
+ return ThenFrame(n);
+}
+
+ElseFrame Else() {
+ ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>();
+ return ElseFrame(n);
+}
+
+IterVar EnvThread(String thread_tag) {
+ return IterVar(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
+ thread_tag);
+}
+
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+ AddToParent(tvm::tir::BufferStore(buffer, value, indices));
+}
+
+void Prefetch(Buffer buffer, Array<Range> bounds) {
+ AddToParent(tvm::tir::Prefetch(buffer, bounds));
+}
+
+void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
+
+using tvm::ir_builder::details::Namer;
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+ .set_dispatch<tvm::tir::BufferNode>([](const ObjectRef& node, String name) -> void {
+ tvm::tir::BufferNode* buffer =
+ const_cast<tvm::tir::BufferNode*>(node.as<tvm::tir::BufferNode>());
+ buffer->name = name;
+ Namer::Name(buffer->data, name + "_data");
+ int n = buffer->strides.size();
+ for (int i = 0; i < n; ++i) {
+ PrimExpr e = buffer->strides[i];
+ if (const tvm::tir::VarNode* v = e.as<tvm::tir::VarNode>()) {
+ Namer::Name(GetRef<tvm::tir::Var>(v), name + "_s" + std::to_string(i));
+ }
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+ .set_dispatch<tvm::tir::SizeVarNode>([](const ObjectRef& node, String name) -> void {
+ using namespace tvm::tir;
+ SizeVarNode* var = const_cast<SizeVarNode*>(node.as<SizeVarNode>());
+ var->name_hint = name;
+ });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+ .set_dispatch<tvm::tir::VarNode>([](const ObjectRef& node, String name) -> void {
+ using namespace tvm::tir;
+ VarNode* var = const_cast<VarNode*>(node.as<VarNode>());
+ var->name_hint = name;
+ });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+ .set_dispatch<tvm::tir::IterVarNode>([](const ObjectRef& node, String name) -> void {
+ using namespace tvm::tir;
+ IterVarNode* var = const_cast<IterVarNode*>(node.as<IterVarNode>());
+ Namer::Name(var->var, name);
+ });
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Ptr").set_body_typed(Ptr);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Block").set_body_typed(Block);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Init").set_body_typed(Init);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Where").set_body_typed(Where);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Reads").set_body_typed(Reads);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Writes").set_body_typed(Writes);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Serial").set_body_typed(Serial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Parallel").set_body_typed(Parallel);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Vectorized").set_body_typed(Vectorized);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Unroll").set_body_typed(Unroll);
+TVM_REGISTER_GLOBAL("ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Grid").set_body_typed(Grid);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Arg")
+ .set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
+ using namespace tvm::tir;
+ if (const auto* var = obj.as<VarNode>()) {
+ return Arg(name, GetRef<tvm::tir::Var>(var));
+ }
+ if (const auto* buffer = obj.as<BufferNode>()) {
+ return Arg(name, GetRef<Buffer>(buffer));
+ }
+ LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey();
+ throw;
+ });
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncName").set_body_typed(FuncName);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncRet").set_body_typed(FuncRet);
+TVM_REGISTER_GLOBAL("ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Assert").set_body_typed(Assert);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Allocate").set_body_typed(Allocate);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Attr").set_body_typed(Attr);
+TVM_REGISTER_GLOBAL("ir_builder.tir.While").set_body_typed(While);
+TVM_REGISTER_GLOBAL("ir_builder.tir.If").set_body_typed(If);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Then").set_body_typed(Then);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Else").set_body_typed(Else);
+TVM_REGISTER_GLOBAL("ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.EnvThread").set_body_typed(EnvThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferStore").set_body_typed(BufferStore);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Prefetch").set_body_typed(Prefetch);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Evaluate").set_body_typed(Evaluate);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int8").set_body_typed(Int8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int16").set_body_typed(Int16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32").set_body_typed(Int32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int64").set_body_typed(Int64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt8").set_body_typed(UInt8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt16").set_body_typed(UInt16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt32").set_body_typed(UInt32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt64").set_body_typed(UInt64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float8").set_body_typed(Float8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float16").set_body_typed(Float16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float32").set_body_typed(Float32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float64").set_body_typed(Float64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Boolean").set_body_typed(Boolean);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Handle").set_body_typed(Handle);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Void").set_body_typed(Void);
+TVM_REGISTER_GLOBAL("ir_builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+ return tvm::min(a, b);
+});
+TVM_REGISTER_GLOBAL("ir_builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+ return tvm::max(a, b);
+});
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder_frame.cc b/src/tir/ir_builder/ir_builder_frame.cc
new file mode 100644
index 0000000000..cd0cd46b50
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder_frame.cc
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/function.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "../../tir/ir/script/script_complete.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+void BlockFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ Array<tvm::tir::Buffer> tir_alloc_buffers;
+ for (const tvm::tir::Buffer& buffer : alloc_buffers) {
+ tir_alloc_buffers.push_back(buffer);
+ }
+ if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) {
+ annotations.Set("tir.script_parsing_detect_access",
+ tvm::IntImm(DataType::Int(64), detect_access));
+ }
+ tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()),
+ writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init,
+ tir_alloc_buffers, match_buffers, annotations);
+ if (no_realize) {
+ CHECK(iter_values.empty())
+ << "ValueError: Block bindings are not allowed when `no_realize=True`";
+ CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
+ AddToParent(block);
+ } else {
+ AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
+ }
+}
+
+void BlockInitFrameNode::EnterWithScope() {
+ BlockFrame frame = FindBlockFrame("T.init");
+ if (frame->init.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate block init declaration";
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void BlockInitFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ BlockFrame frame = FindBlockFrame("T.init");
+ frame->init = AsStmt(stmts);
+}
+
+void ForFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
+}
+
+void PrimFuncFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ tvm::tir::PrimFunc func(/*params=*/args,
+ /*body=*/AsStmt(stmts),
+ /*ret_type=*/ret_type.value_or(TupleType::Empty()),
+ /*buffer_map=*/buffer_map,
+ /*preflattened_buffer_map=*/preflattened_buffer_map,
+ /*attrs=*/attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs));
+ func = tvm::tir::ScriptComplete(func, root_alloc_buffers);
+ IRBuilder builder = IRBuilder::Current();
+ if (builder->frames.empty()) {
+ ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+ builder->result = func;
+ } else if (Optional<IRModuleFrame> opt_frame = builder->FindFrame<IRModuleFrame>()) {
+ IRModuleFrame frame = opt_frame.value();
+ frame->global_vars.push_back(GlobalVar(name.value_or("")));
+ frame->functions.push_back(func);
+ } else {
+ LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
+ }
+}
+
+void AssertFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+}
+
+void LetFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
+}
+
+void AllocateFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
+ AsStmt(stmts), annotations));
+}
+
+void AllocateConstFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(
+ tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+}
+
+void LaunchThreadFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
+}
+
+void RealizeFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
+ tvm::tir::StringImm(storage_scope),
+ tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region,
+ condition, AsStmt(stmts))));
+}
+
+void AttrFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts)));
+}
+
+void WhileFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::While(condition, AsStmt(stmts)));
+}
+
+void IfFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ if (!stmts.empty()) {
+ LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame";
+ }
+ if (!then_stmts.defined()) {
+ LOG(FATAL) << "IfThenElse frame should have at least one then branch";
+ }
+ AddToParent(tvm::tir::IfThenElse(
+ condition, AsStmt(then_stmts.value()),
+ else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr)));
+}
+
+void ThenFrameNode::EnterWithScope() {
+ IfFrame frame = FindIfFrame("T.then_");
+ if (frame->then_stmts.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is "
+ << frame->then_stmts.value();
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void ThenFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ FindIfFrame("T.then_")->then_stmts = stmts;
+}
+
+void ElseFrameNode::EnterWithScope() {
+ IfFrame frame = FindIfFrame("T.else_");
+ if (!frame->then_stmts.defined()) {
+ LOG(FATAL) << "The else branch should follow then branch";
+ }
+ if (frame->else_stmts.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is "
+ << frame->else_stmts.value();
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void ElseFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ FindIfFrame("T.else_")->else_stmts = stmts;
+}
+
+void DeclBufferFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+}
+
+TVM_REGISTER_NODE_TYPE(TIRFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
+TVM_REGISTER_NODE_TYPE(ForFrameNode);
+TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
+TVM_REGISTER_NODE_TYPE(AssertFrameNode);
+TVM_REGISTER_NODE_TYPE(LetFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode);
+TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
+TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
+TVM_REGISTER_NODE_TYPE(AttrFrameNode);
+TVM_REGISTER_NODE_TYPE(WhileFrameNode);
+TVM_REGISTER_NODE_TYPE(IfFrameNode);
+TVM_REGISTER_NODE_TYPE(ThenFrameNode);
+TVM_REGISTER_NODE_TYPE(ElseFrameNode);
+TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode);
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
diff --git a/src/tir/ir_builder/utils.h b/src/tir/ir_builder/utils.h
new file mode 100644
index 0000000000..6b6271ab0e
--- /dev/null
+++ b/src/tir/ir_builder/utils.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_BUILDER_TIR_BASE_H_
+#define TVM_SCRIPT_BUILDER_TIR_BASE_H_
+
+#include <tvm/tir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+inline void AddToParent(tvm::tir::Stmt stmt) {
+ IRBuilder builder = IRBuilder::Current();
+ if (builder->frames.empty()) {
+ ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+ builder->result = stmt;
+ } else if (const auto* tir_frame = builder->frames.back().as<TIRFrameNode>()) {
+ GetRef<TIRFrame>(tir_frame)->stmts.push_back(stmt);
+ } else {
+ LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back();
+ }
+}
+
+inline tvm::tir::Stmt AsStmt(const Array<tvm::tir::Stmt>& stmt) {
+ using namespace tvm::tir;
+ if (stmt.empty()) {
+ return tvm::tir::Evaluate(0);
+ } else if (stmt.size() == 1) {
+ return stmt[0];
+ } else {
+ return SeqStmt(stmt);
+ }
+}
+
+inline BlockFrame FindBlockFrame(const String& method) {
+ if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+ return frame.value();
+ }
+ LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
+ << "' is called under T.block()";
+ throw;
+}
+
+inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
+ if (Optional<PrimFuncFrame> frame = IRBuilder::Current()->GetLastFrame<PrimFuncFrame>()) {
+ return frame.value();
+ }
+ LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method
+ << "' is called under T.prim_func()";
+ throw;
+}
+
+inline IfFrame FindIfFrame(const String& method) {
+ if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) {
+ return frame.value();
+ } else {
+ LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method
+ << "' is called under T.if_()";
+ }
+ throw;
+}
+
+inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) {
+ Array<Range> ranges;
+ for (const PrimExpr& index : buffer_load->indices) {
+ ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1)));
+ }
+ return tvm::tir::BufferRegion(buffer_load->buffer, ranges);
+}
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
+
+#endif // TVM_SCRIPT_BUILDER_TIR_BASE_H_
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 114571218b..dd3c7a0b0f 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
{x, y, q, s}, span);
}
+// address_of
+PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
+ return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
+}
+
+// lookup_param
+PrimExpr lookup_param(String param_name, Span span) {
+ return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
+ span);
+}
+
// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
@@ -702,6 +713,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
}
}
+// isnullptr
+PrimExpr isnullptr(PrimExpr x, Span span) {
+ return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
+}
+
// isinf
PrimExpr isinf(PrimExpr x, Span span) {
DataType t = DataType::Bool(x.dtype().lanes());
@@ -931,6 +947,8 @@ TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
+TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);
+
TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);
TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
@@ -949,6 +967,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
+TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity);
+
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
@@ -997,6 +1019,8 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, right_shift);
+TVM_REGISTER_GLOBAL("tir._OpNot").set_body_typed(logical_not);
+
TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
return if_then_else(cond, true_value, false_value, span);
diff --git a/tests/python/tvmscript/test_builder_basic.py b/tests/python/tvmscript/test_builder_basic.py
new file mode 100644
index 0000000000..7224d1b307
--- /dev/null
+++ b/tests/python/tvmscript/test_builder_basic.py
@@ -0,0 +1,227 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.script.builder import Builder, name, name_many
+from tvm.script.builder import tir as T
+from tvm.ir import Range
+
+
+def test_builder_root_block():
+ print("test_builder_root_block")
+ # impilict root block
+ with Builder() as b0:
+ with T.prim_func():
+ T.func_name("main")
+ T.func_attr({"key": "value"})
+ with T.block(name="block"):
+ pass
+ print(b0.get().script())
+ with Builder() as b1:
+ with T.prim_func():
+ T.func_name("main")
+ T.func_attr({"key": "value"})
+ A = name("A", T.alloc_buffer((128,)))
+ with T.block(name="block"):
+ pass
+ print(b1.get().script())
+ with Builder() as b2:
+ with T.prim_func():
+ T.func_name("main")
+ T.func_attr({"key": "value"})
+ A = name("A", T.alloc_buffer((128,)))
+ with T.block(name="block0"):
+ pass
+ with T.block(name="block1"):
+ pass
+ print(b2.get().script())
+ # expilict root block
+ with Builder() as b0_r:
+ with T.prim_func():
+ T.func_name("main")
+ T.func_attr({"key": "value"})
+ with T.block(name="root"):
+ with T.block(name="block"):
+ pass
+ print(b0_r.get().script())
+ with Builder() as b1_r:
+ with T.prim_func():
+ T.func_name("main")
+ T.func_attr({"key": "value"})
+ with T.block(name="root"):
+ A = name("A", T.alloc_buffer((128,)))
+ with T.block(name="block"):
+ pass
+ print(b1_r.get().script())
+ with Builder() as b2_r:
+ with T.prim_func():
+ T.func_name("main")
+ T.func_attr({"key": "value"})
+ with T.block(name="root"):
+ A = name("A", T.alloc_buffer((128,)))
+ with T.block(name="block0"):
+ pass
+ with T.block(name="block1"):
+ pass
+ print(b2_r.get().script())
+
+
+def test_builder_axis():
+ print("test_builder_axis")
+ with Builder() as b:
+ with T.prim_func():
+ T.func_name("main")
+ with T.grid(128, 128, 128, 128, 128) as (i, j, k, m, n):
+ name_many(["i", "j", "k", "m", "n"], [i, j, k, m, n])
+ with T.block(name="block"):
+ vi = name("vi", T.axis.spatial(128, i))
+ vj = name("vj", T.axis.spatial(128, j))
+ vk = name("vk", T.axis.reduce(128, k))
+ vm = name("vm", T.axis.scan(128, m))
+ vn = name("vn", T.axis.opaque(128, n))
+ x, y, z = name_many(["x", "y", "z"], T.axis.remap("SSR", [i, j, k]))
+ print(b.get().script())
+
+
+def test_builder_prim_func():
+ print("test_builder_prim_func")
+ with Builder() as b:
+ with T.prim_func():
+ T.func_name("main")
+ T.func_attr({"global_symbol": "main"})
+ arg_a = T.arg("a", T.handle())
+ arg_b = T.arg("b", T.handle())
+ buffer_c = T.Buffer((128,), "float32")
+ buffer_d = T.Buffer((128,), "float32")
+ arg_c = T.arg("c", buffer_c)
+ arg_d = T.arg("d", buffer_d)
+ T.func_ret(tvm.ir.PrimType("int8"))
+ A = name("A", T.match_buffer(arg_a, (128, 128, 128), "int32"))
+ B = name("B", T.match_buffer(arg_b, (128, 128, 128), "int32"))
+ T.preflattened_buffer(buffer_c, (128,), data=buffer_c.data)
+ T.preflattened_buffer(buffer_d, (128,), data=buffer_d.data)
+ print(b.get().script())
+
+
+def test_builder_block():
+ print("test_builder_block")
+ with Builder() as b:
+ with T.prim_func():
+ arg_a = T.arg("a", T.handle())
+ arg_b = T.arg("b", T.handle())
+ A = name("A", T.match_buffer(arg_a, (128, 128, 128), "int32"))
+ B = name("B", T.match_buffer(arg_b, (128, 128, 128), "int32"))
+ with T.grid(128, 128, 128) as (i, j, k):
+ name_many(["i", "j", "k"], [i, j, k])
+ with T.block(name="block"):
+ T.block_attr({"axis": 1})
+ T.where(i > 1)
+ with T.init():
+ pass
+ vi, vj, vk = name_many(["vi", "vj", "vk"], T.axis.remap("SSR", [i, j, k]))
+ T.reads(A[vi, vj, vk : vk + B[1, 2, A[3, 4, 5]]])
+ T.writes(A[100, A[50, 51, 52], 102])
+ E = name("E", T.alloc_buffer((128, 128)))
+ F = name("F", T.alloc_buffer((128, 128)))
+ print(b.get().script())
+
+
+def test_builder_for():
+ print("test_builder_for")
+ with Builder() as b:
+ with T.prim_func():
+ with T.grid(128, 128, 128) as (i, j, k):
+ name_many(["i", "j", "k"], [i, j, k])
+ with T.serial(0, 128) as w:
+ w = name("w", w)
+ with T.parallel(0, 128) as x:
+ x = name("x", x)
+ with T.vectorized(0, 128) as y:
+ y = name("y", y)
+ with T.unroll(0, 128) as z:
+ z = name("z", z)
+ with T.thread_binding(0, 32, thread="blockIdx.x") as bx:
+ bx = name("bx", bx)
+ with T.thread_binding(0, 2, thread="vthread.y") as vy:
+ vy = name("vy", vy)
+ with T.thread_binding(0, 8, thread="threadIdx.z") as tz:
+ tz = name("tz", tz)
+ print(b.get().script())
+
+
+def test_builder_stmt():
+ print("test_builder_stmt")
+ with Builder() as b:
+ with T.prim_func():
+ thread_x = name("thread_x", T.env_thread("threadIdx.x"))
+ thread_y = name("thread_y", T.env_thread("threadIdx.y"))
+ buffer_x = name("buffer_x", T.Buffer([128, 128]))
+ buffer_y = name("buffer_y", T.Buffer([128, 128]))
+ var_x = name("var_x", tvm.tir.Var("", dtype="int32"))
+ var_y = name("var_y", tvm.tir.Var("", dtype="int32"))
+ with T.Assert(var_x < var_y, ""):
+ with T.Assert(1, "true"):
+ pass
+ with T.let(var_x, var_y):
+ pass
+ with T.allocate([128], "uint8", "global") as alloc_x:
+ with T.allocate([128], "uint8", "global") as alloc_y:
+ alloc_x, alloc_y = name_many(["alloc_x", "alloc_y"], [alloc_x, alloc_y])
+ with T.allocate_const([1, 1, 1, 1, 1], "int32", [5]) as alloc_const_x:
+ with T.allocate_const([10, 10, 10], "float32", [3]) as alloc_const_y:
+ alloc_const_x, alloc_const_y = name_many(
+ ["alloc_const_x", "alloc_const_y"], [alloc_const_x, alloc_const_y]
+ )
+ with T.realize(buffer_x[0:var_x, 0:var_y], ""):
+ with T.realize(buffer_x[var_x:128, var_y:128], ""):
+ pass
+ with T.attr(buffer_x, "key_x", "value_x"):
+ with T.attr(buffer_y, "key_y", "value_y"):
+ pass
+ with T.launch_thread(thread_x, 4):
+ with T.launch_thread(thread_y, 4):
+ pass
+ with T.while_(var_x < var_y):
+ with T.while_(var_x > 0):
+ pass
+ with T.if_(var_x < var_y):
+ with T.then_():
+ T.evaluate(0)
+ T.evaluate(1)
+ with T.else_():
+ T.evaluate(0)
+ T.evaluate(1)
+ with T.if_(1):
+ with T.then_():
+ T.evaluate(1)
+ T.prefetch(buffer_x, [Range(0, 64), Range(64, 128)])
+ T.prefetch(buffer_y, [Range(0, var_x), Range(var_y, 128)])
+ T.buffer_store(buffer_x, 1, [0, 0])
+ T.buffer_store(buffer_x, var_x + var_y, [var_x, var_y])
+ T.evaluate(var_x + var_y)
+ T.evaluate(1)
+
+ print(b.get().script())
+
+
+if __name__ == "__main__":
+ test_builder_root_block()
+ test_builder_axis()
+ test_builder_prim_func()
+ test_builder_block()
+ test_builder_for()
+ test_builder_stmt()
diff --git a/tests/python/tvmscript/test_parse_basic.py b/tests/python/tvmscript/test_parse_basic.py
new file mode 100644
index 0000000000..ca3e908ce1
--- /dev/null
+++ b/tests/python/tvmscript/test_parse_basic.py
@@ -0,0 +1,118 @@
+import inspect
+
+import pytest
+import tvm
+from tvm.ir import structural_equal
+from tvm.script.builder import ir as I
+from tvm.script.builder import tir as T
+
+
+def test_parse_elementwise():
+ # pylint: disable=unused-argument,unused-variable,invalid-name
+ @T.prim_func
+ def elementwise(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore
+ ) -> None:
+ for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128):
+ with T.block("inner_block"):
+ # vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ vi = T.axis.S(128, i + 1)
+ vj = T.axis.S(128, j + 20)
+ vk = T.axis.R(128, k - i)
+ A[vi + 1, vj] = A[vi, vk] * B[vvv[0], vvv[1], vvv[2]] + 2
+ B[vi, vj, vk] = A[vvv[0], vvv[-1]]
+
+ # pylint: enable=unused-argument,unused-variable,invalid-name
+
+ result = elementwise
+ print(result.script())
+
+
+def test_parse_skip():
+ class Skip:
+ @T.prim_func
+ def f(): # type: ignore
+ ...
+
+ assert inspect.isfunction(Skip.f)
+
+
+def test_parse_class():
+ # pylint: disable=unused-argument,unused-variable,invalid-name
+ @I.ir_module
+ class C:
+ @T.prim_func
+ def elementwise(
+ A: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore
+ B: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore
+ ) -> None:
+ for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128):
+ with T.block("inner_block"):
+ # vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ vi = T.axis.S(128, i + 1)
+ vj = T.axis.S(128, j + 20)
+ vk = T.axis.R(128, k - i)
+
+ # pylint: enable=unused-argument,unused-variable,invalid-name
+
+ print(C.script())
+
+
+def test_parse_atomic():
+ @T.prim_func
+ def f(A: T.int32, B: T.int64, C: T.handle) -> None:
+ pass
+
+ assert f.params[0].name == "A"
+ assert f.params[0].dtype == "int32"
+ assert f.params[1].name == "B"
+ assert f.params[1].dtype == "int64"
+ assert f.params[2].name == "C"
+ assert f.params[2].dtype == "handle"
+
+
+def test_parse_report_error():
+ with pytest.raises(tvm.error.DiagnosticError):
+
+ @T.prim_func
+ def elementwise() -> None:
+ for (*vvv,) in T.grid(128, 128, 128, 128, 128, 128, 128):
+ with T.block("inner_block"):
+ vj = T.axis.S(128, vvv[10] + 20)
+
+
+def test_parse_concise_scope():
+ # pylint: disable=unused-argument,unused-variable,invalid-name
+ @T.prim_func
+ def concise_scope(
+ A: T.handle,
+ ) -> None:
+ A_local = T.allocate([64], "float32", "local")
+ B_local = T.allocate([64], "float32", "local")
+ C_local = T.allocate([64], "float32", "local")
+ T.evaluate(1)
+ T.evaluate(2)
+ T.evaluate(3)
+
+ @T.prim_func
+ def normal_scope(
+ A: T.handle,
+ ) -> None:
+ with T.allocate([64], "float32", "local") as A_local:
+ with T.allocate([64], "float32", "local") as B_local:
+ with T.allocate([64], "float32", "local") as C_local:
+ T.evaluate(1)
+ T.evaluate(2)
+ T.evaluate(3)
+
+ assert structural_equal(normal_scope, concise_scope)
+
+
+if __name__ == "__main__":
+ test_parse_elementwise()
+ test_parse_skip()
+ test_parse_class()
+ test_parse_atomic()
+ test_parse_report_error()
+ test_parse_concise_scope()
diff --git a/tests/python/tvmscript/test_parser_capture.py b/tests/python/tvmscript/test_parser_capture.py
new file mode 100644
index 0000000000..b61d35137f
--- /dev/null
+++ b/tests/python/tvmscript/test_parser_capture.py
@@ -0,0 +1,43 @@
+from tvm.script.builder import ir as I
+from tvm.script.builder import tir as T
+
+
+def test_capture_func():
+ from tvm.script.builder.tir import axis as ax
+ from tvm.script.builder.tir import block, match_buffer
+
+ @T.prim_func
+ def scalar_func(a: T.handle, b: T.handle, c: T.Buffer((128,))):
+ A = match_buffer(a, (128, 128))
+ B = match_buffer(b, (128, 128))
+ with block():
+ for i, j in T.grid(128, 128):
+ with block("inner_block"):
+ vi, vj = ax.remap("SR", [i, j])
+ A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]
+
+ print(scalar_func.script())
+
+
+def test_capture_class():
+ from tvm.script.builder.tir import axis as ax
+ from tvm.script.builder.tir import block, match_buffer
+
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def scalar_func(a: T.handle, b: T.handle, c: T.Buffer((128,))):
+ A = match_buffer(a, (128, 128))
+ B = match_buffer(b, (128, 128))
+ with block():
+ for i, j in T.grid(128, 128):
+ with block("inner_block"):
+ vi, vj = ax.remap("SR", [i, j])
+ A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]
+
+ print(Module.script())
+
+
+if __name__ == "__main__":
+ test_capture_func()
+ test_capture_class()