You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 22:31:05 UTC
[tvm] 01/01: Squashed
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch ir-builder
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 22477336093aaf860a59107c58caae99c37ad3bb
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Mon May 23 17:06:45 2022 -0700
Squashed
---
include/tvm/ir/expr.h | 20 +-
include/tvm/ir/ir_builder.h | 190 +++++
include/tvm/support/with.h | 2 +
include/tvm/tir/ir_builder.h | 140 +++
include/tvm/tir/ir_builder_frame.h | 456 ++++++++++
include/tvm/tir/op.h | 34 +-
python/tvm/ir/__init__.py | 55 +-
.../_ffi_api.py => ir/_ffi_ir_builder_api.py} | 4 +-
python/tvm/ir/ir_builder.py | 85 ++
python/tvm/meta_schedule/default_config.py | 2 +-
python/tvm/meta_schedule/testing/schedule_rule.py | 12 +-
python/tvm/script/__init__.py | 19 +-
python/tvm/script/{ => parser}/__init__.py | 14 +-
python/tvm/script/parser/diagnostics.py | 60 ++
python/tvm/script/parser/dispatch.py | 63 ++
python/tvm/script/parser/doc.py | 361 ++++++++
python/tvm/script/{printer => parser}/doc_core.py | 0
.../script/{tir/prim_func.py => parser/entry.py} | 45 +-
python/tvm/script/parser/evaluator.py | 282 ++++++
.../script/{_ffi_api.py => parser/ir/__init__.py} | 7 +-
.../tvm/script/{__init__.py => parser/ir/entry.py} | 19 +-
.../{tir/__init__.py => parser/ir/parser.py} | 30 +-
python/tvm/script/parser/parser.py | 214 +++++
python/tvm/script/parser/source.py | 134 +++
python/tvm/script/{ => parser/tir}/__init__.py | 9 +-
python/tvm/script/parser/tir/entry.py | 103 +++
python/tvm/script/parser/tir/operation.py | 85 ++
python/tvm/script/parser/tir/parser.py | 269 ++++++
.../script/{tir/prim_func.py => parser/utils.py} | 47 +-
python/tvm/script/parser/var_table.py | 71 ++
python/tvm/script/{ => parser_v1}/__init__.py | 3 +-
python/tvm/script/{ => parser_v1}/_ffi_api.py | 0
.../script/{ => parser_v1}/context_maintainer.py | 8 +-
python/tvm/script/{ => parser_v1}/diagnostics.py | 6 +-
python/tvm/script/{ => parser_v1}/meta_unparser.py | 0
python/tvm/script/{ => parser_v1}/parser.py | 23 +-
python/tvm/script/{ => parser_v1}/registry.py | 2 +-
python/tvm/script/{ => parser_v1}/tir/__init__.py | 0
python/tvm/script/{ => parser_v1}/tir/__init__.pyi | 0
python/tvm/script/{ => parser_v1}/tir/intrin.py | 5 +-
python/tvm/script/{ => parser_v1}/tir/node.py | 7 +-
python/tvm/script/{ => parser_v1}/tir/prim_func.py | 3 +-
.../script/{ => parser_v1}/tir/scope_handler.py | 17 +-
.../tvm/script/{ => parser_v1}/tir/special_stmt.py | 16 +-
python/tvm/script/{ => parser_v1}/tir/ty.py | 1 +
python/tvm/script/{ => parser_v1}/utils.py | 7 +-
python/tvm/te/operation.py | 15 +-
python/tvm/tir/__init__.py | 217 ++++-
.../_ffi_api.py => tir/_ffi_ir_builder_api.py} | 4 +-
python/tvm/tir/analysis/analysis.py | 4 +-
python/tvm/tir/buffer.py | 39 +-
python/tvm/tir/expr.py | 15 +-
python/tvm/tir/function.py | 2 +-
python/tvm/tir/ir_builder_frame.py | 118 +++
python/tvm/tir/{ir_builder.py => ir_builder_v1.py} | 6 +-
python/tvm/tir/ir_builder_v2.py | 949 +++++++++++++++++++++
python/tvm/tir/op.py | 601 ++++++++++++-
python/tvm/tir/schedule/block_scope.py | 2 +-
python/tvm/tir/schedule/schedule.py | 6 +-
python/tvm/tir/schedule/state.py | 3 +-
python/tvm/tir/stmt.py | 4 +
python/tvm/tir/tensor_intrin/__init__.py | 6 +-
python/tvm/tir/tensor_intrin/arm_cpu.py | 3 +-
python/tvm/tir/tensor_intrin/cuda.py | 14 +-
python/tvm/tir/tensor_intrin/rocm.py | 2 +-
python/tvm/tir/usmp/transform/transform.py | 5 +-
src/ir/diagnostic.cc | 6 +-
src/ir/expr.cc | 13 +
src/ir/ir_builder.cc | 134 +++
src/tir/ir/expr.cc | 24 +-
src/tir/ir/script/script_complete.cc | 5 +-
src/tir/ir/script/script_complete.h | 37 +
src/tir/ir/stmt.cc | 10 +
src/tir/ir_builder/ir_builder.cc | 664 ++++++++++++++
src/tir/ir_builder/ir_builder_frame.cc | 208 +++++
src/tir/ir_builder/utils.h | 92 ++
src/tir/op/op.cc | 24 +
src/tir/schedule/primitive/cache_read_write.cc | 2 +-
.../test_meta_schedule_auto_tensorize.py | 10 +-
.../unittest/test_aot_legalize_packed_call.py | 6 +-
...est_meta_schedule_postproc_rewrite_tensorize.py | 2 +-
...ta_schedule_schedule_rule_multi_level_tiling.py | 4 +-
.../unittest/test_meta_schedule_space_cuda.py | 2 +-
.../unittest/test_meta_schedule_tune_relay.py | 8 +-
tests/python/unittest/test_target_codegen_llvm.py | 15 +-
.../python/unittest/test_tir_lower_match_buffer.py | 44 +-
.../python/unittest/test_tir_schedule_tensorize.py | 11 +-
.../python/unittest/test_tir_transform_helpers.py | 7 +-
.../test_tir_transform_hoist_expression.py | 10 +-
.../test_tir_transform_inject_software_pipeline.py | 14 +-
.../test_tir_transform_inject_virtual_thread.py | 18 +-
.../python/unittest/test_tvmscript_error_report.py | 40 +-
tests/python/unittest/test_tvmscript_spans.py | 73 --
.../python/unittest/test_tvmscript_syntax_sugar.py | 39 +-
94 files changed, 5992 insertions(+), 475 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 5e358ed50e..cbc7db8686 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -764,16 +764,32 @@ struct PackedFuncValueConverter<PrimExpr> {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
- return PrimExpr(val.operator int());
+ return IntImm(runtime::DataType::Int(32), val.operator int());
}
if (val.type_code() == kDLFloat) {
- return PrimExpr(static_cast<float>(val.operator double()));
+ return FloatImm(runtime::DataType::Float(32), val.operator double());
}
return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
}
};
+// template <>
+// struct PackedFuncValueConverter<Array<PrimExpr>> {
+// static Array<PrimExpr> From(const TVMPODValue_& val) {
+// if (val.type_code() == kTVMNullptr) return Array<PrimExpr>(nullptr);
+// Array<ObjectRef> vals = val.AsObjectRef<Array<ObjectRef>>();
+// Array<PrimExpr> exprs;
+// for (const ObjectRef& v : vals) {
+// TVMValue value;
+// value.v_handle = const_cast<void*>(static_cast<const void*>(v.get()));
+// exprs.push_back(
+// PackedFuncValueConverter<PrimExpr>::From(TVMArgValue(value, kTVMObjectHandle)));
+// }
+// return exprs;
+// }
+// };
+
template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
diff --git a/include/tvm/ir/ir_builder.h b/include/tvm/ir/ir_builder.h
new file mode 100644
index 0000000000..2ecc774ec5
--- /dev/null
+++ b/include/tvm/ir/ir_builder.h
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_IR_IR_BUILDER_H_
+#define TVM_IR_IR_BUILDER_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+
+#include <vector>
+
+namespace tvm {
+namespace ir_builder {
+
+////////////////////////////// Core Infra: Frame and IRBuilder //////////////////////////////
+
+class IRBuilderFrameNode : public runtime::Object {
+ public:
+ std::vector<runtime::TypedPackedFunc<void()>> callbacks;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `callbacks` is not visited.
+ }
+
+ static constexpr const char* _type_key = "ir_builder.IRBuilderFrame";
+ TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
+
+ public:
+ virtual ~IRBuilderFrameNode() = default;
+ virtual void EnterWithScope();
+ virtual void ExitWithScope();
+
+ void AddCallback(runtime::TypedPackedFunc<void()> callback);
+};
+
+class IRBuilderFrame : public runtime::ObjectRef {
+ public:
+ virtual ~IRBuilderFrame() = default;
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
+
+ protected:
+ IRBuilderFrame() = default;
+
+ public:
+ inline void EnterWithScope();
+ inline void ExitWithScope();
+};
+
+class IRBuilderNode : public runtime::Object {
+ public:
+ runtime::Array<IRBuilderFrame> frames;
+ Optional<ObjectRef> result;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("frames", &frames);
+ v->Visit("result", &result);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.IRBuilder";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
+
+ public:
+ template <typename TFrame>
+ inline Optional<TFrame> FindFrame() const;
+ template <typename TFrame>
+ inline Optional<TFrame> GetLastFrame() const;
+ template <typename TObjectRef>
+ inline TObjectRef Get() const;
+};
+
+class IRBuilder : public runtime::ObjectRef {
+ public:
+ IRBuilder();
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
+
+ public:
+ void EnterWithScope();
+ void ExitWithScope();
+ static IRBuilder Current();
+ template <class TObjectRef>
+ inline static TObjectRef Name(String name, TObjectRef obj);
+};
+
+////////////////////////////// Generic IRModule //////////////////////////////
+
+class IRModuleFrameNode : public IRBuilderFrameNode {
+ public:
+ Array<GlobalVar> global_vars;
+ Array<BaseFunc> functions;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ IRBuilderFrameNode::VisitAttrs(v);
+ v->Visit("global_vars", &global_vars);
+ v->Visit("functions", &functions);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.IRModuleFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class IRModuleFrame : public IRBuilderFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame,
+ IRModuleFrameNode);
+};
+
+TVM_DLL IRModuleFrame IRModule();
+
+////////////////////////////// Details //////////////////////////////
+
+namespace details {
+
+class Namer {
+ public:
+ using FType = NodeFunctor<void(const ObjectRef&, String)>;
+ static FType& vtable();
+ static void Name(ObjectRef node, String name);
+};
+
+} // namespace details
+
+template <class TObjectRef>
+inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
+ details::Namer::Name(obj, name);
+ return Downcast<TObjectRef>(obj);
+}
+
+inline void IRBuilderFrame::EnterWithScope() {
+ ICHECK(data_ != nullptr);
+ static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
+}
+
+inline void IRBuilderFrame::ExitWithScope() {
+ ICHECK(data_ != nullptr);
+ static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
+ data_.reset();
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::FindFrame() const {
+ using TFrameNode = typename TFrame::ContainerType;
+ for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
+ if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
+ return GetRef<TFrame>(p);
+ }
+ }
+ return NullOpt;
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
+ using TFrameNode = typename TFrame::ContainerType;
+ if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
+ return Downcast<TFrame>(frames.back());
+ }
+ return NullOpt;
+}
+
+template <typename TObjectRef>
+inline TObjectRef IRBuilderNode::Get() const {
+ using TObject = typename TObjectRef::ContainerType;
+ CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
+ const auto* n = result.as<TObject>();
+ CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key;
+ return GetRef<TObjectRef>(n);
+}
+
+} // namespace ir_builder
+} // namespace tvm
+
+#endif // TVM_IR_IR_BUILDER_H_
diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h
index 3651e05e74..bbc36419ff 100644
--- a/include/tvm/support/with.h
+++ b/include/tvm/support/with.h
@@ -75,6 +75,8 @@ class With {
ContextType& operator*() { return *get(); }
const ContextType* operator*() const { return *get(); }
+ ContextType operator()() { return ctx_; }
+
private:
/*! \brief internal context type. */
ContextType ctx_;
diff --git a/include/tvm/tir/ir_builder.h b/include/tvm/tir/ir_builder.h
new file mode 100644
index 0000000000..19111b9b20
--- /dev/null
+++ b/include/tvm/tir/ir_builder.h
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_H_
+#define TVM_TIR_IR_BUILDER_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/ir_builder_frame.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+using tvm::runtime::NDArray;
+using tvm::tir::Buffer;
+using tvm::tir::Var;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+ Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+ String storage_scope, int align, int offset_factor, String buffer_type,
+ Optional<Array<IntImm>> axis_separators);
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
+BlockFrame Block(String name, bool no_realize = false);
+BlockInitFrame Init();
+void Where(PrimExpr predicate);
+void Reads(Array<ObjectRef> buffer_slices);
+void Writes(Array<ObjectRef> buffer_slices);
+void BlockAttrs(Map<String, ObjectRef> attrs);
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+ Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+ PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
+ int offset_factor = 0, String buffer_type = "default",
+ Array<IntImm> axis_separators = {});
+
+namespace axis {
+Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+} // namespace axis
+
+ForFrame Serial(PrimExpr start, PrimExpr stop,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Parallel(PrimExpr start, PrimExpr stop,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Vectorized(PrimExpr start, PrimExpr stop,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Unroll(PrimExpr start, PrimExpr stop,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+ForFrame Grid(Array<PrimExpr> extents);
+
+PrimFuncFrame PrimFunc();
+Var Arg(String name, Var var);
+Buffer Arg(String name, Buffer buffer);
+void FuncName(String name);
+void FuncAttrs(Map<String, ObjectRef> attrs);
+Type FuncRet(Type ret_type);
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+ Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+ PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
+ int align = -1, int offset_factor = 0, String buffer_type = "default",
+ Array<IntImm> axis_separators = {});
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
+ DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
+ Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
+ String storage_scope = "global", int align = -1, int offset_factor = 0,
+ String buffer_type = "default", Array<IntImm> axis_separators = {});
+
+AssertFrame Assert(PrimExpr condition, String message);
+LetFrame Let(Var var, PrimExpr value);
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope = "",
+ Optional<PrimExpr> condition = NullOpt,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
+AllocateConstFrame AllocateConst(
+ NDArray data, DataType dtype, Array<PrimExpr> extents,
+ Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value);
+WhileFrame While(PrimExpr condition);
+IfFrame If(PrimExpr condition);
+ThenFrame Then();
+ElseFrame Else();
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
+Var EnvThread(String thread_tag);
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+void Prefetch(Buffer buffer, Array<Range> bounds);
+void Evaluate(PrimExpr value);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
+ inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
+ DataType dtype = DType; \
+ return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
+ }
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
+
+#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
+
+#endif // TVM_TIR_IR_BUILDER_H_
diff --git a/include/tvm/tir/ir_builder_frame.h b/include/tvm/tir/ir_builder_frame.h
new file mode 100644
index 0000000000..d975710a05
--- /dev/null
+++ b/include/tvm/tir/ir_builder_frame.h
@@ -0,0 +1,456 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_FRAME_H_
+#define TVM_TIR_IR_BUILDER_FRAME_H_
+
+#include <tvm/ir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+class TIRFrameNode : public IRBuilderFrameNode {
+ public:
+ Array<tvm::tir::Stmt> stmts;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ IRBuilderFrameNode::VisitAttrs(v);
+ v->Visit("stmts", &stmts);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.TIRFrame";
+ TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode);
+};
+
+class TIRFrame : public IRBuilderFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode);
+
+ protected:
+ TIRFrame() = default;
+};
+
+class BlockFrameNode : public TIRFrameNode {
+ public:
+ String name;
+ Array<tvm::tir::IterVar> iter_vars;
+ Optional<Array<tvm::tir::BufferRegion>> reads;
+ Optional<Array<tvm::tir::BufferRegion>> writes;
+ Optional<tvm::tir::Stmt> init;
+ Array<tvm::tir::Buffer> alloc_buffers;
+ Array<tvm::tir::MatchBufferRegion> match_buffers;
+ Optional<Map<String, ObjectRef>> annotations;
+
+ Array<PrimExpr> iter_values;
+ Optional<PrimExpr> predicate;
+ bool no_realize;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("name", &name);
+ v->Visit("iter_vars", &iter_vars);
+ v->Visit("reads", &reads);
+ v->Visit("writes", &writes);
+ v->Visit("init", &init);
+ v->Visit("alloc_buffers", &alloc_buffers);
+ v->Visit("match_buffers", &match_buffers);
+ v->Visit("annotations", &annotations);
+ v->Visit("iter_values", &iter_values);
+ v->Visit("predicate", &predicate);
+ v->Visit("no_realize", &no_realize);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.BlockFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class BlockFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
+};
+
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+ void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+ static constexpr const char* _type_key = "ir_builder.tir.BlockInitFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+ void EnterWithScope() final;
+ void ExitWithScope() final;
+};
+
+class BlockInitFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode);
+};
+
+class ForFrameNode : public TIRFrameNode {
+ public:
+ using FMakeForLoop =
+ runtime::TypedPackedFunc<tvm::tir::Stmt(Array<tvm::tir::Var>, Array<Range>, tvm::tir::Stmt)>;
+
+ Array<tvm::tir::Var> vars;
+ Array<Range> doms;
+ FMakeForLoop f_make_for_loop;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("vars", &vars);
+ v->Visit("doms", &doms);
+ // `f_make_for_loop` is not visited.
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.ForFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class ForFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
+};
+
+class PrimFuncFrameNode : public TIRFrameNode {
+ public:
+ Optional<String> name;
+ Array<tvm::tir::Var> args;
+ Optional<Type> ret_type;
+ Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
+ Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
+ Optional<Map<String, ObjectRef>> attrs;
+ Map<tvm::tir::Var, tvm::tir::IterVar> env_threads;
+ Array<tvm::tir::Buffer> root_alloc_buffers;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("name", &name);
+ v->Visit("args", &args);
+ v->Visit("ret_type", &ret_type);
+ v->Visit("buffer_map", &buffer_map);
+ v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
+ v->Visit("attrs", &attrs);
+ v->Visit("env_threads", &env_threads);
+ v->Visit("root_alloc_buffers", &root_alloc_buffers);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.PrimFuncFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class PrimFuncFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
+};
+
+class AssertFrameNode : public TIRFrameNode {
+ public:
+ PrimExpr condition;
+ PrimExpr message;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("condition", &condition);
+ v->Visit("message", &message);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.AssertFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class AssertFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
+};
+
+class LetFrameNode : public TIRFrameNode {
+ public:
+ tvm::tir::Var var;
+ PrimExpr value;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("var", &var);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.LetFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class LetFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
+};
+
+class AllocateFrameNode : public TIRFrameNode {
+ public:
+ Array<PrimExpr> extents;
+ DataType dtype;
+ String storage_scope;
+ PrimExpr condition;
+ Map<String, ObjectRef> annotations;
+ tvm::tir::Buffer buffer;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("extents", &extents);
+ v->Visit("dtype", &dtype);
+ v->Visit("storage_scope", &storage_scope);
+ v->Visit("condition", &condition);
+ v->Visit("annotations", &annotations);
+ v->Visit("buffer", &buffer);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.AllocateFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class AllocateFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
+};
+
+class AllocateConstFrameNode : public TIRFrameNode {
+ public:
+ DataType dtype;
+ Array<PrimExpr> extents;
+ tvm::runtime::NDArray data;
+ tvm::tir::Buffer buffer;
+ Map<String, ObjectRef> annotations;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("dtype", &dtype);
+ v->Visit("extents", &extents);
+ v->Visit("data", &data);
+ v->Visit("buffer", &buffer);
+ v->Visit("annotations", &annotations);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.AllocateConstFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class AllocateConstFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
+ AllocateConstFrameNode);
+};
+
+class LaunchThreadFrameNode : public TIRFrameNode {
+ public:
+ PrimExpr extent;
+ String attr_key;
+ tvm::tir::IterVar iter_var;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("extent", &extent);
+ v->Visit("attr_key", &attr_key);
+ v->Visit("iter_var", &iter_var);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.LaunchThreadFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class LaunchThreadFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
+ LaunchThreadFrameNode);
+};
+
+class RealizeFrameNode : public TIRFrameNode {
+ public:
+ tvm::tir::BufferRegion buffer_slice;
+ String storage_scope;
+ PrimExpr condition;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("buffer_slice", &buffer_slice);
+ v->Visit("storage_scope", &storage_scope);
+ v->Visit("condition", &condition);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.RealizeFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class RealizeFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
+};
+
+class AttrFrameNode : public TIRFrameNode {
+ public:
+ ObjectRef node;
+ String attr_key;
+ PrimExpr value;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("node", &node);
+ v->Visit("attr_key", &attr_key);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.AttrFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class AttrFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
+};
+
+class WhileFrameNode : public TIRFrameNode {
+ public:
+ PrimExpr condition;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("condition", &condition);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.WhileFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class WhileFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
+};
+
+class IfFrameNode : public TIRFrameNode {
+ public:
+ PrimExpr condition;
+ Optional<Array<tvm::tir::Stmt>> then_stmts;
+ Optional<Array<tvm::tir::Stmt>> else_stmts;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("condition", &condition);
+ v->Visit("then_stmts", &then_stmts);
+ v->Visit("else_stmts", &else_stmts);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.IfFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class IfFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
+};
+
+class ThenFrameNode : public TIRFrameNode {
+ public:
+ static constexpr const char* _type_key = "ir_builder.tir.ThenFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);
+
+ public:
+ void EnterWithScope() final;
+ void ExitWithScope() final;
+};
+
+class ThenFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
+};
+
+class ElseFrameNode : public TIRFrameNode {
+ public:
+ static constexpr const char* _type_key = "ir_builder.tir.ElseFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);
+
+ public:
+ void EnterWithScope() final;
+ void ExitWithScope() final;
+};
+
+class ElseFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
+};
+
+class DeclBufferFrameNode : public TIRFrameNode {
+ public:
+ tvm::tir::Buffer buffer;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ TIRFrameNode::VisitAttrs(v);
+ v->Visit("buffer", &buffer);
+ }
+
+ static constexpr const char* _type_key = "ir_builder.tir.DeclBufferFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);
+
+ public:
+ void ExitWithScope() final;
+};
+
+class DeclBufferFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
+};
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
+
+#endif // TVM_TIR_IR_BUILDER_FRAME_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 7236c6a611..09758cc923 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
/*!
* \brief Check if x is infinite.
* \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
Span span = Span());
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
/*!
* \brief Calculate floor(x)
* \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
Span span = Span());
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
@@ -701,6 +732,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(log2);
TVM_DECLARE_INTRIN_UNARY(log10);
+TVM_DECLARE_INTRIN_UNARY(log1p);
TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(tan);
TVM_DECLARE_INTRIN_UNARY(cos);
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4e847c0310..51dff32ac6 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -16,29 +16,46 @@
# under the License.
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
-from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .base import structural_equal, assert_structural_equal, structural_hash
-from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
-from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
-from .tensor_type import TensorType
-from .affine_type import TensorAffineType, TupleAffineType
-from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
-from .op import Op, register_op_attr, register_intrin_lowering
-from .function import CallingConv, BaseFunc
+from . import diagnostics, instrument, ir_builder, transform
from .adt import Constructor, TypeData
-from .module import IRModule
+from .affine_type import TensorAffineType, TupleAffineType
from .attrs import Attrs, DictAttrs, make_node
+from .base import (
+ EnvFunc,
+ Node,
+ SourceName,
+ Span,
+ assert_structural_equal,
+ load_json,
+ save_json,
+ structural_equal,
+ structural_hash,
+)
from .container import Array, Map
+from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
+from .function import BaseFunc, CallingConv
from .memory_pools import (
- PoolInfo,
- WorkspacePoolInfo,
- ConstantPoolInfo,
- WorkspaceMemoryPools,
ConstantMemoryPools,
+ ConstantPoolInfo,
+ PoolInfo,
PoolInfoProperties,
+ WorkspaceMemoryPools,
+ WorkspacePoolInfo,
)
-
-from . import transform
-from . import instrument
-from . import diagnostics
+from .module import IRModule
+from .op import Op, register_intrin_lowering, register_op_attr
+from .tensor_type import TensorType
+from .type import (
+ FuncType,
+ GlobalTypeVar,
+ IncompleteType,
+ PointerType,
+ PrimType,
+ RelayRefType,
+ TupleType,
+ Type,
+ TypeConstraint,
+ TypeKind,
+ TypeVar,
+)
+from .type_relation import TypeCall, TypeRelation
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/ir/_ffi_ir_builder_api.py
similarity index 88%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/ir/_ffi_ir_builder_api.py
index 926d17b166..9d08bc9b70 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/ir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
import tvm._ffi
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/ir/ir_builder.py b/python/tvm/ir/ir_builder.py
new file mode 100644
index 0000000000..db0b1ead6d
--- /dev/null
+++ b/python/tvm/ir/ir_builder.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A generic IRBuilder across the TVM stack"""
+from typing import List, TypeVar
+
+from tvm._ffi import register_object as _register_object
+from tvm.runtime import Object as _Object
+
+from . import _ffi_ir_builder_api as _ffi_api
+
+
+@_register_object("ir_builder.IRBuilderFrame")
+class IRBuilderFrame(_Object):
+ def __enter__(self) -> "IRBuilderFrame":
+ _ffi_api.IRBuilderFrameEnter(self) # pylint: disable=no-member # type: ignore
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
+ _ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type: ignore
+
+ def add_callback(self, callback) -> None: # pylint: disable=unused-argument
+ _ffi_api.IRBuilderFrameAddCallback( # pylint: disable=no-member # type: ignore
+ self, callback
+ )
+
+
+@_register_object("ir_builder.IRBuilder")
+class IRBuilder(_Object):
+ def __init__(self) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.IRBuilder # pylint: disable=no-member # type: ignore
+ )
+
+ def __enter__(self) -> "IRBuilder":
+ _ffi_api.IRBuilderEnter(self) # pylint: disable=no-member # type: ignore
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
+ _ffi_api.IRBuilderExit(self) # pylint: disable=no-member # type: ignore
+
+ @staticmethod
+ def current() -> "IRBuilder":
+ return _ffi_api.IRBuilderCurrent() # pylint: disable=no-member # type: ignore
+
+ def get(self) -> _Object:
+ return _ffi_api.IRBuilderGet(self) # pylint: disable=no-member # type: ignore
+
+
+DefType = TypeVar("DefType", bound=_Object)
+
+
+def name(s: str, v: DefType) -> DefType:
+ return _ffi_api.IRBuilderName(s, v) # pylint: disable=no-member # type: ignore
+
+
+def name_many( # pylint: disable=invalid-name
+ s: List[str],
+ vs: List[DefType],
+) -> List[DefType]:
+ assert len(s) == len(vs)
+ return [name(i, v) for i, v in zip(s, vs)]
+
+
+@_register_object("ir_builder.IRModuleFrame")
+class IRModuleFrame(IRBuilderFrame):
+ ...
+
+
+def ir_module() -> IRModuleFrame:
+ return _ffi_api.IRModule() # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index 58f82a248b..105b3467de 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -357,7 +357,7 @@ class _DefaultCUDATensorCore:
@staticmethod
def schedule_rules():
from tvm.meta_schedule import schedule_rule as M
- from tvm.tir.tensor_intrin import get_wmma_intrin_group
+ from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group
return [
M.MultiLevelTilingTensorCore(
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index f5a936f491..3d90030bcf 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -16,6 +16,7 @@
# under the License.
"""Default schedule rules"""
from typing import List, Union
+
from tvm.meta_schedule.schedule_rule import (
AddRFactor,
AutoBind,
@@ -27,8 +28,9 @@ from tvm.meta_schedule.schedule_rule import (
ReuseType,
ScheduleRule,
)
-from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore
-from tvm.tir import tensor_intrin
+from tvm.meta_schedule.schedule_rule.multi_level_tiling import (
+ MultiLevelTilingTensorCore,
+)
from tvm.target import Target
@@ -130,8 +132,12 @@ def multi_level_tiling_tensor_core(
trans_b = [trans_b]
if target.kind.name == "cuda":
+ from tvm.tir.tensor_intrin import ( # pylint: disable=import-outside-toplevel
+ cuda,
+ )
+
intrin_groups = [
- tensor_intrin.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
+ cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
for _in_dtype in in_dtype
for _out_dtype in out_dtype
for _trans_b in trans_b
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..8b132dcdf0 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,22 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+from . import parser, parser_v1
-from . import tir
+#############
+from .parser import ir as ir_v2
+from .parser import ir_module as ir_module_v2
+from .parser import parse as from_source_v2
+from .parser import tir as tir_v2
-from .parser import ir_module, from_source
+#############
+from .parser_v1 import from_source as from_source_v1
+from .parser_v1 import ir_module as ir_module_v1
+from .parser_v1 import tir as tir_v1
+
+# pylint: disable=invalid-name
+
+ir = ir_v2
+ir_module = ir_module_v2
+tir = tir_v2
+from_source = from_source_v2
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 77%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..d8530e0ab1 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,13 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
+# under the Licens.
+"""The parser"""
+from . import dispatch as _dispatch
+from . import doc as _doc
+from . import ir
+from . import parser as _parser
from . import tir
-
-from .parser import ir_module, from_source
+from .entry import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/parser/diagnostics.py b/python/tvm/script/parser/diagnostics.py
new file mode 100644
index 0000000000..bb4f05a254
--- /dev/null
+++ b/python/tvm/script/parser/diagnostics.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+from .source import Source
+
+
+class Diagnostics:
+
+ source: Source
+ ctx: diagnostics.DiagnosticContext
+
+ def __init__(self, source: Source):
+ mod = IRModule()
+ mod.source_map.add(source.source_name, source.full_source)
+ self.source = source
+ self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+ lineno = node.lineno or self.source.start_line
+ col_offset = node.col_offset or self.source.start_column
+ end_lineno = node.end_lineno or lineno
+ end_col_offset = node.end_col_offset or col_offset
+ lineno += self.source.start_line - 1
+ end_lineno += self.source.start_line - 1
+ col_offset += self.source.start_column + 1
+ end_col_offset += self.source.start_column + 1
+ self.ctx.emit(
+ diagnostics.Diagnostic(
+ level=level,
+ span=Span(
+ source_name=SourceName(self.source.source_name),
+ line=lineno,
+ end_line=end_lineno,
+ column=col_offset,
+ end_column=end_col_offset,
+ ),
+ message=message,
+ )
+ )
+
+ def error(self, node: doc.AST, message: str) -> None:
+ self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+ self.ctx.render()
diff --git a/python/tvm/script/parser/dispatch.py b/python/tvm/script/parser/dispatch.py
new file mode 100644
index 0000000000..f10b90961a
--- /dev/null
+++ b/python/tvm/script/parser/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+ from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+ """Register a method for a dispatch token and type name"""
+
+ def f(method: ParseMethod):
+ ParseVTable[(token, type_name)] = method
+
+ return f
+
+
+def get(
+ token: str,
+ type_name: str,
+ default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+ return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int): # pylint: disable=invalid-name
+ def f(method: OpMethod):
+ OpVTable[(ty, op, operand_index)] = method
+
+ return f
+
+
+def get_op( # pylint: disable=invalid-name
+ ty: Type,
+ op: Type,
+ operand_index: int,
+ default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+ return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/doc.py b/python/tvm/script/parser/doc.py
new file mode 100644
index 0000000000..f6a641cb64
--- /dev/null
+++ b/python/tvm/script/parser/doc.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import * # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+ to_doc: typing.Optional[FnToDoc]
+ from_doc: typing.Optional[FnFromDoc]
+
+ def __init__(self):
+ self.to_doc = None
+ self.from_doc = None
+
+
+class Registry:
+ _inst: typing.Optional["Registry"] = None
+ table: typing.Dict[str, Entry]
+
+ def __init__(self):
+ self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+ def f(to_doc: FnToDoc): # pylint: disable=redefined-outer-name
+ reg = Registry._inst # pylint: disable=protected-access
+ reg.table[name].to_doc = to_doc
+
+ return f
+
+
+def register_from_doc(name: str):
+ def f(to_doc: FnFromDoc): # pylint: disable=redefined-outer-name
+ reg = Registry._inst # pylint: disable=protected-access
+ reg.table[name].from_doc = to_doc
+
+ return f
+
+
+def _is_atomic_type(node):
+ return (
+ node is None
+ or node in [..., True, False]
+ or isinstance(
+ node,
+ (
+ int,
+ float,
+ str,
+ bool,
+ bytes,
+ complex,
+ ),
+ )
+ )
+
+
+def _get_registry_entry(cls_name, attr):
+ cls_name = cls_name.split(".")[-1]
+ reg = Registry._inst # pylint: disable=protected-access
+ if cls_name in reg.table:
+ entry = reg.table[cls_name]
+ return getattr(entry, attr, None)
+ return None
+
+
+def from_doc(node):
+ if _is_atomic_type(node):
+ return node
+ if isinstance(node, tuple):
+ return tuple(from_doc(n) for n in node)
+ if isinstance(node, list):
+ return [from_doc(n) for n in node]
+ func = _get_registry_entry(node.__class__.__name__, "from_doc")
+ if not func:
+ raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+ return func(node)
+
+
+def to_doc(node):
+ if _is_atomic_type(node):
+ return node
+ if isinstance(node, tuple):
+ return tuple(to_doc(n) for n in node)
+ if isinstance(node, list):
+ return [to_doc(n) for n in node]
+ func = _get_registry_entry(node.__class__.__name__, "to_doc")
+ if not func:
+ raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+ return func(node)
+
+
+def parse(
+ source,
+ filename="<unknown>",
+ mode="exec",
+) -> doc.AST:
+ try:
+ program = ast.parse( # pylint: disable=unexpected-keyword-arg
+ source=source,
+ filename=filename,
+ mode=mode,
+ feature_version=(3, 8),
+ )
+ except: # pylint: disable=bare-except
+ program = ast.parse(
+ source=source,
+ filename=filename,
+ mode=mode,
+ )
+ return to_doc(program)
+
+
+class NodeVisitor:
+ def visit(self, node: doc.AST) -> None:
+ if isinstance(node, (list, tuple)):
+ for item in node:
+ self.visit(item)
+ return
+ if not isinstance(node, doc.AST):
+ return
+ getattr(
+ self,
+ "visit_" + node.__class__.__name__.split(".")[-1],
+ self.generic_visit,
+ )(node)
+
+ def generic_visit(self, node: doc.AST) -> None:
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ value = getattr(node, field, None)
+ if value is None:
+ pass
+ elif isinstance(value, (doc.AST, list, tuple)):
+ self.visit(value)
+
+
+class NodeTransformer:
+ def visit(self, node: doc.AST) -> doc.AST:
+ if isinstance(node, list):
+ return [self.visit(item) for item in node]
+ if isinstance(node, tuple):
+ return tuple(self.visit(item) for item in node)
+ if not isinstance(node, doc.AST):
+ return node
+ return getattr(
+ self,
+ "visit_" + node.__class__.__name__.split(".")[-1],
+ self.generic_visit,
+ )(node)
+
+ def generic_visit(self, node: doc.AST) -> doc.AST:
+ kv: typing.Dict[str, typing.Any] = {}
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ value = getattr(node, field, None)
+ if value is None:
+ pass
+ elif isinstance(value, (doc.AST, list, tuple)):
+ value = self.visit(value)
+ kv[field] = value
+ return node.__class__(**kv)
+
+
+def _register_default():
+ class DefaultTranslator:
+ def __init__(self, doc_cls, func, fields):
+ self.doc_cls = doc_cls # getattr(doc, name)
+ self.func = func
+ self.fields = fields
+
+ def __call__(self, node):
+ kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+ return self.doc_cls(**kv)
+
+ Registry._inst = Registry() # pylint: disable=protected-access
+ for cls_name in dir(doc):
+ doc_cls = getattr(doc, cls_name)
+ if not hasattr(ast, cls_name):
+ continue
+ if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+ assert "." not in cls_name
+ register_to_doc(cls_name)(
+ DefaultTranslator(
+ getattr(doc, cls_name),
+ to_doc,
+ doc_cls._FIELDS, # pylint: disable=protected-access
+ )
+ )
+ register_from_doc(cls_name)(
+ DefaultTranslator(
+ getattr(ast, cls_name),
+ from_doc,
+ doc_cls._FIELDS, # pylint: disable=protected-access
+ )
+ )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+ return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+ if _py_version() not in [(3, 6), (3, 7)]:
+ return
+
+ def as_constant(f) -> doc.Constant:
+ def to_doc_func(x: ast.AST) -> doc.Constant:
+ return doc.Constant(
+ value=getattr(x, f) if isinstance(f, str) else f(x),
+ kind=None,
+ s=None,
+ n=None,
+ lineno=x.lineno,
+ col_offset=x.col_offset,
+ end_lineno=x.lineno,
+ end_col_offset=x.col_offset,
+ )
+
+ return to_doc_func
+
+ register_to_doc("Str")(as_constant("s"))
+ register_to_doc("NameConstant")(as_constant("value"))
+ register_to_doc("Num")(as_constant("n"))
+ register_to_doc("Bytes")(as_constant("s"))
+ register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+ if _py_version() >= (3, 9):
+ return
+
+ def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+ if isinstance(x.slice, ast.Slice):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=doc.Slice(
+ lower=to_doc(x.slice.lower),
+ upper=to_doc(x.slice.upper),
+ step=to_doc(x.slice.step),
+ lineno=getattr(x.slice, "lineno", None),
+ col_offset=getattr(x.slice, "col_offset", None),
+ end_lineno=getattr(x.slice, "end_lineno", None),
+ end_col_offset=getattr(x.slice, "end_col_offset", None),
+ ),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ if isinstance(x.slice, ast.ExtSlice):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=doc.Tuple(
+ elts=[to_doc(i) for i in x.slice.dims],
+ ctx=doc.Load(
+ lineno=None,
+ col_offset=None,
+ end_lineno=None,
+ end_col_offset=None,
+ ),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ ),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ if isinstance(x.slice, ast.Index):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=to_doc(x.slice.value),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+ def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+ if isinstance(x.slice, doc.Slice):
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=from_doc(x.slice),
+ ctx=from_doc(x.ctx),
+ )
+ elif isinstance(x.slice, doc.Tuple):
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=ast.ExtSlice(
+ dims=[from_doc(i) for i in x.slice.elts],
+ ),
+ ctx=from_doc(x.ctx),
+ )
+ else:
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=ast.Index(value=from_doc(x.slice)),
+ ctx=from_doc(x.ctx),
+ )
+ result.lineno = x.lineno
+ result.col_offset = x.col_offset
+ result.end_lineno = x.end_lineno
+ result.end_col_offset = x.end_col_offset
+ return result
+
+ register_to_doc("Subscript")(subscript_to_doc)
+ register_from_doc("Subscript")(subscript_from_doc)
+
+
+def _register_index_handling():
+ if _py_version() >= (3, 9):
+ return
+
+ def index_to_doc(x: ast.Index) -> doc.Expr:
+ return to_doc(x.value)
+
+ def index_from_doc(x: doc.Expr) -> ast.Index:
+ result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx))
+ result.lineno = x.lineno
+ result.col_offset = x.col_offset
+ result.end_lineno = x.end_lineno
+ result.end_col_offset = x.end_col_offset
+ return result
+
+ register_to_doc("Index")(index_to_doc)
+ register_from_doc("Index")(index_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
+_register_index_handling()
diff --git a/python/tvm/script/printer/doc_core.py b/python/tvm/script/parser/doc_core.py
similarity index 100%
rename from python/tvm/script/printer/doc_core.py
rename to python/tvm/script/parser/doc_core.py
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/entry.py
similarity index 51%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/entry.py
index 923eb97d27..b70e876d43 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/entry.py
@@ -14,32 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
-import inspect
-from typing import Callable
+from tvm.ir.ir_builder import IRBuilder
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+from . import doc
+from .parser import Parser
+from .source import Source
-def prim_func(input_func: Callable) -> PrimFunc:
- """Decorate a python function as tvm script.
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+ if extra_vars is None:
+ from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
+ from tvm.script.parser import tir # pylint: disable=import-outside-toplevel
- Parameters
- ----------
- func : input_func
- The function to be parsed.
+ extra_vars = {
+ "I": ir,
+ "ir": ir,
+ "T": tir,
+ "tir": tir,
+ }
- Returns
- -------
- output : PrimFunc
- The result functions.
- """
- if inspect.isfunction(input_func):
- result = from_source(input_func)
- result.__name__ = input_func.__name__
- result.__qualname__ = input_func.__qualname__
- return result
-
- raise TypeError("Only function definitions are supported.")
+ source = Source(program)
+ parser = Parser(source)
+ with IRBuilder() as builder:
+ parser.parse(extra_vars=extra_vars)
+ return builder.get()
diff --git a/python/tvm/script/parser/evaluator.py b/python/tvm/script/parser/evaluator.py
new file mode 100644
index 0000000000..3899531b21
--- /dev/null
+++ b/python/tvm/script/parser/evaluator.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+ from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+ doc.Add: lambda a, b: a + b,
+ doc.Sub: lambda a, b: a - b,
+ doc.Mult: lambda a, b: a * b,
+ doc.Div: lambda a, b: a / b,
+ doc.FloorDiv: lambda a, b: a // b,
+ doc.Mod: lambda a, b: a % b,
+ doc.LShift: lambda a, b: a << b,
+ doc.RShift: lambda a, b: a >> b,
+ doc.BitOr: lambda a, b: a | b,
+ doc.BitXor: lambda a, b: a ^ b,
+ doc.BitAnd: lambda a, b: a & b,
+ doc.MatMult: lambda a, b: a @ b,
+ doc.Pow: lambda a, b: a**b,
+ doc.Eq: lambda a, b: a == b,
+ doc.NotEq: lambda a, b: a != b,
+ doc.Lt: lambda a, b: a < b,
+ doc.LtE: lambda a, b: a <= b,
+ doc.Gt: lambda a, b: a > b,
+ doc.GtE: lambda a, b: a >= b,
+ doc.Is: lambda a, b: a is b,
+ doc.IsNot: lambda a, b: a is not b,
+ doc.In: lambda a, b: a in b,
+ doc.NotIn: lambda a, b: a not in b,
+ doc.And: lambda a, b: a and b,
+ doc.Or: lambda a, b: a or b,
+ doc.Invert: lambda a: ~a,
+ doc.Not: lambda a: not a,
+ doc.UAdd: lambda a: +a,
+ doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+ parser: "Parser"
+ value_table: Dict[str, Any]
+ new_value_count: int
+
+ def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+ super().__init__()
+ self.parser = parser
+ self.value_table = value_table
+ self.new_value_count = 0
+
+ @staticmethod
+ def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+ self = ExprEvaluator(parser, value_table)
+ result = self._visit(node) # pylint: disable=protected-access
+ if isinstance(result, doc.Name):
+ if result.id not in self.value_table:
+ self.parser.report_error(result, f"Undefined variable: {result.id}")
+ return self.value_table[result.id]
+ if isinstance(result, doc.Constant):
+ return result.value
+ raise TypeError(f"Unexpected result type: {type(result)}")
+
+ def _add_intermediate_result(self, value: Any) -> doc.Name:
+ name = f"__tvm_tmp_value_{self.new_value_count}"
+ self.new_value_count += 1
+ self.value_table[name] = value
+ lineno = 0
+ col_offset = 0
+ return doc.Name(
+ id=name,
+ ctx=doc.Load(
+ lineno=lineno,
+ col_offset=col_offset,
+ end_lineno=None,
+ end_col_offset=None,
+ ),
+ lineno=lineno,
+ col_offset=col_offset,
+ end_lineno=None,
+ end_col_offset=None,
+ )
+
+ def _visit(self, node: doc.AST) -> Any:
+ if isinstance(node, list):
+ return [self._visit(n) for n in node]
+ if isinstance(node, tuple):
+ return tuple(self._visit(n) for n in node)
+ assert isinstance(node, doc.AST)
+ if isinstance(node, doc.Name):
+ if node.id not in self.value_table:
+ self.parser.report_error(node, f"Undefined variable: {node.id}")
+ return node
+ if isinstance(
+ node,
+ (
+ doc.Constant,
+ doc.expr_context,
+ doc.operator,
+ doc.boolop,
+ doc.unaryop,
+ doc.cmpop,
+ ),
+ ):
+ return node
+ if not isinstance(node, (doc.expr, doc.slice)):
+ return node
+ if isinstance(node, doc.Lambda):
+ return self._eval_lambda(node)
+ fields = {}
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ attr = getattr(node, field)
+ if isinstance(attr, (doc.AST, tuple, list)):
+ fields[field] = self._visit(attr)
+ else:
+ fields[field] = attr
+ try:
+ if isinstance(node, doc.BoolOp):
+ value = self._eval_bool_op(fields)
+ elif isinstance(node, doc.Compare):
+ value = self._eval_compare(fields)
+ elif isinstance(node, doc.UnaryOp):
+ value = self._eval_unary_op(fields)
+ elif isinstance(node, doc.BinOp):
+ value = self._eval_bin_op(fields)
+ elif isinstance(node, doc.Slice):
+ value = self._eval_slice(fields)
+ else:
+ value = self._eval_expr(node.__class__(**fields))
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.parser.report_error(node, str(e))
+ return self._add_intermediate_result(value)
+
+ def _eval_lambda(self, node: doc.Lambda) -> Any:
+ try:
+ value = self._eval_expr(node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.parser.report_error(node, str(e))
+ return self._add_intermediate_result(value)
+
+ def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+ op = fields["op"]
+ if not isinstance(op, (doc.And, doc.Or)):
+ raise TypeError(f"Unexpected operator: {op}")
+ value = self._eval_expr(fields["values"][0])
+ for rhs in fields["values"][1:]:
+ value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+ return value
+
+ def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+ value = self._eval_expr(fields["left"])
+ for op, rhs in zip(fields["ops"], fields["comparators"]):
+ value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+ return value
+
+ def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+ value = self._eval_expr(fields["operand"])
+ value = _eval_op(fields["op"], values=[value])
+ return value
+
+ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+ return _eval_op(
+ fields["op"],
+ values=[
+ self._eval_expr(fields["left"]),
+ self._eval_expr(fields["right"]),
+ ],
+ )
+
+ def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+ lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+ lower = self._eval_expr(lower) if lower is not None else None
+ upper = self._eval_expr(upper) if upper is not None else None
+ step = self._eval_expr(step) if step is not None else None
+
+ return slice(lower, upper, step)
+
+ def _eval_expr(self, v: Any) -> Any:
+ return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+ parser: "Parser",
+ node: Union[doc.expr, doc.Expression],
+ dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+ value_table = {}
+ if dict_globals is not None:
+ value_table.update(dict_globals)
+ return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+ parser: "Parser",
+ target: doc.expr,
+ source: Any,
+) -> Dict[str, Any]:
+ try:
+ return _eval_assign(target, source)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+ raise
+
+
+def _eval_expr(
+ node: Union[doc.expr, doc.Expression],
+ dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+ node = doc.from_doc(node)
+ if isinstance(node, ast.expr):
+ node = ast.Expression(body=node)
+ assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+ if dict_globals is None:
+ dict_globals = {}
+ node = ast.fix_missing_locations(node)
+ exe = compile(node, filename="<ast>", mode="eval")
+ return eval(exe, dict_globals) # pylint: disable=eval-used
+
+
+def _eval_op(
+ op: doc.AST,
+ values: List[Any],
+):
+ op_type = type(op) # pylint: disable=protected-access
+ for i, v in enumerate(values):
+ v_type = getattr(type(v), "_dispatch_type", None)
+ if v_type is None:
+ continue
+ f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+ if f is not None:
+ return f(*values)
+ return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+ target: doc.expr,
+ source: Any,
+) -> Dict[str, Any]:
+ target = doc.from_doc(target)
+ assert isinstance(target, ast.expr)
+ RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name
+ rhs_var_name = RHS_VAR_NAME
+ dict_locals = {rhs_var_name: source}
+ mod = ast.fix_missing_locations(
+ ast.Module(
+ body=[
+ ast.Assign(
+ targets=[target],
+ value=ast.Name(
+ id=rhs_var_name,
+ ctx=ast.Load(),
+ ),
+ )
+ ],
+ type_ignores=[],
+ )
+ )
+ exe = compile(mod, filename="<ast>", mode="exec")
+ exec(exe, {}, dict_locals) # pylint: disable=exec-used
+ del dict_locals[rhs_var_name]
+ return dict_locals
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser/ir/__init__.py
similarity index 89%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser/ir/__init__.py
index 926d17b166..bea08cfb1b 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script"""
-import tvm._ffi
-
-tvm._ffi._init_api("script", __name__)
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/entry.py
similarity index 66%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/entry.py
index 555659d0c5..353963f29b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,8 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Type
-from . import tir
+from tvm.ir import IRModule
-from .parser import ir_module, from_source
+from ..entry import parse
+from ..utils import inspect_class_capture
+
+
+def ir_module(f: Type) -> IRModule:
+ if not inspect.isclass(f):
+ raise TypeError(f"Expect a class, but got: {f}")
+
+ return parse(f, inspect_class_capture(f))
+
+
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 55%
copy from python/tvm/script/tir/__init__.py
copy to python/tvm/script/parser/ir/parser.py
index 2f2b4bbc25..aec203c7d9 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,18 +14,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+from tvm.ir import ir_builder as I
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
+from .. import dispatch, doc
+from ..parser import Parser
-from .prim_func import prim_func
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
- for _size in ["8", "16", "32", "64"]:
- for _lanes in ["", "x4", "x8", "x16", "x32"]:
- from . import ty
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+ with self.var_table.with_frame():
+ with I.ir_module():
+ with self.with_dispatch_token("ir"):
+ self.visit_body(node.body)
- _name = _dtype + _size + _lanes
- globals()[_name] = getattr(ty, _name)
+
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+ pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+ pass
diff --git a/python/tvm/script/parser/parser.py b/python/tvm/script/parser/parser.py
new file mode 100644
index 0000000000..a89cd10fad
--- /dev/null
+++ b/python/tvm/script/parser/parser.py
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from typing import Any, Callable, Dict, List, Optional, Set, Union
+
+from ...error import DiagnosticError
+from . import dispatch, doc
+from .diagnostics import Diagnostics
+from .evaluator import eval_assign, eval_expr
+from .source import Source
+from .utils import deferred
+from .var_table import VarTable
+
+DEFAULT_VISIT = {
+ "Interactive",
+ "Module",
+ "Expression",
+ "Pass",
+}
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+ def _wrapper(self: "Parser", node: doc.AST) -> None:
+ try:
+ return func(self, node)
+ except DiagnosticError:
+ raise
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+ raise
+
+ return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+ for token in [self.dispatch_tokens[-1], "default"]:
+ func = dispatch.get(token=token, type_name=type_name, default=None)
+ if func is not None:
+ return _dispatch_wrapper(func)
+ return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+ """The TVMScript parser"""
+
+ diag: Diagnostics
+ dispatch_tokens: List[str]
+ var_table: VarTable
+
+ def __init__(self, source: Source) -> None:
+ self.diag = Diagnostics(source)
+ self.dispatch_tokens = ["default"]
+ self.var_table = VarTable()
+
+ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+ if extra_vars is None:
+ extra_vars = {}
+ with self.var_table.with_frame():
+ for k, v in extra_vars.items():
+ self.var_table.add(k, v)
+ node = self.diag.source.as_ast()
+ self.visit(node)
+
+ def with_dispatch_token(self, token: str):
+ def pop_token():
+ self.dispatch_tokens.pop()
+
+ self.dispatch_tokens.append(token)
+ return deferred(pop_token)
+
+ def eval_expr(
+ self,
+ node: Union[doc.Expression, doc.expr],
+ extra_vars: Optional[Dict[str, Any]] = None,
+ ) -> Any:
+ var_values = self.var_table.get()
+ if extra_vars is not None:
+ for k, v in extra_vars.items():
+ var_values[k] = v
+ return eval_expr(self, node, var_values)
+
+ def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
+ if isinstance(target, (doc.Tuple, doc.List)):
+ vars: Set[str] = set() # pylint: disable=redefined-builtin
+ for i in target.elts:
+ res = self._duplicate_lhs_check(i)
+ if isinstance(res, bool) and res:
+ return True
+ assert isinstance(res, set)
+ if vars & res:
+ return True
+ vars = vars.union(res)
+ return vars
+ elif isinstance(target, doc.Name):
+ return {target.id}
+ else:
+ self.report_error(target, "Invalid type in assign statement")
+ raise NotImplementedError
+
+ def eval_assign(
+ self,
+ target: doc.expr,
+ source: Any,
+ bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+ ) -> Dict[str, Any]:
+ if self._duplicate_lhs_check(target) is True:
+ self.report_error(target, "Duplicate vars assigned.")
+ var_values = eval_assign(self, target, source)
+ for k, v in var_values.items():
+ var = bind_value(self, target, k, v)
+ self.var_table.add(k, var)
+ return var_values
+
+ def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use
+ self.diag.error(node, msg)
+
+ def visit(self, node: doc.AST) -> None:
+ if isinstance(node, (list, tuple)):
+ for item in node:
+ self.visit(item)
+ return
+ if not isinstance(node, doc.AST):
+ return
+ name = node.__class__.__name__.split(".")[-1]
+ if name in DEFAULT_VISIT:
+ func = self.generic_visit
+ else:
+ func = getattr(self, "visit_" + name, None)
+ if func is None:
+ raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+ try:
+ func(node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+
+ def visit_body(self, node: List[doc.stmt]) -> Any:
+ for stmt in node:
+ self.visit(stmt)
+
+ def visit_tvm_annotation(self, node: doc.expr) -> Any:
+ return _dispatch(self, "tvm_annotation")(self, node)
+
+ def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
+ if not node.decorator_list:
+ self.report_error(node, "Function must be decorated")
+ # TODO: only the last decorator is parsed
+ decorator = self.eval_expr(node.decorator_list[-1])
+ if not hasattr(decorator, "dispatch_token"):
+ self.report_error(node, "The parser does not understand the decorator")
+ token = decorator.dispatch_token
+ func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+ if func is None:
+ self.report_error(node, "The parser does not understand the decorator")
+ try:
+ func(self, node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+
+ def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
+ func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+ if func is None:
+ self.report_error(node, "The parser does not understand the decorator")
+ try:
+ func(self, node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+
+ def visit_arguments(self, node: doc.arguments) -> Any:
+ return _dispatch(self, "arguments")(self, node)
+
+ def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "For")(self, node)
+
+ def visit_While(self, node: doc.While) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "While")(self, node)
+
+ def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "With")(self, node)
+
+ def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Assign")(self, node)
+
+ def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Expr")(self, node)
+
+ def visit_If(self, node: doc.If) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "If")(self, node)
+
+ def visit_AnnAssign(self, node: doc.AnnAssign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "AnnAssign")(self, node)
+
+ def visit_AugAssign(self, node: doc.AugAssign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "AugAssign")(self, node)
+
+ def visit_Assert(self, node: doc.Assert) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Assert")(self, node)
+
+ def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Return")(self, node)
diff --git a/python/tvm/script/parser/source.py b/python/tvm/script/parser/source.py
new file mode 100644
index 0000000000..a7a436d568
--- /dev/null
+++ b/python/tvm/script/parser/source.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name
+import inspect
+import re
+import sys
+from typing import Union
+
+from . import doc
+
+
+class Source:
+ source_name: str
+ start_line: int
+ start_column: int
+ source: str
+ full_source: str
+
+ def __init__(self, program: Union[str, doc.AST]):
+ if isinstance(program, str):
+ self.source_name = "<str>"
+ self.start_line = 1
+ self.start_column = 0
+ self.source = program
+ self.full_source = program
+ return
+
+ self.source_name = inspect.getsourcefile(program) # type: ignore
+ lines, self.start_line = getsourcelines(program) # type: ignore
+ if lines:
+ self.start_column = len(lines[0]) - len(lines[0].lstrip())
+ else:
+ self.start_column = 0
+ if self.start_column and lines:
+ self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+ else:
+ self.source = "".join(lines)
+ try:
+ # It will cause a problem when running in Jupyter Notebook.
+ # `mod` will be <module '__main__'>, which is a built-in module
+ # and `getsource` will throw a TypeError
+ mod = inspect.getmodule(program)
+ if mod:
+ self.full_source = inspect.getsource(mod)
+ else:
+ self.full_source = self.source
+ except TypeError:
+ # It's a work around for Jupyter problem.
+ # Since `findsource` is an internal API of inspect, we just use it
+ # as a fallback method.
+ src, _ = inspect.findsource(program) # type: ignore
+ self.full_source = "".join(src)
+
+ def as_ast(self) -> doc.AST:
+ return doc.parse(self.source)
+
+
+_getfile = inspect.getfile # pylint: disable=invalid-name
+_findsource = inspect.findsource # pylint: disable=invalid-name
+
+
+def _patched_inspect_getfile(obj):
+ if not inspect.isclass(obj):
+ return _getfile(obj)
+ mod = getattr(obj, "__module__", None)
+ if mod is not None:
+ file = getattr(sys.modules[mod], "__file__", None)
+ if file is not None:
+ return file
+ for _, member in inspect.getmembers(obj):
+ if inspect.isfunction(member):
+ if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+ return inspect.getfile(member)
+ raise TypeError(f"Source for {obj:!r} not found")
+
+
+def findsource(obj):
+ import linecache # pylint: disable=import-outside-toplevel
+
+ if not inspect.isclass(obj):
+ return _findsource(obj)
+
+ file = inspect.getsourcefile(obj)
+ if file:
+ linecache.checkcache(file)
+ else:
+ file = inspect.getfile(obj)
+ if not (file.startswith("<") and file.endswith(">")):
+ raise OSError("source code not available")
+
+ module = inspect.getmodule(obj, file)
+ if module:
+ lines = linecache.getlines(file, module.__dict__)
+ else:
+ lines = linecache.getlines(file)
+ if not lines:
+ raise OSError("could not get source code")
+ qual_name = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
+ pat_list = []
+ for qn in qual_name:
+ if qn.endswith("<locals>"):
+ pat_list.append(re.compile(r"^(\s*)def\s*" + qn[:-8] + r"\b"))
+ else:
+ pat_list.append(re.compile(r"^(\s*)class\s*" + qn + r"\b"))
+ for i, line in enumerate(lines):
+ match = pat_list[0].match(line)
+ if match:
+ pat_list.pop(0)
+ if not pat_list:
+ return lines, i
+ raise OSError("could not find class definition")
+
+
+def getsourcelines(obj):
+ obj = inspect.unwrap(obj)
+ lines, l_num = findsource(obj)
+ return inspect.getblock(lines[l_num:]), l_num + 1
+
+
+inspect.getfile = _patched_inspect_getfile
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 78%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..caa51744c8 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from tvm.tir.ir_builder_v2 import * # pylint: disable=redefined-builtin
-from . import tir
-
-from .parser import ir_module, from_source
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..4a0c7c40fb
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+from tvm.tir.ir_builder_v2 import buffer_decl, ptr
+
+from ..entry import parse
+from ..utils import inspect_function_capture
+
+
+def _is_defined_in_class(frames):
+ if len(frames) > 2:
+ maybe_class_frame = frames[2]
+ statement_list = maybe_class_frame[4]
+ if statement_list is None:
+ return False
+ first_statement = statement_list[0]
+ line = first_statement.strip()
+ if line.startswith("class "):
+ return True
+ if line.startswith("@") and "ir_module" in line:
+ return True
+ return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+ if not inspect.isfunction(f):
+ raise TypeError(f"Expect a function, but got: {f}")
+ if _is_defined_in_class(inspect.stack()):
+ return f
+ return parse(f, inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+ def __call__(
+ self,
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+ ) -> Buffer:
+ return buffer_decl(
+ shape,
+ dtype=dtype,
+ data=data,
+ strides=strides,
+ elem_offset=elem_offset,
+ scope=scope,
+ align=align,
+ offset_factor=offset_factor,
+ buffer_type=buffer_type,
+ axis_separators=axis_separators,
+ )
+
+ def __getitem__(self, keys) -> Buffer:
+ if not isinstance(keys, tuple):
+ return self(keys)
+ if len(keys) >= 2 and not isinstance(keys[1], str):
+ return self(keys)
+ return self(*keys) # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+ def __call__(self, dtype, storage_scope="global"):
+ if callable(dtype):
+ dtype = dtype().dtype
+ return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
+
+ def __getitem__(self, keys):
+ if not isinstance(keys, tuple):
+ return self(keys)
+ return self(*keys)
+
+
+Buffer = BufferProxy() # pylint: disable=invalid-name
+Ptr = PtrProxy() # pylint: disable=invalid-name
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..11ee92ad29
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .. import doc
+from ..dispatch import OpMethod, register_op
+
+
+def _register_expr_op(ty: Type): # pylint: disable=invalid-name
+ ty._dispatch_type = ty # pylint: disable=protected-access
+
+ def _and(a, b):
+ if isinstance(a, bool):
+ a = IntImm("bool", a)
+ if isinstance(b, bool):
+ b = IntImm("bool", b)
+ return tir.And(a, b)
+
+ def _or(a, b):
+ if isinstance(a, bool):
+ a = IntImm("bool", a)
+ if isinstance(b, bool):
+ b = IntImm("bool", b)
+ return tir.Or(a, b)
+
+ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
+ register_op(ty, op, i)(m)
+
+ for i in [0, 1]:
+ # Case 1. binop
+ r(doc.Add, i, tir.Add)
+ r(doc.Sub, i, tir.Sub)
+ r(doc.Mult, i, tir.Mul)
+ r(doc.Div, i, tir.Div)
+ r(doc.FloorDiv, i, tir.FloorDiv)
+ r(doc.Mod, i, tir.FloorMod)
+ r(doc.LShift, i, lambda a, b: a << b)
+ r(doc.RShift, i, lambda a, b: a >> b)
+ r(doc.BitOr, i, lambda a, b: a | b)
+ r(doc.BitXor, i, lambda a, b: a ^ b)
+ r(doc.BitAnd, i, lambda a, b: a & b)
+ # doc.MatMult <-- not implemented
+ # doc.Pow <-- not implemented
+ # Case 2. cmpop
+ r(doc.Eq, i, tir.EQ)
+ r(doc.NotEq, i, tir.NE)
+ r(doc.Lt, i, tir.LT)
+ r(doc.LtE, i, tir.LE)
+ r(doc.Gt, i, tir.GT)
+ r(doc.GtE, i, tir.GE)
+ # doc.Is <-- not implemented
+ # doc.IsNot <-- not implemented
+ # doc.In <-- not implemented
+ # doc.NotIn <-- not implemented
+ # Case 3. boolop
+ r(doc.And, i, _and)
+ r(doc.Or, i, _or)
+ for i in [0]:
+ # Case 4. unaryop
+ r(doc.Invert, i, lambda a: ~a)
+ r(doc.Not, i, tir.Not)
+ r(doc.UAdd, i, lambda a: +a)
+ r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..38973f6de2
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,269 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.ir.ir_builder import IRBuilderFrame as Frame
+from tvm.ir.ir_builder import name
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+from tvm.tir import ir_builder_v2 as T
+
+from .. import dispatch, doc
+from ..parser import Parser
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, (Buffer, Var)):
+ name(var_name, value)
+ return value
+ else:
+ self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+ raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, Var):
+ name(var_name, value)
+ return value
+ else:
+ self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+ raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, T.inline):
+ return value.value
+ elif isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, _node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, Frame):
+ value.add_callback(partial(value.__exit__, None, None, None))
+ res = value.__enter__()
+ name(var_name, res)
+ return res
+ elif isinstance(value, (Buffer, IterVar)) or (
+ isinstance(value, Var) and not self.var_table.exist(value)
+ ):
+ name(var_name, value)
+ return value
+ elif isinstance(value, PrimExpr):
+ var = T.var(value.dtype)
+ name(var_name, var)
+ frame = T.let(var, value)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+ return var
+ return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+ for_frame = self.eval_expr(node.iter)
+ if not isinstance(for_frame, T.frame.ForFrame):
+ self.report_error(
+ node.iter,
+ "Expect the for loop to be one of the following: "
+ "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+ )
+ with self.var_table.with_frame():
+ with for_frame as iters:
+ self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+ with self.var_table.with_frame():
+ cond = self.eval_expr(node.test)
+ with T.While(cond):
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+ if len(node.targets) != 1:
+ self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+ lhs = node.targets[0]
+ rhs = self.eval_expr(node.value)
+ if isinstance(lhs, doc.Subscript):
+ if isinstance(lhs.slice, doc.Tuple):
+ indices = []
+ for index in lhs.slice.elts:
+ indices.append(self.eval_expr(index))
+ else:
+ indices = [self.eval_expr(lhs.slice)]
+ T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+ else:
+ self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+ lhs_pos = (
+ node.target.lineno,
+ node.target.col_offset,
+ node.target.end_lineno,
+ node.target.end_col_offset,
+ )
+ rhs_pos = (
+ node.value.lineno,
+ node.value.col_offset,
+ node.value.end_lineno,
+ node.value.end_col_offset,
+ )
+ node.target.ctx = doc.Load(*lhs_pos)
+ with self.var_table.with_frame():
+ lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+ rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+ lhs_expr = self.eval_expr(node.target)
+ rhs_expr = self.eval_expr(node.value)
+ self.var_table.add(lhs_name, lhs_expr)
+ self.var_table.add(rhs_name, rhs_expr)
+ op = doc.BinOp(
+ doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+ node.op,
+ doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+ *lhs_pos,
+ )
+ rhs = self.eval_expr(op)
+ lhs = node.target
+ lhs.ctx = doc.Store(*lhs_pos)
+ if isinstance(lhs, doc.Subscript):
+ if isinstance(lhs.slice, doc.Tuple):
+ indices = []
+ for index in lhs.slice.elts:
+ indices.append(self.eval_expr(index))
+ else:
+ indices = [self.eval_expr(lhs.slice)]
+ T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+ else:
+ self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+ lhs = node.target
+ rhs = self.eval_expr(node.value)
+ ann_var = self.visit_tvm_annotation(node.annotation)
+ if not isinstance(ann_var, Var):
+ self.report_error(node.annotation, "Annotation should be Var")
+ self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+ frame = T.let(ann_var, rhs)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+ with contextlib.ExitStack() as stack:
+ stack.enter_context(self.var_table.with_frame())
+ for item in node.items:
+ frame = self.eval_expr(item.context_expr)
+ if not isinstance(frame, Frame):
+ self.report_error(
+ item.context_expr, "Invalid context expression in the with-statement."
+ )
+ rhs = stack.enter_context(frame)
+ if item.optional_vars is not None:
+ self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+ with self.var_table.with_frame():
+ self.var_table.add("range", T.serial)
+ with T.prim_func():
+ T.func_name(node.name)
+ if node.returns is not None:
+ ret_type = self.eval_expr(node.returns)
+ if callable(ret_type):
+ ret_type = PrimType(ret_type().dtype)
+ T.func_ret(ret_type)
+ with self.with_dispatch_token("tir"):
+ self.visit(node.args)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+ # TODO: handle different types of arguments:
+ # - vararg: arg | None
+ # - kwonlyargs: list[arg]
+ # - kw_defaults: list[expr | None]
+ # - kwarg: arg | None
+ # - defaults: list[expr]
+ # - posonlyargs: list[arg]
+ arg: doc.arg
+ for arg in node.args:
+ if arg.annotation is None:
+ self.report_error(arg, "Type annotation is required for function parameters.")
+ param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+ self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+ annotation = self.eval_expr(node)
+ if callable(annotation):
+ annotation = annotation()
+ return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+ res = self.eval_expr(node.value)
+ if isinstance(res, Frame):
+ res.add_callback(partial(res.__exit__, None, None, None))
+ res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+ with self.var_table.with_frame():
+ with T.If(self.eval_expr(node.test)):
+ with T.Then():
+ self.visit_body(node.body)
+ if node.orelse:
+ with T.Else():
+ self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+ cond = self.eval_expr(node.test)
+ msg = self.eval_expr(node.msg)
+ frame = T.Assert(cond, msg)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="Return")
+def visit_return(self: Parser, node: doc.Return) -> None:
+ self.report_error(node, "Return is not allowed.")
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/utils.py
similarity index 52%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/utils.py
index 923eb97d27..4c08a381c0 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/utils.py
@@ -14,32 +14,35 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script Interface for PrimFunc"""
-
+# pylint: disable=missing-docstring
import inspect
-from typing import Callable
+from contextlib import contextmanager
+from typing import Any, Callable, Dict
+
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+def deferred(f: Callable[[], None]):
+ @contextmanager
+ def context():
+ try:
+ yield
+ finally:
+ f()
+ return context()
-def prim_func(input_func: Callable) -> PrimFunc:
- """Decorate a python function as tvm script.
- Parameters
- ----------
- func : input_func
- The function to be parsed.
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+ captured = {
+ **inspect.getclosurevars(func).nonlocals,
+ **func.__globals__,
+ }
+ return captured
- Returns
- -------
- output : PrimFunc
- The result functions.
- """
- if inspect.isfunction(input_func):
- result = from_source(input_func)
- result.__name__ = input_func.__name__
- result.__qualname__ = input_func.__qualname__
- return result
- raise TypeError("Only function definitions are supported.")
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+ result: Dict[str, Any] = {}
+ for _, v in cls.__dict__.items():
+ if inspect.isfunction(v):
+ func_vars = inspect_function_capture(v)
+ result.update(**func_vars)
+ return result
diff --git a/python/tvm/script/parser/var_table.py b/python/tvm/script/parser/var_table.py
new file mode 100644
index 0000000000..32fced625a
--- /dev/null
+++ b/python/tvm/script/parser/var_table.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The symbol table of variable values"""
+
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Set
+
+from .utils import deferred
+
+
+class VarTableFrame:
+ vars: Set[str]
+
+ def __init__(self):
+ self.vars = set()
+
+ def add(self, var: str):
+ if var in self.vars:
+ raise ValueError(f"Variable {var} already defined in current scope")
+ self.vars.add(var)
+
+ def pop_all(self, fn_pop: Callable[[str], None]):
+ for var in self.vars:
+ fn_pop(var)
+ self.vars.clear()
+
+
+class VarTable:
+
+ frames: List[VarTableFrame]
+ name2value: Dict[str, List[Any]]
+
+ def __init__(self):
+ self.frames = []
+ self.name2value = defaultdict(list)
+
+ def with_frame(self):
+ def pop_frame():
+ frame = self.frames.pop()
+ frame.pop_all(lambda name: self.name2value[name].pop())
+
+ self.frames.append(VarTableFrame())
+ return deferred(pop_frame)
+
+ def add(self, var: str, value: Any):
+ self.frames[-1].add(var)
+ self.name2value[var].append(value)
+
+ def get(self) -> Dict[str, Any]:
+ return {key: values[-1] for key, values in self.name2value.items() if values}
+
+ def exist(self, value: Any):
+ for v in self.name2value.values():
+ if v is value:
+ return True
+ return False
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser_v1/__init__.py
similarity index 95%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser_v1/__init__.py
index 555659d0c5..004e947bf6 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser_v1/__init__.py
@@ -17,5 +17,4 @@
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
from . import tir
-
-from .parser import ir_module, from_source
+from .parser import from_source, ir_module
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser_v1/_ffi_api.py
similarity index 100%
copy from python/tvm/script/_ffi_api.py
copy to python/tvm/script/parser_v1/_ffi_api.py
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py
similarity index 98%
rename from python/tvm/script/context_maintainer.py
rename to python/tvm/script/parser_v1/context_maintainer.py
index f7f16855c7..400baacc4b 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/parser_v1/context_maintainer.py
@@ -16,16 +16,16 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
+from typing import Callable, Dict, List, Mapping, Optional, Union
+import synr
import tvm
from tvm.ir import Span
from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
+from tvm.tir import Buffer, MatchBufferRegion, PrimExpr, Stmt, Var
from tvm.tir.expr import IterVar
+
from .tir.node import BufferSlice
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/parser_v1/diagnostics.py
similarity index 95%
rename from python/tvm/script/diagnostics.py
rename to python/tvm/script/parser_v1/diagnostics.py
index e676461ab3..b15997552f 100644
--- a/python/tvm/script/diagnostics.py
+++ b/python/tvm/script/parser_v1/diagnostics.py
@@ -17,11 +17,11 @@
"""Bridge from synr's (the library used for parsing the python AST)
DiagnosticContext to TVM's diagnostics
"""
-from synr import DiagnosticContext, ast
-
import tvm
+from synr import DiagnosticContext, ast
+from tvm.ir.diagnostics import Diagnostic
from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
+from tvm.ir.diagnostics import DiagnosticLevel, get_renderer
class TVMDiagnosticCtx(DiagnosticContext):
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/parser_v1/meta_unparser.py
similarity index 100%
rename from python/tvm/script/meta_unparser.py
rename to python/tvm/script/parser_v1/meta_unparser.py
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser_v1/parser.py
similarity index 99%
rename from python/tvm/script/parser.py
rename to python/tvm/script/parser_v1/parser.py
index 908af081c9..b2b7e388f1 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser_v1/parser.py
@@ -20,35 +20,34 @@ We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
different python versions. Synr also provides an error handling context that we
use for error reporting.
"""
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
+import inspect
import json
import operator
-import inspect
+
+# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
+import types
from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
import tvm
+from synr import Transformer, ast, to_ast
from tvm import IRModule
from tvm._ffi.base import TVMError
from tvm.ir import GlobalVar
from tvm.ir.function import BaseFunc
from tvm.tir import buffer
from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
+from . import _ffi_api, tir
from .context_maintainer import ContextMaintainer
+from .diagnostics import TVMDiagnosticCtx
from .meta_unparser import MetaUnparser
from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
+from .tir import ty
from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
+from .tir.node import BufferSlice, Slice
+from .tir.scope_handler import ForScopeHandler, ScopeHandler, WithScopeHandler
from .tir.special_stmt import SpecialStmt
-from .tir import ty
+from .utils import call_with_error_reporting, synr_span_from_tvm, tvm_span_from_synr
class CallArgumentReader(object):
diff --git a/python/tvm/script/registry.py b/python/tvm/script/parser_v1/registry.py
similarity index 97%
rename from python/tvm/script/registry.py
rename to python/tvm/script/parser_v1/registry.py
index e7d90dd515..e816b90f5d 100644
--- a/python/tvm/script/registry.py
+++ b/python/tvm/script/parser_v1/registry.py
@@ -17,7 +17,7 @@
"""TVM Script Parser Function Registry """
# pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
import types
-from typing import Union, Callable, Dict, Optional, Any
+from typing import Any, Callable, Dict, Optional, Union
class Registry(object):
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser_v1/tir/__init__.py
similarity index 100%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser_v1/tir/__init__.py
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi
similarity index 100%
rename from python/tvm/script/tir/__init__.pyi
rename to python/tvm/script/parser_v1/tir/__init__.pyi
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py
similarity index 98%
rename from python/tvm/script/tir/intrin.py
rename to python/tvm/script/parser_v1/tir/intrin.py
index 382431c229..8c51decf14 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/parser_v1/tir/intrin.py
@@ -17,11 +17,12 @@
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
import builtins
-from typing import List, Any
+from typing import Any, List
import tvm.tir
+from tvm.target import codegen
+
from ..registry import register
-from ...target import codegen
from ..utils import get_param_list, tvm_span_from_synr
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/parser_v1/tir/node.py
similarity index 98%
rename from python/tvm/script/tir/node.py
rename to python/tvm/script/parser_v1/tir/node.py
index 29e79607fb..cfaf9df476 100644
--- a/python/tvm/script/tir/node.py
+++ b/python/tvm/script/parser_v1/tir/node.py
@@ -17,12 +17,13 @@
# pylint: disable=redefined-builtin
"""TVM Script nodes."""
-from typing import Optional, Union, List, Callable
+from typing import Callable, List, Optional, Union
+
import synr
from tvm.arith import Analyzer
+from tvm.ir import Range, Span
from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
+from tvm.tir import Buffer, BufferLoad, BufferRegion, IntImm, PrimExpr, Ramp
class Slice:
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser_v1/tir/prim_func.py
similarity index 97%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser_v1/tir/prim_func.py
index 923eb97d27..a5fdcc15c5 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser_v1/tir/prim_func.py
@@ -19,7 +19,8 @@
import inspect
from typing import Callable
-from tvm.tir.function import PrimFunc
+from tvm.tir import PrimFunc
+
from ..parser import from_source
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/parser_v1/tir/scope_handler.py
similarity index 98%
rename from python/tvm/script/tir/scope_handler.py
rename to python/tvm/script/parser_v1/tir/scope_handler.py
index da7545c9a9..cd2167f4ab 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/parser_v1/tir/scope_handler.py
@@ -16,24 +16,19 @@
# under the License.
"""TVM Script Parser Scope Handler Classes"""
# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
-import synr
import numpy as np
+import synr
import tvm.tir
+from tvm.ir import Range, Span
from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
-
-from .node import BufferSlice
+from tvm.tir import Buffer, BufferRegion, ForKind, IterVar, PrimExpr, Stmt, Var
from ..context_maintainer import ContextMaintainer
from ..registry import register
-from ..utils import (
- get_param_list,
- tvm_span_from_synr,
- call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
class ScopeHandler:
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py
similarity index 99%
rename from python/tvm/script/tir/special_stmt.py
rename to python/tvm/script/parser_v1/tir/special_stmt.py
index 15502055b7..42a90f647f 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/parser_v1/tir/special_stmt.py
@@ -17,27 +17,21 @@
"""TVM Script Parser Special Stmt Classes"""
# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
# pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
import synr
+import tvm.tir
from synr import ast
+from tvm.ir import Span
from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
from tvm.runtime import Object, String
from tvm.target import Target
-from tvm.ir import Span
from tvm.tir import IntImm, IterVar, Var
-from .node import BufferSlice
-
from ..context_maintainer import BlockInfo, ContextMaintainer
from ..registry import register
-from ..utils import (
- get_param_list,
- tvm_span_from_synr,
- call_with_error_reporting,
-)
+from ..utils import call_with_error_reporting, get_param_list, tvm_span_from_synr
+from .node import BufferSlice
def convert_to_int(
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/parser_v1/tir/ty.py
similarity index 99%
rename from python/tvm/script/tir/ty.py
rename to python/tvm/script/parser_v1/tir/ty.py
index 4548102a9e..d9e4b3388d 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/parser_v1/tir/ty.py
@@ -23,6 +23,7 @@ a wrapper for uniform Type system in IR
from numbers import Integral
import tvm
+
from .special_stmt import SpecialStmt, convert_to_int
diff --git a/python/tvm/script/utils.py b/python/tvm/script/parser_v1/utils.py
similarity index 97%
rename from python/tvm/script/utils.py
rename to python/tvm/script/parser_v1/utils.py
index c655a62237..f358a90081 100644
--- a/python/tvm/script/utils.py
+++ b/python/tvm/script/parser_v1/utils.py
@@ -16,13 +16,12 @@
# under the License.
"""Helper functions in TVM Script Parser"""
-from typing import Callable, List, Any, Optional, Tuple
-
import inspect
-import synr
+from typing import Any, Callable, List, Optional, Tuple
-from tvm.ir import Span, SourceName
+import synr
from tvm.error import DiagnosticError
+from tvm.ir import SourceName, Span
def get_param_list(
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index ada5c369ad..b0d72d283b 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -22,9 +22,9 @@ from numbers import Integral as _Integral
from typing import List
import tvm._ffi
+import tvm.arith._ffi_api
import tvm.tir
import tvm.tir._ffi_api
-import tvm.arith._ffi_api
from tvm._ffi.base import string_types
from tvm.ir import Array
from tvm.runtime import convert
@@ -420,11 +420,14 @@ def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimF
)
for tensor, buffer in zip(input_tensors, input_buffers):
# TODO(csullivan): Can a stronger comparison between Tensor<>Buffer be made?
- assert tensor.shape == buffer.shape, (
- "The input input_tensors provided do not match the input buffers in the ",
- "primfunc. Please check that the order of input te.Input_Tensors and the ",
- "order of the primfunc variables in the params list agree.",
- )
+ assert len(tensor.shape) == len(buffer.shape)
+ for d1, d2 in zip(tensor.shape, buffer.shape):
+ assert d1 == d2, (
+ "The input input_tensors provided do not match the input buffers in the ",
+ "primfunc. Please check that the order of input te.Input_Tensors and the ",
+ "order of the primfunc variables in the params list agree.",
+ )
+
output = extern(
[buf.shape for buf in outputs],
input_tensors,
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index c64b7dfe71..41a3f86233 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -19,50 +19,185 @@
from tvm.ir import PrimExpr
from tvm.runtime import const
-from .buffer import Buffer, decl_buffer, DataProducer
-from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
-from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
-from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
-from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
-from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
-
-from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
+from . import analysis
+from . import ir_builder_v1 as ir_builder
+from . import schedule, stmt_functor, transform, usmp
+from .buffer import Buffer, DataProducer, decl_buffer
+from .data_layout import BijectiveLayout, Layout, bijective_layout, layout
+from .expr import (
+ EQ,
+ GE,
+ GT,
+ LE,
+ LT,
+ NE,
+ Add,
+ And,
+ Any,
+ Broadcast,
+ BufferLoad,
+ Call,
+ CallEffectKind,
+ Cast,
+ CommReducer,
+ Div,
+ FloatImm,
+ FloorDiv,
+ FloorMod,
+ IntImm,
+ IterVar,
+ Let,
+ Load,
+ Max,
+ Min,
+ Mod,
+ Mul,
+ Not,
+ Or,
+ ProducerLoad,
+ Ramp,
+ Reduce,
+ Select,
+ Shuffle,
+ SizeVar,
+ StringImm,
+ Sub,
+ Var,
+)
+from .function import IndexMap, PrimFunc, TensorIntrin
+from .op import (
+ TVMBackendAllocWorkspace,
+ TVMBackendFreeWorkspace,
+ abs,
+ acos,
+ acosh,
+ address_of,
+ all,
+ any,
+ asin,
+ asinh,
+ assume,
+ atan,
+ atan2,
+ atanh,
+ call_cpacked,
+ call_cpacked_lowered,
+ call_extern,
+ call_intrin,
+ call_llvm_intrin,
+ call_llvm_pure_intrin,
+ call_packed,
+ call_packed_lowered,
+ call_pure_extern,
+ ceil,
+ ceildiv,
+ clz,
+ comm_reducer,
+ copysign,
+ cos,
+ cosh,
+ div,
+ erf,
+ exp,
+ exp2,
+ exp10,
+ floor,
+ floordiv,
+ floormod,
+ fmod,
+ hypot,
+ if_then_else,
+ indexdiv,
+ indexmod,
+ isfinite,
+ isinf,
+ isnan,
+ isnullptr,
+ ldexp,
+ likely,
+ log,
+ log1p,
+ log2,
+ log10,
+ lookup_param,
+ max,
+ max_value,
+ min,
+ min_value,
+ mma_fill,
+ mma_store,
+ nearbyint,
+ nextafter,
+ popcount,
+ power,
+ ptx_commit_group,
+ ptx_cp_async,
+ ptx_ldmatrix,
+ ptx_mma,
+ ptx_mma_sp,
+ ptx_wait_group,
+ q_multiply_shift,
+ ret,
+ round,
+ rsqrt,
+ shift_left,
+ shift_right,
+ sigmoid,
+ sin,
+ sinh,
+ sqrt,
+ sum,
+ tan,
+ tanh,
+ trace,
+ trunc,
+ truncdiv,
+ truncmod,
+ tvm_access_ptr,
+ tvm_bmma_sync,
+ tvm_fill_fragment,
+ tvm_load_matrix_sync,
+ tvm_mma_sync,
+ tvm_stack_alloca,
+ tvm_stack_make_array,
+ tvm_stack_make_shape,
+ tvm_store_matrix_sync,
+ tvm_struct_get,
+ tvm_struct_set,
+ tvm_thread_allreduce,
+ tvm_throw_last_error,
+ tvm_tuple,
+ undef,
+ vectorcombine,
+ vectorhigh,
+ vectorlow,
+)
+from .schedule import BlockScope, Schedule, ScheduleError, ScheduleState, StmtSRef
from .stmt import (
- BufferStore,
- BufferRealize,
- Store,
- ProducerStore,
Allocate,
AllocateConst,
+ AssertStmt,
AttrStmt,
+ Block,
+ BlockRealize,
+ BufferRealize,
+ BufferRegion,
+ BufferStore,
DeclBuffer,
+ Evaluate,
+ For,
+ ForKind,
+ IfThenElse,
+ LetStmt,
+ MatchBufferRegion,
+ Prefetch,
+ ProducerRealize,
+ ProducerStore,
+ SeqStmt,
+ Stmt,
+ Store,
+ While,
+ stmt_list,
+ stmt_seq,
+ type_annotation,
)
-
-from .stmt import ProducerRealize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
-from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-
-from .function import PrimFunc, TensorIntrin, IndexMap
-
-from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
-from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
-from .op import sin, sinh, asin, asinh
-from .op import cos, cosh, acos, acosh
-from .op import tan, tanh, atan, atan2, atanh
-from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
-from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
-from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
-from .op import comm_reducer, min, max, sum
-from .op import q_multiply_shift
-
-from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
-
-from . import schedule
-from . import ir_builder
-from . import transform
-from . import analysis
-from . import stmt_functor
-from . import usmp
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/tir/_ffi_ir_builder_api.py
similarity index 88%
rename from python/tvm/script/_ffi_api.py
rename to python/tvm/tir/_ffi_ir_builder_api.py
index 926d17b166..61b288d498 100644
--- a/python/tvm/script/_ffi_api.py
+++ b/python/tvm/tir/_ffi_ir_builder_api.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script"""
+"""FFI APIs for tvm.ir"""
import tvm._ffi
-tvm._ffi._init_api("script", __name__)
+tvm._ffi._init_api("ir_builder.tir", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 13674daa24..ea220fea22 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -21,9 +21,9 @@ from typing import Dict, List, Union
from tvm import Object
from tvm.ir import IRModule
from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
+from tvm.tir.stmt import Block, BufferRegion, PrimExpr, Stmt
-from .. import Buffer, Stmt
+from ..buffer import Buffer
from ..function import PrimFunc
from . import _ffi_api
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index d9b0aec76a..e74eb15453 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -16,11 +16,12 @@
# under the License.
"""Abstraction for array data structures."""
from numbers import Integral
-import tvm._ffi
+import tvm._ffi
from tvm._ffi.base import string_types
+from tvm.ir import PointerType, PrimExpr, PrimType, Range
from tvm.runtime import Object, convert
-from tvm.ir import PrimExpr, PointerType, PrimType
+
from . import _ffi_api
@@ -176,6 +177,40 @@ class Buffer(Object):
"""
return _ffi_api.BufferOffsetOf(self, indices) # type: ignore
+ def __getitem__(self, indices):
+ from ..arith import Analyzer # pylint: disable=import-outside-toplevel
+ from .expr import BufferLoad, Ramp # pylint: disable=import-outside-toplevel
+ from .stmt import BufferRegion # pylint: disable=import-outside-toplevel
+
+ if not isinstance(indices, (tuple, list)):
+ indices = [indices]
+ if any(isinstance(index, slice) and index.step is None for index in indices):
+ region = []
+ for index in indices:
+ if isinstance(index, slice):
+ region.append(
+ Range.from_min_extent(
+ index.start, Analyzer().simplify(index.stop - index.start)
+ )
+ )
+ else:
+ region.append(Range.from_min_extent(index, 1))
+ return BufferRegion(self, region)
+ else:
+ expr_indices = []
+ for index in indices:
+ if isinstance(index, slice):
+ lanes = Analyzer().simplify(
+ (index.stop - index.start + index.step - 1) // index.step
+ )
+ if lanes == 1:
+ expr_indices.append(index.start)
+ else:
+ expr_indices.append(Ramp(index.start, index.step, int(lanes)))
+ else:
+ expr_indices.append(index)
+ return BufferLoad(self, expr_indices)
+
def decl_buffer(
shape,
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
assert(y.a == x)
"""
from typing import Optional, Union
-from tvm import ir
+
import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
from . import _ffi_api
+from . import generic as _generic
def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
class ExprOp(object):
"""Operator overloading for Expr like expressions."""
- # TODO(tkonolige): use inspect to add source information to these objects
-
def __add__(self, other):
return _generic.add(self, other)
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
"""
def __init__(self, condition, true_value, false_value, span=None):
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
self.__init_handle_by_constructor__(
_ffi_api.Select, condition, true_value, false_value, span # type: ignore
)
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index f06376147b..6c57e27b82 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -394,7 +394,7 @@ class IndexMap(Object):
raise TypeError(
"Expected mapping function to return list of "
"either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. "
- "Instead received {val} of type {type(val)}."
+ f"Instead received {val} of type {type(val)}."
)
return IndexMap(initial_indices, final_indices), axis_separators
diff --git a/python/tvm/tir/ir_builder_frame.py b/python/tvm/tir/ir_builder_frame.py
new file mode 100644
index 0000000000..a1f457aad2
--- /dev/null
+++ b/python/tvm/tir/ir_builder_frame.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IRBuilder for TIR"""
+
+from typing import List
+
+from tvm._ffi import register_object as _register_object
+from tvm.ir.ir_builder import IRBuilderFrame
+
+from .buffer import Buffer
+from .expr import Var
+
+
+@_register_object("ir_builder.tir.TIRFrame")
+class TIRFrame(IRBuilderFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.BlockFrame")
+class BlockFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.ForFrame")
+class ForFrame(TIRFrame):
+ def __enter__(self) -> List[Var]:
+ super().__enter__()
+ return self.vars if len(self.vars) > 1 else self.vars[0]
+
+
+@_register_object("ir_builder.tir.PrimFuncFrame")
+class PrimFuncFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.AssertFrame")
+class AssertFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.LetFrame")
+class LetFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.AllocateFrame")
+class AllocateFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
+
+
+@_register_object("ir_builder.tir.AllocateConstFrame")
+class AllocateConstFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
+
+
+@_register_object("ir_builder.tir.LaunchThreadFrame")
+class LaunchThreadFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.RealizeFrame")
+class RealizeFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.AttrFrame")
+class AttrFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.WhileFrame")
+class WhileFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.IfFrame")
+class IfFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.ThenFrame")
+class ThenFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.ElseFrame")
+class ElseFrame(TIRFrame):
+ ...
+
+
+@_register_object("ir_builder.tir.DeclBufferFrame")
+class DeclBufferFrame(TIRFrame):
+ def __enter__(self) -> Buffer:
+ super().__enter__()
+ return self.buffer
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder_v1.py
similarity index 99%
rename from python/tvm/tir/ir_builder.py
rename to python/tvm/tir/ir_builder_v1.py
index ce8cd1b403..8a68faaac6 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder_v1.py
@@ -17,13 +17,13 @@
"""Developer API of IR node builder make function."""
import tvm
from tvm._ffi.base import string_types
-from tvm.runtime import ObjectGeneric, convert, const
from tvm.ir import container as _container
+from tvm.runtime import ObjectGeneric, const, convert
-from . import stmt as _stmt
-from . import expr as _expr
from . import buffer as _buffer
+from . import expr as _expr
from . import op
+from . import stmt as _stmt
class WithScope(object):
diff --git a/python/tvm/tir/ir_builder_v2.py b/python/tvm/tir/ir_builder_v2.py
new file mode 100644
index 0000000000..b4c8aa3a1d
--- /dev/null
+++ b/python/tvm/tir/ir_builder_v2.py
@@ -0,0 +1,949 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""IRBuilder for TIR"""
+import functools
+import inspect
+from numbers import Integral
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from tvm.ir import Range, Type
+from tvm.runtime import convert, ndarray
+from tvm.target import Target as target
+
+from . import _ffi_ir_builder_api as _ffi_api
+from . import ir_builder_frame as frame
+from . import op as _tir_op
+from .buffer import Buffer
+from .expr import Broadcast as broadcast
+from .expr import BufferLoad, CommReducer, IntImm, IterVar, Let, PrimExpr
+from .expr import Ramp as ramp, Cast
+from .expr import Select, Shuffle, StringImm, Var
+from .generic import cast
+from .stmt import BufferRegion, type_annotation
+
+
+def buffer_decl(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+) -> Buffer:
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ return _ffi_api.BufferDecl( # pylint: disable=no-member # type: ignore
+ shape,
+ dtype,
+ "",
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def ptr(dtype, storage_scope="global"):
+ return _ffi_api.Ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
+
+
+def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
+ return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore
+
+
+def init() -> frame.BlockInitFrame:
+ return _ffi_api.Init() # pylint: disable=no-member # type: ignore
+
+
+def where(predicate) -> None:
+ if isinstance(predicate, bool):
+ predicate = IntImm("bool", predicate)
+ if isinstance(predicate, int):
+ if predicate in [0, 1]:
+ predicate = IntImm("bool", predicate)
+ else:
+ raise ValueError("Invalid value for predicate: {}".format(predicate))
+ _ffi_api.Where(predicate) # pylint: disable=no-member # type: ignore
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+ if len(buffer_slices) == 1:
+ if isinstance(buffer_slices[0], tuple):
+ buffer_slices = list(buffer_slices[0])
+ elif isinstance(buffer_slices[0], list):
+ buffer_slices = buffer_slices[0] # type: ignore
+ else:
+ buffer_slices = [buffer_slices[0]] # type: ignore
+ else:
+ buffer_slices = list(buffer_slices) # type: ignore
+ _ffi_api.Reads(buffer_slices) # pylint: disable=no-member # type: ignore
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+ if len(buffer_slices) == 1:
+ if isinstance(buffer_slices[0], tuple):
+ buffer_slices = list(buffer_slices[0])
+ elif isinstance(buffer_slices[0], list):
+ buffer_slices = buffer_slices[0] # type: ignore
+ else:
+ buffer_slices = [buffer_slices[0]]
+ else:
+ buffer_slices = list(buffer_slices) # type: ignore
+ _ffi_api.Writes(buffer_slices) # pylint: disable=no-member # type: ignore
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+ return _ffi_api.BlockAttrs(attrs) # pylint: disable=no-member # type: ignore
+
+
+def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ axis_separators=None,
+) -> Buffer:
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is None:
+ strides = []
+ return _ffi_api.AllocBuffer( # pylint: disable=no-member # type: ignore
+ shape,
+ dtype,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def _as_range(dom) -> Range:
+ if isinstance(dom, Range):
+ return dom
+ if isinstance(dom, (list, tuple)):
+ return Range(dom[0], dom[1])
+ return Range(0, dom)
+
+
+class axis: # pylint: disable=invalid-name
+ @staticmethod
+ def spatial(dom, binding, dtype="int32") -> IterVar:
+ return _ffi_api.AxisSpatial( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def reduce(dom, binding, dtype="int32") -> IterVar:
+ return _ffi_api.AxisReduce( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def scan(dom, binding, dtype="int32") -> IterVar:
+ return _ffi_api.AxisScan( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def opaque(dom, binding, dtype="int32") -> IterVar:
+ return _ffi_api.AxisOpaque( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def remap(kinds, bindings, dtype="int32") -> Union[List[IterVar], IterVar]:
+ iter_vars = _ffi_api.AxisRemap( # pylint: disable=no-member # type: ignore
+ kinds, bindings, dtype
+ )
+ return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+ S = spatial # pylint: disable=invalid-name
+ R = reduce # pylint: disable=invalid-name
+
+
+def serial(start, stop=None, *, annotations=None) -> frame.ForFrame:
+ if stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore
+
+
+def parallel(start, stop=None, *, annotations=None) -> frame.ForFrame:
+ if stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore
+
+
+def vectorized(start, stop=None, *, annotations=None) -> frame.ForFrame:
+ if stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore
+
+
+def unroll(start, stop=None, *, annotations=None) -> frame.ForFrame:
+ if stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore
+
+
+def thread_binding(
+ start,
+ stop=None,
+ thread=None,
+ *,
+ annotations=None,
+) -> frame.ForFrame:
+ if thread is None:
+ if not isinstance(stop, str):
+ raise ValueError("Thread cannot be None for thread_binding")
+ thread = stop
+ stop = start
+ start = 0
+ elif stop is None:
+ stop = start
+ start = 0
+ return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore
+ start, stop, thread, annotations
+ )
+
+
+def grid(*extents) -> frame.ForFrame:
+ return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore
+
+
+def prim_func() -> frame.PrimFuncFrame:
+ return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore
+
+
+def arg(name, obj):
+ return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore
+
+
+def func_name(name: str) -> str:
+ return _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore
+
+
+def func_attr(attrs: Dict[str, Any]) -> None:
+ return _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore
+
+
+def func_ret(ret_type) -> Type:
+ return _ffi_api.FuncRet(ret_type) # pylint: disable=no-member # type: ignore
+
+
+def match_buffer(
+ param,
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ axis_separators=None,
+) -> Buffer:
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is None:
+ strides = []
+ return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore
+ param,
+ shape,
+ dtype,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def preflattened_buffer(
+ postflattened,
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ axis_separators=None,
+) -> None:
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is None:
+ strides = []
+ _ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore
+ postflattened,
+ shape,
+ dtype,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name
+ return _ffi_api.Assert(condition, message) # pylint: disable=no-member # type: ignore
+
+
+def let(
+ v: Var,
+ value: PrimExpr,
+ body: PrimExpr = None,
+) -> frame.LetFrame:
+ if body is None:
+ return _ffi_api.Let(v, value) # pylint: disable=no-member # type: ignore
+ return Let(v, value, body)
+
+
+def allocate(
+ extents: List[PrimExpr],
+ dtype: str,
+ scope: str = "",
+ condition: PrimExpr = None,
+ annotations=None,
+) -> frame.AllocateFrame:
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.Allocate( # pylint: disable=no-member # type: ignore
+ extents, dtype, scope, condition, annotations
+ )
+
+
+def allocate_const(
+ data: List[PrimExpr],
+ dtype: str,
+ extents: List[PrimExpr],
+ annotations=None,
+) -> frame.AllocateConstFrame:
+
+ return _ffi_api.AllocateConst( # pylint: disable=no-member # type: ignore
+ ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+ )
+
+
+def realize(
+ buffer_slice: BufferRegion,
+ storage_scope: str,
+ condition: PrimExpr = True,
+) -> frame.RealizeFrame:
+ return _ffi_api.Realize( # pylint: disable=no-member # type: ignore
+ buffer_slice, storage_scope, condition
+ )
+
+
+def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
+ node = convert(node)
+ value = convert(value)
+ return _ffi_api.Attr(node, attr_key, value) # pylint: disable=no-member # type: ignore
+
+
+def While(condition: PrimExpr) -> frame.WhileFrame: # pylint: disable=invalid-name
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.While(condition) # pylint: disable=no-member # type: ignore
+
+
+def If(condition: PrimExpr) -> frame.IfFrame: # pylint: disable=invalid-name
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
+ return _ffi_api.If(condition) # pylint: disable=no-member # type: ignore
+
+
+def Then() -> frame.ThenFrame: # pylint: disable=invalid-name
+ return _ffi_api.Then() # pylint: disable=no-member # type: ignore
+
+
+def Else() -> frame.ElseFrame: # pylint: disable=invalid-name
+ return _ffi_api.Else() # pylint: disable=no-member # type: ignore
+
+
+def decl_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+) -> frame.DeclBufferFrame:
+
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ return _ffi_api.DeclBuffer( # pylint: disable=no-member # type: ignore
+ shape,
+ dtype,
+ "",
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
+
+
+def launch_thread(
+ iter_var: IterVar, # pylint: disable=redefined-outer-name
+ extent: PrimExpr,
+) -> frame.LaunchThreadFrame:
+ return _ffi_api.LaunchThread(iter_var, extent) # pylint: disable=no-member # type: ignore
+
+
+def env_thread(thread_tag: str) -> IterVar:
+ return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore
+
+
+def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
+ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel
+
+ expr_indices = []
+ for index in indices:
+ if isinstance(index, slice):
+ step = 1 if index.step is None else index.step
+ lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step)
+ if lanes == 1:
+ expr_indices.append(index.start)
+ else:
+ expr_indices.append(ramp(index.start, step, int(lanes)))
+ else:
+ expr_indices.append(index)
+ if isinstance(value, bool) and buffer.dtype == "bool":
+ value = IntImm("bool", value)
+ return _ffi_api.BufferStore( # pylint: disable=no-member # type: ignore
+ buffer, value, expr_indices
+ )
+
+
+def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+ return _ffi_api.Prefetch(buffer, indices) # pylint: disable=no-member # type: ignore
+
+
+def evaluate(value: PrimExpr) -> None:
+ if isinstance(value, str):
+ value = StringImm(value)
+ return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore
+
+
+def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int8(expr) # pylint: disable=no-member # type: ignore
+
+
+def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int16(expr) # pylint: disable=no-member # type: ignore
+
+
+def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int32(expr) # pylint: disable=no-member # type: ignore
+
+
+def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int64(expr) # pylint: disable=no-member # type: ignore
+
+
+def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.UInt8(expr) # pylint: disable=no-member # type: ignore
+
+
+def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.UInt16(expr) # pylint: disable=no-member # type: ignore
+
+
+def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.UInt32(expr) # pylint: disable=no-member # type: ignore
+
+
+def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.UInt64(expr) # pylint: disable=no-member # type: ignore
+
+
+def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ if not isinstance(expr, PrimExpr):
+ expr = convert(expr)
+ return _ffi_api.Float8(expr) # pylint: disable=no-member # type: ignore
+
+
+def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ if not isinstance(expr, PrimExpr):
+ expr = convert(expr)
+ return _ffi_api.Float16(expr) # pylint: disable=no-member # type: ignore
+
+
+def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ if not isinstance(expr, PrimExpr):
+ expr = convert(expr)
+ return _ffi_api.Float32(expr) # pylint: disable=no-member # type: ignore
+
+
+def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ if not isinstance(expr, PrimExpr):
+ expr = convert(expr)
+ return _ffi_api.Float64(expr) # pylint: disable=no-member # type: ignore
+
+
+def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int32x4(expr) # pylint: disable=no-member # type: ignore
+
+
+def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int32x8(expr) # pylint: disable=no-member # type: ignore
+
+
+def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Int32x16(expr) # pylint: disable=no-member # type: ignore
+
+
+def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Boolean(expr) # pylint: disable=no-member # type: ignore
+
+
+def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Handle(expr) # pylint: disable=no-member # type: ignore
+
+
+def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
+ return _ffi_api.Void(expr) # pylint: disable=no-member # type: ignore
+
+
+def min(a, b): # pylint: disable=redefined-builtin
+ """Compute the minimum value of two expressions.
+
+ Parameters
+ ----------
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
+ Returns
+ -------
+ res : PrimExpr
+ The result expression.
+ """
+ return _ffi_api.min(a, b) # pylint: disable=no-member # type: ignore
+
+
+def max(a, b): # pylint: disable=redefined-builtin
+ """Compute the maximum value of two expressions.
+
+ Parameters
+ ----------
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
+ Returns
+ -------
+ res : PrimExpr
+ The result expression.
+ """
+ return _ffi_api.max(a, b) # pylint: disable=no-member # type: ignore
+
+
+def var(dtype, name="") -> Var:
+ return Var(name, dtype) # pylint: disable=no-member # type: ignore
+
+
+def iter_var(v, dom, iter_type, thread_tag):
+ iter_type = getattr(IterVar, iter_type)
+ return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner, identity):
+ """Create a CommReducer from lambda inputs/outputs and the identities"""
+ params = inspect.signature(combiner).parameters
+ num_args = len(params)
+ args = []
+ for name, i in zip(params.keys(), identity + identity):
+ if isinstance(i, int):
+ args.append(Var(name, "int32"))
+ else:
+ args.append(Var(name, i.dtype))
+ res = combiner(*args)
+ if not isinstance(res, tuple):
+ res = (res,)
+ return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def llvm_lookup_intrinsic_id(name):
+ # pylint: disable=import-outside-toplevel
+ from tvm.target.codegen import llvm_lookup_intrinsic_id as f
+
+ # pylint: enable=import-outside-toplevel
+ return f(name)
+
+
+def _op_wrapper(func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if "dtype" in kwargs:
+ kwargs.pop("dtype")
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
+def _dtype_forward(func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if "dtype" in kwargs:
+ args = (kwargs.pop("dtype"),) + args
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
+# pylint: disable=invalid-name
+
+buffer_var = ptr
+abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+vectorlow = _dtype_forward(_tir_op.vectorlow)
+vectorhigh = _dtype_forward(_tir_op.vectorhigh)
+vectorcombine = _dtype_forward(_tir_op.vectorcombine)
+assume = _op_wrapper(_tir_op.assume)
+undef = _op_wrapper(_tir_op.undef)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+
+class inline:
+ def __init__(self, value) -> None:
+ self.value = value
+ self.i = 0
+
+ def __iter__(self):
+ def f():
+ for i in self.value:
+ yield inline(i)
+
+ return f()
+
+
+# pylint: enable=invalid-name
+
+
+__all__ = [
+ "Assert",
+ "Cast",
+ "Else",
+ "If",
+ "Let",
+ "Select",
+ "Shuffle",
+ "TVMBackendAllocWorkspace",
+ "TVMBackendFreeWorkspace",
+ "Then",
+ "While",
+ "abs",
+ "acos",
+ "acosh",
+ "address_of",
+ "alloc_buffer",
+ "allocate",
+ "allocate_const",
+ "arg",
+ "asin",
+ "asinh",
+ "assume",
+ "atan",
+ "atan2",
+ "atanh",
+ "attr",
+ "axis",
+ "block",
+ "block_attr",
+ "boolean",
+ "broadcast",
+ "buffer_decl",
+ "buffer_store",
+ "buffer_var",
+ "call_cpacked",
+ "call_cpacked_lowered",
+ "call_extern",
+ "call_intrin",
+ "call_llvm_intrin",
+ "call_llvm_pure_intrin",
+ "call_packed",
+ "call_packed_lowered",
+ "call_pure_extern",
+ "cast",
+ "ceil",
+ "ceildiv",
+ "clz",
+ "comm_reducer",
+ "copysign",
+ "cos",
+ "cosh",
+ "env_thread",
+ "erf",
+ "evaluate",
+ "exp",
+ "exp10",
+ "exp2",
+ "decl_buffer",
+ "fabs",
+ "float16",
+ "float32",
+ "float64",
+ "float8",
+ "floor",
+ "floordiv",
+ "floormod",
+ "fmod",
+ "func_attr",
+ "func_name",
+ "func_ret",
+ "grid",
+ "handle",
+ "hypot",
+ "if_then_else",
+ "infinity",
+ "init",
+ "inline",
+ "int16",
+ "int32",
+ "int32x16",
+ "int32x4",
+ "int32x8",
+ "int64",
+ "int8",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "isnullptr",
+ "iter_var",
+ "launch_thread",
+ "ldexp",
+ "let",
+ "likely",
+ "llvm_lookup_intrinsic_id",
+ "log",
+ "log10",
+ "log1p",
+ "log2",
+ "lookup_param",
+ "match_buffer",
+ "max",
+ "max_value",
+ "min",
+ "min_value",
+ "mma_fill",
+ "mma_store",
+ "nearbyint",
+ "nextafter",
+ "parallel",
+ "popcount",
+ "power",
+ "prefetch",
+ "preflattened_buffer",
+ "prim_func",
+ "ptr",
+ "ptx_commit_group",
+ "ptx_cp_async",
+ "ptx_ldmatrix",
+ "ptx_mma",
+ "ptx_mma_sp",
+ "ptx_wait_group",
+ "q_multiply_shift",
+ "ramp",
+ "reads",
+ "realize",
+ "reinterpret",
+ "ret",
+ "round",
+ "rsqrt",
+ "serial",
+ "shift_left",
+ "shift_right",
+ "sigmoid",
+ "sin",
+ "sinh",
+ "sqrt",
+ "tan",
+ "tanh",
+ "target",
+ "thread_binding",
+ "trunc",
+ "truncdiv",
+ "truncmod",
+ "tvm_access_ptr",
+ "tvm_bmma_sync",
+ "tvm_call_cpacked",
+ "tvm_call_cpacked_lowered",
+ "tvm_call_packed",
+ "tvm_call_packed_lowered",
+ "tvm_fill_fragment",
+ "tvm_load_matrix_sync",
+ "tvm_mma_sync",
+ "tvm_stack_alloca",
+ "tvm_stack_make_array",
+ "tvm_stack_make_shape",
+ "tvm_store_matrix_sync",
+ "tvm_struct_get",
+ "tvm_struct_set",
+ "tvm_thread_allreduce",
+ "tvm_throw_last_error",
+ "tvm_tuple",
+ "type_annotation",
+ "uint16",
+ "uint32",
+ "uint64",
+ "uint8",
+ "undef",
+ "unroll",
+ "var",
+ "vectorcombine",
+ "vectorhigh",
+ "vectorized",
+ "vectorlow",
+ "void",
+ "where",
+ "writes",
+]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 17005b04a4..e2cad37bb6 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -14,17 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=redefined-builtin, invalid-name
+# pylint: disable=redefined-builtin,invalid-name,no-member,protected-access
"""Operators used in TIR expression."""
+import warnings
from typing import Any, Optional
+
import tvm._ffi
-from tvm.ir.base import Span
-from tvm.runtime import convert, const
from tvm.ir import Array, Op
+from tvm.ir.base import Span
+from tvm.runtime import const, convert
-from .buffer import Buffer
-from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer
from . import _ffi_api
+from .buffer import Buffer
+from .expr import Call, CommReducer, PrimExprWithOp, StringImm, Var
def _pack_buffer(buf, span=None):
@@ -100,6 +102,64 @@ def call_cpacked(*args, span=None):
return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span)
+def call_packed_lowered(*args, span=None):
+ """Lowered version of call packed.
+
+ The argument to packed function can be Expr or Buffer.
+ The argument is the corresponding POD type when Expr is presented.
+
+ When the argument is Buffer, the corresponding PackedFunc
+ will recieve an TVMArrayHandle whose content is valid during the callback period.
+ If the PackedFunc is a python callback, then the corresponding argument is NDArray.
+
+ Parameters
+ ----------
+ args : list of Expr or Buffer.
+ Positional arguments.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+
+ See Also
+ --------
+ te.extern : Create tensor with extern function call.
+ """
+ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+ return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span)
+
+
+def call_cpacked_lowered(*args, span=None):
+ """Lowered version of call c-packed.
+
+ Same as call_packed, except that the first argument is the function name
+ (as in call_extern), and the last argument is the resource handle.
+
+ Parameters
+ ----------
+ args : list of Expr or Buffer.
+ Positional arguments.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+
+ See Also
+ --------
+ te.extern : Create tensor with extern function call.
+ """
+ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
+ return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span)
+
+
def call_intrin(dtype, func_name, *args, span=None):
"""Build expression by calling an intrinsic function.
@@ -151,7 +211,10 @@ def call_pure_extern(dtype, func_name, *args, span=None):
The call expression.
"""
return Call(
- dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span
+ dtype,
+ Op.get("tir.call_pure_extern"),
+ convert((StringImm(func_name),) + args),
+ span,
)
@@ -178,7 +241,10 @@ def call_extern(dtype, func_name, *args, span=None):
The call expression.
"""
return Call(
- dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span
+ dtype,
+ Op.get("tir.call_extern"),
+ convert((StringImm(func_name),) + args),
+ span=span,
)
@@ -207,10 +273,22 @@ def call_llvm_intrin(dtype, name, *args, span=None):
# pylint: disable=import-outside-toplevel
from tvm.target import codegen
- llvm_id = codegen.llvm_lookup_intrinsic_id(name)
- assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+ from .expr import IntImm
+
+ if isinstance(name, str):
+ llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ elif isinstance(name, IntImm):
+ llvm_id = name.value
+ else:
+ llvm_id = name
+ if llvm_id == 0:
+ warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
return call_intrin(
- dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+ dtype,
+ Op.get("tir.call_llvm_intrin"),
+ tvm.tir.const(llvm_id, "uint32"),
+ *args,
+ span=span,
)
@@ -239,8 +317,16 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
# pylint: disable=import-outside-toplevel
from tvm.target import codegen
- llvm_id = codegen.llvm_lookup_intrinsic_id(name)
- assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+ from .expr import IntImm
+
+ if isinstance(name, str):
+ llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ elif isinstance(name, IntImm):
+ llvm_id = name.value
+ else:
+ llvm_id = name
+ if llvm_id == 0:
+ warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
return call_intrin(
dtype,
Op.get("tir.call_llvm_pure_intrin"),
@@ -250,6 +336,326 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+ return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+ return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
+def tvm_stack_alloca(dtype_str, num):
+ return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num)
+
+
+def tvm_stack_make_shape(*args):
+ return call_intrin("handle", "tir.tvm_stack_make_shape", *args)
+
+
+def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
+ return call_intrin(
+ "handle", "tir.tvm_stack_make_array", data, shape, strides, ndim, arr_dtype, elem_offset
+ )
+
+
+def address_of(buffer_load, span=None):
+ """Returns the address of an element in the buffer
+
+ Parameters
+ ----------
+ buffer_load: BufferLoad
+ The buffer load.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+
+
+def lookup_param(param_name, span=None):
+ """Returns the param by name
+
+ Parameters
+ ----------
+ param_name : str
+ The name of param.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.lookup_param", param_name, span=span)
+
+
+def tvm_tuple(*value):
+ return call_intrin("handle", "tir.tvm_tuple", *value)
+
+
+def tvm_struct_get(arr, index, field_id, dtype):
+ return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field_id)
+
+
+def tvm_struct_set(arr, index, field_id, value):
+ return call_intrin("handle", "tir.tvm_struct_set", arr, index, field_id, value)
+
+
+def tvm_thread_allreduce(*freduce_args):
+ return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+ return call_intrin(
+ "handle",
+ "tir.tvm_load_matrix_sync",
+ fragment,
+ m,
+ n,
+ k,
+ index,
+ buffer_ptr,
+ stride,
+ layout,
+ )
+
+
+def tvm_mma_sync(
+ fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+ return call_intrin(
+ "handle",
+ "tir.tvm_mma_sync",
+ fragment_d,
+ index_d,
+ fragment_a,
+ index_a,
+ fragment_b,
+ index_b,
+ fragment_c,
+ index_c,
+ )
+
+
+def tvm_bmma_sync(
+ fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
+):
+ return call_intrin(
+ "handle",
+ "tir.tvm_bmma_sync",
+ fragment_d,
+ index_d,
+ fragment_a,
+ index_a,
+ fragment_b,
+ index_b,
+ fragment_c,
+ index_c,
+ )
+
+
+def tvm_fill_fragment(fragment, m, n, k, index, value):
+ return call_intrin(
+ "handle",
+ "tir.tvm_fill_fragment",
+ fragment,
+ m,
+ n,
+ k,
+ index,
+ value,
+ )
+
+
+def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
+ return call_intrin(
+ "handle",
+ "tir.tvm_store_matrix_sync",
+ fragment,
+ m,
+ n,
+ k,
+ index,
+ buffer_ptr,
+ stride,
+ layout,
+ )
+
+
+def ptx_mma( # pylint: disable=missing-docstring
+ dtype,
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ operator=None,
+):
+ if operator is None:
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ )
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ operator,
+ )
+
+
+def ptx_mma_sp( # pylint: disable=missing-docstring
+ dtype,
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ metadata,
+ meta_index,
+ sparse_selector,
+ saturate,
+):
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma_sp",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ metadata,
+ meta_index,
+ sparse_selector,
+ saturate,
+ )
+
+
+def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
+ return call_intrin(
+ dtype,
+ "tir.ptx_ldmatrix",
+ trans,
+ num,
+ type,
+ local_ptr,
+ local_offset,
+ smem_ptr,
+ smem_offset,
+ )
+
+
+def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
+ return call_intrin(
+ dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
+ )
+
+
+def ptx_commit_group():
+ return call_intrin("", "tir.ptx_commit_group")
+
+
+def ptx_wait_group(num):
+ return call_intrin("", "tir.ptx_wait_group", num)
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+ return call_intrin(
+ dtype,
+ "tir.mma_store",
+ m,
+ n,
+ dst_ptr,
+ src_ptr,
+ src_offset,
+ dst_stride,
+ )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+ return call_intrin(
+ dtype,
+ "tir.mma_fill",
+ local_size,
+ local_ptr,
+ offset,
+ )
+
+
+def vectorlow(dtype, vec):
+ return call_intrin(dtype, "tir.vectorlow", vec)
+
+
+def vectorhigh(dtype, vec):
+ return call_intrin(dtype, "tir.vectorhigh", vec)
+
+
+def vectorcombine(dtype, vec1, vec2):
+ return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)
+
+
+def assume(cond=None):
+ return call_intrin("int32", "tir.assume", cond)
+
+
+def undef():
+ return call_intrin("int32", "tir.undef")
+
+
def ret(val):
"""Create a tir return expression
@@ -286,9 +692,9 @@ def any(*args, span=None):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
- val = _ffi_api._OpOr(args[0], args[1], span) # type: ignore
+ val = _ffi_api._OpOr(args[0], args[1], span) # type: ignore # pylint: disable=no-member,protected-access
for i in range(2, len(args)):
- val = _ffi_api._OpOr(val, args[i], span) # type: ignore
+ val = _ffi_api._OpOr(val, args[i], span) # type: ignore # pylint: disable=no-member,protected-access
return val
@@ -313,9 +719,9 @@ def all(*args, span=None):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
- val = _ffi_api._OpAnd(args[0], args[1], span) # type: ignore
+ val = _ffi_api._OpAnd(args[0], args[1], span) # type: ignore # pylint: disable=no-member,protected-access
for i in range(2, len(args)):
- val = _ffi_api._OpAnd(val, args[i], span) # type: ignore
+ val = _ffi_api._OpAnd(val, args[i], span) # type: ignore # pylint: disable=no-member,protected-access
return val
@@ -394,6 +800,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
return _ffi_api.max_value(dtype, span) # type: ignore
+def infinity(dtype: str, span: Optional[Span] = None) -> Any:
+ """infinity value of dtype
+
+ Parameters
+ ----------
+ dtype : str
+ The data type.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ value : tvm.Expr
+ The infinity value of dtype.
+ """
+ return _ffi_api.infinity(dtype, span) # type: ignore
+
+
+def reinterpret(dtype, value, span=None) -> Any:
+ """infinity value of dtype
+
+ Parameters
+ ----------
+ dtype : str
+ The data type.
+
+ value : PrimExpr
+ The input value.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ value : tvm.Expr
+ The reinterpret cast value of dtype.
+ """
+ return _ffi_api.reinterpret(dtype, value, span) # type: ignore
+
+
def exp(x):
"""Take exponential of input x.
@@ -998,6 +1445,25 @@ def ldexp(x1, x2):
return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore
+def likely(cond, span=None):
+ """Mark condition as likely.
+
+ Parameters
+ ----------
+ cond : PrimExpr
+ Input argument.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ y : PrimExpr
+ The marked expression.
+ """
+ return _ffi_api.likely(cond, span) # type: ignore
+
+
def isnan(x, span=None):
"""Check if input value is Nan.
@@ -1017,6 +1483,25 @@ def isnan(x, span=None):
return _ffi_api.isnan(x, span) # type: ignore
+def isnullptr(x, span=None):
+ """Check if input value is nullptr.
+
+ Parameters
+ ----------
+ x : PrimExpr
+ Input argument.
+
+ span : Optional[Span]
+ The location of this operator in the source code.
+
+ Returns
+ -------
+ y : PrimExpr
+ The result.
+ """
+ return call_intrin("bool", "tir.isnullptr", x, span=span) # type: ignore
+
+
def isfinite(x, span=None):
"""Check if input value is finite.
@@ -1122,6 +1607,42 @@ def q_multiply_shift(x, y, q, s):
return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)
+def shift_left(x, y, span=None):
+ """Return the result of x left shifted by y bits.
+
+ Parameters
+ ----------
+ x : PrimExpr
+ Input argument.
+ y : PrimExpr
+ Input argument.
+
+ Returns
+ -------
+ z : PrimExpr
+ The result.
+ """
+ return _ffi_api.left_shift(x, y, span)
+
+
+def shift_right(x, y, span=None):
+ """Return the result of x right shifted by y bits.
+
+ Parameters
+ ----------
+ x : PrimExpr
+ Input argument.
+ y : PrimExpr
+ Input argument.
+
+ Returns
+ -------
+ z : PrimExpr
+ The result.
+ """
+ return _ffi_api.right_shift(x, y, span)
+
+
def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
@@ -1306,8 +1827,8 @@ def truncmod(a, b, span=None):
return _ffi_api._OpTruncMod(a, b, span) # type: ignore
-def floordiv(a, b, span=None):
- """Compute the floordiv of two expressions.
+def ceildiv(a, b, span=None):
+ """Compute the ceildiv of two expressions.
Parameters
----------
@@ -1325,11 +1846,11 @@ def floordiv(a, b, span=None):
res : PrimExpr
The result expression.
"""
- return _ffi_api._OpFloorDiv(a, b, span) # type: ignore
+ return _ffi_api._OpCeilDiv(a, b, span) # type: ignore
-def floormod(a, b, span=None):
- """Compute the floormod of two expressions.
+def floordiv(a, b, span=None):
+ """Compute the floordiv of two expressions.
Parameters
----------
@@ -1347,27 +1868,29 @@ def floormod(a, b, span=None):
res : PrimExpr
The result expression.
"""
- return _ffi_api._OpFloorMod(a, b, span) # type: ignore
+ return _ffi_api._OpFloorDiv(a, b, span) # type: ignore
-def ceildiv(lhs, rhs, span=None):
- """Generic ceildiv operator.
+def floormod(a, b, span=None):
+ """Compute the floormod of two expressions.
Parameters
----------
- lhs : object
- The left operand.
- rhs : object
- The right operand.
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
span : Optional[Span]
The location of this operator in the source.
Returns
-------
- op : tvm.Expr
- The result Expr of ceildiv operaton.
+ res : PrimExpr
+ The result expression.
"""
- return _ffi_api._OpCeilDiv(lhs, rhs, span) # type: ignore
+ return _ffi_api._OpFloorMod(a, b, span) # type: ignore
def comm_reducer(fcombine, fidentity, name="reduce"):
@@ -1523,6 +2046,22 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return reducer
+def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
+ return call_intrin(
+ "handle",
+ "tir.TVMBackendAllocWorkspace",
+ device_type,
+ device_id,
+ nbytes,
+ dtype_code_hint,
+ dtype_bits_hint,
+ )
+
+
+def TVMBackendFreeWorkspace(device_type, device_id, ptr):
+ call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)
+
+
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
index 30e047b4f7..0ebaf212d1 100644
--- a/python/tvm/tir/schedule/block_scope.py
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -20,8 +20,8 @@ from typing import List, Optional, Union
from tvm._ffi import register_object
from tvm.runtime import Object
-from tvm.tir import Block, For
+from ..stmt import Block, For
from . import _ffi_api
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index cf031c014c..f26f954d51 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -21,9 +21,11 @@ from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
-from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
-from ..function import IndexMap
+from ..buffer import Buffer
+from ..expr import FloatImm, IntImm
+from ..function import IndexMap, PrimFunc
+from ..stmt import Block, For
from . import _ffi_api
from ._type_checker import type_checked
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..3aed52fb50 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -22,8 +22,9 @@ from typing import Dict, Optional, Union
from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import Object
-from tvm.tir import Block, BlockRealize, For, PrimFunc
+from ..function import PrimFunc
+from ..stmt import Block, BlockRealize, For
from . import _ffi_api
from .block_scope import BlockScope, StmtSRef
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4847e377de..3c2228e6d9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -754,3 +754,7 @@ def stmt_list(stmt):
res += stmt_list(x)
return res
return [stmt]
+
+
+def type_annotation(dtype, span=None):
+ return _ffi_api.TypeAnnotation(dtype, span)
diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py
index a3b47ff6d5..f0725b666e 100644
--- a/python/tvm/tir/tensor_intrin/__init__.py
+++ b/python/tvm/tir/tensor_intrin/__init__.py
@@ -16,8 +16,4 @@
# under the License.
# pylint: disable=unused-import
"""Intrinsics for tensorization."""
-from .x86 import *
-from .arm_cpu import *
-from .dot_product_common import *
-from .rocm import *
-from .cuda import *
+from . import arm_cpu, cuda, rocm, x86
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py
index 3e934e1b9d..78b57d5fe1 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -17,8 +17,9 @@
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for ARM tensorization."""
from tvm.script import tir as T
-from .. import TensorIntrin
+from .. import TensorIntrin
+from .dot_product_common import DP4A_INTRIN # pylint: disable=unused-import
# TODO(masahi): Parametrize the TVMScript description of dot product by
# shape and dtype, and share the common description with x86.
diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py
index 4ac9338ba8..028402a756 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -137,7 +137,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(shared[v0, v1])
- thread_id, local_id = index_map(v0, v1)
+ thread_id, local_id = T.inline(index_map(v0, v1))
T.writes(warp[thread_id, local_id])
warp[thread_id, local_id] = shared[v0, v1]
@@ -242,11 +242,11 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
- b_row_ind, b_col_ind = maybe_swap(k, j)
+ b_row_ind, b_col_ind = T.inline(maybe_swap(k, j))
- thread_id_C, local_id_C = index_map_C(i, j)
- thread_id_A, local_id_A = index_map_A(i, k)
- thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)
+ thread_id_C, local_id_C = T.inline(index_map_C(i, j))
+ thread_id_A, local_id_A = T.inline(index_map_A(i, k))
+ thread_id_B, local_id_B = T.inline(index_map_B(b_row_ind, b_col_ind))
T.reads(
C[thread_id_C, local_id_C],
@@ -338,7 +338,7 @@ def get_mma_fill_intrin(dtype, local_size):
for i0, i1 in T.grid(M_DIM, N_DIM):
with T.block("C_warp"):
i, j = T.axis.remap("SS", [i0, i1])
- thread_id, local_id = index_map(i, j)
+ thread_id, local_id = T.inline(index_map(i, j))
T.reads()
T.writes(C_warp[thread_id, local_id])
C_warp[thread_id, local_id] = zero
@@ -375,7 +375,7 @@ def get_mma_store_intrin(dtype, local_size, scope="global"):
for i0, i1 in T.grid(M_DIM, N_DIM):
with T.block("C_warp"):
v0, v1 = T.axis.remap("SS", [i0, i1])
- thread_id, local_id = index_map(v0, v1)
+ thread_id, local_id = T.inline(index_map(v0, v1))
T.reads(C_warp[thread_id, local_id])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[thread_id, local_id]
diff --git a/python/tvm/tir/tensor_intrin/rocm.py b/python/tvm/tir/tensor_intrin/rocm.py
index 7a989d0bcc..017b2722a8 100644
--- a/python/tvm/tir/tensor_intrin/rocm.py
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -37,7 +37,7 @@ def sdot4(
T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
T.int32(0),
- T.bool(1),
+ T.boolean(1),
dtype="int32",
)
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
index f472172cf3..86d8bef356 100644
--- a/python/tvm/tir/usmp/transform/transform.py
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -20,8 +20,9 @@
from typing import Dict
import tvm
-from tvm.tir import Stmt
-from tvm.tir.usmp.utils import PoolAllocation
+
+from ...stmt import Stmt
+from ..utils import PoolAllocation
from . import _ffi_api
diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc
index 336575a93e..06670c8cc2 100644
--- a/src/ir/diagnostic.cc
+++ b/src/ir/diagnostic.cc
@@ -71,7 +71,11 @@ DiagnosticBuilder Diagnostic::Help(Span span) {
/* Diagnostic Renderer */
TVM_REGISTER_NODE_TYPE(DiagnosticRendererNode);
-void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { (*this)->renderer(ctx); }
+void DiagnosticRenderer::Render(const DiagnosticContext& ctx) {
+ if ((*this)->renderer != nullptr) {
+ (*this)->renderer(ctx);
+ }
+}
TVM_DLL DiagnosticRenderer::DiagnosticRenderer(
TypedPackedFunc<void(DiagnosticContext ctx)> renderer) {
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index a3318bf94f..a7d95848da 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -21,6 +21,7 @@
* \file src/ir/expr.cc
* \brief The expression AST nodes for the common IR infra.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/runtime/registry.h>
@@ -49,6 +50,18 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
if (auto* ptr = ref.as<runtime::StringObj>()) {
return tir::StringImm(GetRef<runtime::String>(ptr));
}
+ if (auto* ptr = ref.as<tir::BufferRegionNode>()) {
+ tir::BufferRegion buffer_region = GetRef<tir::BufferRegion>(ptr);
+ Array<PrimExpr> indices;
+ for (Range r : buffer_region->region) {
+ if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+ indices.push_back(r->min);
+ } else {
+ indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+ }
+ }
+ return tir::BufferLoad(buffer_region->buffer, indices);
+ }
Optional<String> actual_type = ObjectTypeChecker<PrimExpr>::CheckAndGetMismatch(ref.get());
ICHECK(!actual_type.defined()) << "Expected type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but got " << actual_type.value();
diff --git a/src/ir/ir_builder.cc b/src/ir/ir_builder.cc
new file mode 100644
index 0000000000..9f42cdb168
--- /dev/null
+++ b/src/ir/ir_builder.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ir/ir_builder.h>
+#include <tvm/ir/module.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace ir_builder {
+
+void IRBuilderFrameNode::EnterWithScope() {
+ IRBuilder::Current()->frames.push_back(GetRef<IRBuilderFrame>(this));
+}
+
+void IRBuilderFrameNode::ExitWithScope() {
+ for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
+ (*it)();
+ }
+ this->callbacks.clear();
+ IRBuilder::Current()->frames.pop_back();
+}
+
+void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc<void()> callback) {
+ if (IRBuilder::Current()->frames.empty()) {
+ LOG(FATAL) << "ValueError: No frames in Builder to add callback";
+ }
+ IRBuilder::Current()->frames.back()->callbacks.push_back(callback);
+}
+
+IRBuilder::IRBuilder() {
+ ObjectPtr<IRBuilderNode> n = make_object<IRBuilderNode>();
+ n->frames.clear();
+ n->result = NullOpt;
+ data_ = n;
+}
+
+std::vector<IRBuilder>* ThreadLocalBuilderStack() {
+ thread_local std::vector<IRBuilder> stack;
+ return &stack;
+}
+
+void IRBuilder::EnterWithScope() {
+ IRBuilderNode* n = this->get();
+ CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: "
+ << n->frames.size()
+ << ". Please use a fresh new builder every time building IRs";
+ n->result = NullOpt;
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ stack->push_back(*this);
+}
+
+void IRBuilder::ExitWithScope() {
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ ICHECK(!stack->empty());
+ stack->pop_back();
+}
+
+IRBuilder IRBuilder::Current() {
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ CHECK(!stack->empty()) << "ValueError: No builder in current scope";
+ return stack->back();
+}
+
+IRModuleFrame IRModule() {
+ ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
+ n->global_vars.clear();
+ n->functions.clear();
+ return IRModuleFrame(n);
+}
+
+void IRModuleFrameNode::ExitWithScope() {
+ ICHECK_EQ(functions.size(), global_vars.size());
+ int n = functions.size();
+ Map<GlobalVar, BaseFunc> func_map;
+ for (int i = 0; i < n; ++i) {
+ func_map.Set(global_vars[i], functions[i]);
+ }
+ IRBuilder builder = IRBuilder::Current();
+ ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+ builder->result = tvm::IRModule(func_map);
+}
+
+namespace details {
+
+Namer::FType& Namer::vtable() {
+ static FType inst;
+ return inst;
+}
+
+void Namer::Name(ObjectRef node, String name) {
+ static const FType& f = vtable();
+ CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name;
+ CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \""
+ << node->GetTypeKey();
+ f(node, name);
+}
+
+} // namespace details
+
+TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode);
+TVM_REGISTER_NODE_TYPE(IRBuilderNode);
+TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameEnter")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameExit")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderFrameAddCallback")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::AddCallback);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); });
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderGet")
+ .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
+TVM_REGISTER_GLOBAL("ir_builder.IRModule").set_body_typed(IRModule);
+
+} // namespace ir_builder
+} // namespace tvm
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 7979c9f47a..e56a7bc4af 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -20,6 +20,7 @@
/*!
* \file expr.cc
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
@@ -828,12 +829,26 @@ TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) {
Array<PrimExpr> prim_expr_args;
for (const auto& it : args) {
- ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>())
+ ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() ||
+ it->IsInstance<IterVarNode>() || it->IsInstance<BufferRegionNode>())
<< "Argument " << it << " is not a string or primexpr";
if (const auto* str = it.as<runtime::StringObj>()) {
prim_expr_args.push_back(StringImm(str->data));
+ } else if (const auto* expr = it.as<PrimExprNode>()) {
+ prim_expr_args.push_back(GetRef<PrimExpr>(expr));
+ } else if (const auto* br = it.as<BufferRegionNode>()) {
+ BufferRegion buffer_region = GetRef<BufferRegion>(br);
+ Array<PrimExpr> indices;
+ for (Range r : buffer_region->region) {
+ if (arith::Analyzer().CanProveEqual(r->extent, 1)) {
+ indices.push_back(r->min);
+ } else {
+ indices.push_back(tir::Ramp(r->min, 1, Downcast<IntImm>(r->extent)->value));
+ }
+ }
+ prim_expr_args.push_back(BufferLoad(buffer_region->buffer, indices));
} else {
- prim_expr_args.push_back(Downcast<PrimExpr>(it));
+ prim_expr_args.push_back(Downcast<IterVar>(it).operator PrimExpr());
}
}
return Call(type, op, prim_expr_args, span);
@@ -1089,6 +1104,11 @@ BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span) {
<< "-dimensional indices provided.";
ObjectPtr<BufferLoadNode> node = make_object<BufferLoadNode>();
+ for (const PrimExpr& i : indices) {
+ ICHECK(i->dtype.is_int() || i->dtype.is_uint())
+ << "ValueError: index of BufferLoad should be int, but got type " << i->dtype
+ << " for index " << i;
+ }
node->buffer = std::move(buffer);
node->indices = std::move(indices);
node->span = std::move(span);
diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index 7e3d3d1075..b11ca6650a 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -22,11 +22,10 @@
* \brief Used by TVM Script parser to expand incomplete TIR input
*/
+#include "./script_complete.h"
+
#include <tvm/arith/int_set.h>
-#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
-#include <tvm/tir/stmt.h>
-#include <tvm/tir/stmt_functor.h>
#include <utility>
diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h
new file mode 100644
index 0000000000..8df0456646
--- /dev/null
+++ b/src/tir/ir/script/script_complete.h
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/ir/script/script_complete.h
+ * \brief Used by TVM Script parser to expand incomplete TIR input
+ */
+#ifndef TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_
+#define TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates);
+
+} // namespace tir
+} // namespace tvm
+#endif // TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 524204f3d3..ef11d10257 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -810,6 +810,14 @@ BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
CHECK_EQ(buffer->shape.size(), region.size())
<< "The dimension between " << buffer << " and region " << region
<< " mismatched, the buffer is " << buffer;
+ for (const Range& r : region) {
+ ICHECK(r->min->dtype.is_int() || r->min->dtype.is_uint())
+ << "ValueError: ranges of BufferRegion should be int, but got type " << r->min->dtype
+ << " for range " << r << " in its min value " << r->min;
+ ICHECK(r->extent->dtype.is_int() || r->extent->dtype.is_uint())
+ << "ValueError: ranges of BufferRegion should be int, but got type " << r->extent->dtype
+ << " for range " << r << " in its extent value " << r->extent;
+ }
ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>();
node->buffer = std::move(buffer);
node->region = std::move(region);
@@ -1091,5 +1099,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
TVM_REGISTER_OP("tir.type_annotation")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
+TVM_REGISTER_GLOBAL("tir.TypeAnnotation").set_body_typed(TypeAnnotation);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder.cc b/src/tir/ir_builder/ir_builder.cc
new file mode 100644
index 0000000000..96568e356c
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder.cc
@@ -0,0 +1,664 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+using tvm::tir::IterVar;
+
+Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
+ Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
+ String storage_scope, int align, int offset_factor, String buffer_type,
+ Optional<Array<IntImm>> axis_separators) {
+ Var buffer_data;
+ if (!data.defined()) {
+ DataType storage_dtype = dtype;
+ if (storage_dtype == DataType::Bool()) {
+ storage_dtype = DataType::Int(8);
+ }
+ buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope));
+ } else {
+ buffer_data = data.value();
+ }
+ if (!elem_offset.defined() && offset_factor) {
+ DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype;
+ elem_offset = tvm::tir::Var("elem_offset", shape_dtype);
+ }
+ return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
+ elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor,
+ (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault,
+ axis_separators.value_or(Array<IntImm>()));
+}
+
+DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
+ Optional<Var> data, Optional<Array<PrimExpr>> strides,
+ Optional<PrimExpr> elem_offset, String storage_scope, int align,
+ int offset_factor, String buffer_type,
+ Optional<Array<IntImm>> axis_separators) {
+ ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
+ n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
+ align, offset_factor, buffer_type, axis_separators);
+ return DeclBufferFrame(n);
+}
+
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
+ return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
+}
+
+BlockFrame Block(String name, bool no_realize) {
+ ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
+ n->name = name;
+ n->iter_vars.clear();
+ n->reads = NullOpt;
+ n->writes = NullOpt;
+ n->init = NullOpt;
+ n->alloc_buffers.clear();
+ n->match_buffers.clear();
+ n->annotations = NullOpt;
+ n->iter_values.clear();
+ n->predicate = NullOpt;
+ n->no_realize = no_realize;
+ return BlockFrame(n);
+}
+
+BlockInitFrame Init() { return BlockInitFrame(make_object<BlockInitFrameNode>()); }
+
+void Where(PrimExpr predicate) {
+ BlockFrame frame = FindBlockFrame("T.where");
+ if (frame->predicate.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is "
+ << frame->predicate;
+ }
+ frame->predicate = predicate;
+}
+
+void Reads(Array<ObjectRef> buffer_slices) {
+ using namespace tvm::tir;
+ BlockFrame frame = FindBlockFrame("T.reads");
+ if (frame->reads.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads;
+ }
+ Array<BufferRegion> reads;
+ for (const ObjectRef& obj : buffer_slices) {
+ if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+ reads.push_back(GetRef<BufferRegion>(buffer_region));
+ } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+ reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+ } else {
+ LOG(FATAL) << "Invalid type for buffer reads.";
+ }
+ }
+ frame->reads = reads;
+}
+
+void Writes(Array<ObjectRef> buffer_slices) {
+ using namespace tvm::tir;
+ BlockFrame frame = FindBlockFrame("T.writes");
+ if (frame->writes.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is "
+ << frame->writes;
+ }
+ Array<BufferRegion> writes;
+ for (const ObjectRef& obj : buffer_slices) {
+ if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+ writes.push_back(GetRef<BufferRegion>(buffer_region));
+ } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+ writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+ } else {
+ LOG(FATAL) << "Invalid type for buffer writes.";
+ }
+ }
+ frame->writes = writes;
+}
+
+void BlockAttrs(Map<String, ObjectRef> attrs) {
+ BlockFrame frame = FindBlockFrame("T.block_attr");
+ if (frame->annotations.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
+ }
+ frame->annotations = attrs;
+}
+
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+ Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+ int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+ Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+ offset_factor, buffer_type_str, axis_separators);
+ IRBuilder builder = IRBuilder::Current();
+ if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) {
+ frame.value()->alloc_buffers.push_back(buffer);
+ } else if (Optional<PrimFuncFrame> frame = builder->GetLastFrame<PrimFuncFrame>()) {
+ frame.value()->root_alloc_buffers.push_back(buffer);
+ } else {
+ LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure "
+ "'T.alloc_buffer' is called under T.block() or T.prim_func()";
+ }
+ return buffer;
+}
+
+namespace axis {
+
+IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
+ if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+ BlockFrame frame = opt_frame.value();
+ frame->iter_vars.push_back(iter_var);
+ frame->iter_values.push_back(binding);
+ } else {
+ LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
+ }
+ return iter_var;
+}
+
+#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \
+ Var Method(Range dom, PrimExpr binding, DataType dtype) { \
+ ICHECK(dom.defined()) << Name << " axis must have a domain"; \
+ int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
+ return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \
+ /*iter_type=*/Kind, /*thread_tag=*/""), \
+ binding) \
+ ->var; \
+ }
+TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
+TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
+TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
+TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
+#undef TVM_TIR_IR_BUILDER_AXIS
+
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
+ using namespace tvm::tir;
+ Array<Var> results;
+ ICHECK_EQ(kinds.size(), bindings.size());
+ int n = bindings.size();
+ results.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ char c = kinds.c_str()[i];
+ PrimExpr e = bindings[i];
+ const VarNode* v = e.as<VarNode>();
+ ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
+ Range dom{nullptr};
+ for (const auto& frame : IRBuilder::Current()->frames) {
+ if (const auto* for_frame = frame.as<ForFrameNode>()) {
+ ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
+ int n = for_frame->doms.size();
+ for (int i = 0; i < n; ++i) {
+ if (for_frame->vars[i].get() == v) {
+ dom = for_frame->doms[i];
+ break;
+ }
+ }
+ if (dom.defined()) {
+ break;
+ }
+ }
+ }
+ ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
+ DataType dtype = v->dtype;
+ if (c == 'S') {
+ results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+ /*var=*/Var("", dtype),
+ /*iter_type=*/IterVarType::kDataPar,
+ /*thread_tag=*/""),
+ e)
+ ->var);
+ } else if (c == 'R') {
+ results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+ /*var=*/Var("", dtype),
+ /*iter_type=*/IterVarType::kCommReduce,
+ /*thread_tag=*/""),
+ e)
+ ->var);
+ } else {
+ LOG(FATAL) << "Unknown axis kind: " << c;
+ }
+ }
+ return results;
+}
+
+} // namespace axis
+
+#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \
+ ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) { \
+ PrimExpr min = start; \
+ PrimExpr extent = arith::Analyzer().Simplify(stop - start); \
+ ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \
+ int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
+ n->vars = {Var("v", DataType::Int(bits))}; \
+ n->doms = {Range::FromMinExtent(min, extent)}; \
+ n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, tvm::tir::Stmt body) { \
+ ICHECK_EQ(vars.size(), 1); \
+ ICHECK_EQ(doms.size(), 1); \
+ return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \
+ annotations.value_or(Map<String, ObjectRef>())); \
+ }; \
+ return ForFrame(n); \
+ }
+
+TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized);
+TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled);
+
+#undef TVM_TIR_IR_BUILDER_FOR_FRAME
+
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+ Optional<Map<String, ObjectRef>> annotations) {
+ using namespace tvm::tir;
+ PrimExpr min = start;
+ PrimExpr extent = arith::Analyzer().Simplify(stop - start);
+ ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+ int bits = std::max(min.dtype().bits(), extent.dtype().bits());
+ n->vars = {Var("v", DataType::Int(bits))};
+ n->doms = {Range::FromMinExtent(min, extent)};
+ n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
+ ICHECK_EQ(vars.size(), 1);
+ ICHECK_EQ(doms.size(), 1);
+ IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
+ thread);
+ return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
+ annotations.value_or(Map<String, ObjectRef>()));
+ };
+ return ForFrame(n);
+}
+
+ForFrame Grid(Array<PrimExpr> extents) {
+ using namespace tvm::tir;
+ ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
+ n->vars.reserve(extents.size());
+ n->doms.reserve(extents.size());
+ for (const auto& extent : extents) {
+ DataType dtype = extent.dtype();
+ n->vars.push_back(Var("v", extent.dtype()));
+ n->doms.push_back(Range(make_const(dtype, 0), extent));
+ }
+ n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, Stmt body) -> Stmt {
+ ICHECK_EQ(vars.size(), doms.size());
+ int n = vars.size();
+ for (int i = n - 1; i >= 0; --i) {
+ Range dom = doms[i];
+ Var var = vars[i];
+ body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
+ /*thread_binding=*/NullOpt, /*annotations=*/{});
+ }
+ return body;
+ };
+ return ForFrame(n);
+}
+
+PrimFuncFrame PrimFunc() {
+ ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>();
+ n->name = NullOpt;
+ n->args.clear();
+ n->ret_type = NullOpt;
+ n->buffer_map.clear();
+ n->preflattened_buffer_map.clear();
+ n->attrs = NullOpt;
+ n->env_threads.clear();
+ n->root_alloc_buffers.clear();
+ return PrimFuncFrame(n);
+}
+
+Var Arg(String name, Var var) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+ details::Namer::Name(var, name);
+ frame->args.push_back(var);
+ return var;
+}
+
+Buffer Arg(String name, Buffer buffer) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
+ details::Namer::Name(buffer, name);
+ Var handle(buffer->name + "_handle", DataType::Handle());
+ frame->args.push_back(handle);
+ frame->buffer_map.Set(handle, buffer);
+ return buffer;
+}
+
+void FuncName(String name) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
+ if (frame->name.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value();
+ }
+ frame->name = name;
+}
+
+void FuncAttrs(Map<String, ObjectRef> attrs) {
+ using namespace tvm::tir;
+ PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr");
+ if (frame->attrs.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
+ }
+ frame->attrs = attrs;
+}
+
+tvm::Type FuncRet(tvm::Type ret_type) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type");
+ if (frame->ret_type.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is "
+ << frame->ret_type.value();
+ }
+ frame->ret_type = ret_type;
+ return ret_type;
+}
+
+Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+ Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
+ int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
+ Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
+ offset_factor, buffer_type_str, axis_separators);
+ if (const auto* var = param.as<tvm::tir::VarNode>()) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer");
+ Var v = GetRef<Var>(var);
+ for (auto const& arg : frame->args) {
+ if (arg.same_as(v)) {
+ frame->buffer_map.Set(v, buffer);
+ return buffer;
+ }
+ }
+ LOG(FATAL) << "ValueError: Can not bind non-input param to buffer.";
+ } else if (const auto* buffer_load = param.as<tvm::tir::BufferLoadNode>()) {
+ BlockFrame frame = FindBlockFrame("T.match_buffer");
+ frame->match_buffers.push_back(tvm::tir::MatchBufferRegion(
+ buffer, BufferRegionFromLoad(GetRef<tvm::tir::BufferLoad>(buffer_load))));
+ } else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) {
+ BlockFrame frame = FindBlockFrame("T.match_buffer");
+ frame->match_buffers.push_back(
+ tvm::tir::MatchBufferRegion(buffer, GetRef<tvm::tir::BufferRegion>(buffer_region)));
+ } else {
+ LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
+ }
+ return buffer;
+}
+
+void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape, DataType dtype,
+ Optional<Var> data, Array<PrimExpr> strides, PrimExpr elem_offset,
+ String storage_scope, int align, int offset_factor, String buffer_type_str,
+ Array<IntImm> axis_separators) {
+ PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer");
+ for (auto const& p : frame->buffer_map) {
+ if (p.second.same_as(postflattened_buffer)) {
+ String buffer_name(postflattened_buffer->name + "_preflatten");
+ Buffer buffer =
+ BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset,
+ storage_scope, align, offset_factor, buffer_type_str, axis_separators);
+ details::Namer::Name(buffer, buffer_name);
+ frame->preflattened_buffer_map.Set(p.first, buffer);
+ return;
+ }
+ }
+ LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name
+ << " does not exist.";
+}
+
+AssertFrame Assert(PrimExpr condition, String message) {
+ ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>();
+ n->condition = condition;
+ n->message = tvm::tir::StringImm(message);
+ return AssertFrame(n);
+}
+
+LetFrame Let(Var var, PrimExpr value) {
+ ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
+ n->var = var;
+ n->value = value;
+ return LetFrame(n);
+}
+
+AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope,
+ Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) {
+ ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>();
+ n->extents = extents;
+ n->dtype = dtype;
+ n->storage_scope = storage_scope;
+ n->condition = condition.value_or(tvm::Bool(true));
+ n->annotations = annotations.value_or(Map<String, ObjectRef>());
+ n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
+ "default", NullOpt);
+ return AllocateFrame(n);
+}
+
+AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
+ Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+ ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
+ n->dtype = dtype;
+ n->extents = extents;
+ n->data = data;
+ n->annotations = annotations;
+ n->buffer =
+ BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+ return AllocateConstFrame(n);
+}
+
+LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
+ IterVar iter_var{nullptr};
+
+ if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
+ if (Optional<IterVar> opt_iter_var = opt_frame.value()->env_threads.Get(var)) {
+ iter_var = opt_iter_var.value();
+ } else {
+ LOG(INFO) << "ValueError: " << var->name_hint
+ << " is not an env_thread created using T.env_thread.";
+ }
+ } else {
+ LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc";
+ }
+ ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
+ if (!iter_var->dom.defined()) {
+ const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+ } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
+ LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
+ << iter_var->dom->extent << " vs " << extent;
+ }
+ n->iter_var = iter_var;
+ n->extent = extent;
+ n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent";
+ return LaunchThreadFrame(n);
+}
+
+RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
+ PrimExpr condition) {
+ ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
+ n->buffer_slice = buffer_slice;
+ n->storage_scope = storage_scope;
+ n->condition = condition;
+ return RealizeFrame(n);
+}
+
+AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) {
+ ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
+ n->node = node;
+ n->attr_key = attr_key;
+ n->value = value;
+ return AttrFrame(n);
+}
+
+WhileFrame While(PrimExpr condition) {
+ ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>();
+ n->condition = condition;
+ return WhileFrame(n);
+}
+
+IfFrame If(PrimExpr condition) {
+ ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>();
+ n->condition = condition;
+ n->then_stmts = NullOpt;
+ n->else_stmts = NullOpt;
+ return IfFrame(n);
+}
+
+ThenFrame Then() {
+ ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>();
+ return ThenFrame(n);
+}
+
+ElseFrame Else() {
+ ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>();
+ return ElseFrame(n);
+}
+
+Var EnvThread(String thread_tag) {
+ IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
+ thread_tag);
+ Var var = iter_var->var;
+ if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
+ opt_frame.value()->env_threads.Set(var, iter_var);
+ } else {
+ LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
+ }
+ return var;
+}
+
+void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+ AddToParent(tvm::tir::BufferStore(buffer, value, indices));
+}
+
+void Prefetch(Buffer buffer, Array<Range> bounds) {
+ AddToParent(tvm::tir::Prefetch(buffer, bounds));
+}
+
+void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
+
+using tvm::ir_builder::details::Namer;
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+ .set_dispatch<tvm::tir::BufferNode>([](const ObjectRef& node, String name) -> void {
+ tvm::tir::BufferNode* buffer =
+ const_cast<tvm::tir::BufferNode*>(node.as<tvm::tir::BufferNode>());
+ buffer->name = name;
+ Namer::Name(buffer->data, name);
+ int n = buffer->strides.size();
+ for (int i = 0; i < n; ++i) {
+ PrimExpr e = buffer->strides[i];
+ if (const tvm::tir::VarNode* v = e.as<tvm::tir::VarNode>()) {
+ Namer::Name(GetRef<tvm::tir::Var>(v), name + "_s" + std::to_string(i));
+ }
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+ .set_dispatch<tvm::tir::SizeVarNode>([](const ObjectRef& node, String name) -> void {
+ using namespace tvm::tir;
+ SizeVarNode* var = const_cast<SizeVarNode*>(node.as<SizeVarNode>());
+ var->name_hint = name;
+ });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+ .set_dispatch<tvm::tir::VarNode>([](const ObjectRef& node, String name) -> void {
+ using namespace tvm::tir;
+ VarNode* var = const_cast<VarNode*>(node.as<VarNode>());
+ var->name_hint = name;
+ });
+
+TVM_STATIC_IR_FUNCTOR(Namer, vtable)
+ .set_dispatch<tvm::tir::IterVarNode>([](const ObjectRef& node, String name) -> void {
+ using namespace tvm::tir;
+ IterVarNode* var = const_cast<IterVarNode*>(node.as<IterVarNode>());
+ Namer::Name(var->var, name);
+ });
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Ptr").set_body_typed(Ptr);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Block").set_body_typed(Block);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Init").set_body_typed(Init);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Where").set_body_typed(Where);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Reads").set_body_typed(Reads);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Writes").set_body_typed(Writes);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Serial").set_body_typed(Serial);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Parallel").set_body_typed(Parallel);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Vectorized").set_body_typed(Vectorized);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Unroll").set_body_typed(Unroll);
+TVM_REGISTER_GLOBAL("ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Grid").set_body_typed(Grid);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Arg")
+ .set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
+ using namespace tvm::tir;
+ if (const auto* var = obj.as<VarNode>()) {
+ return Arg(name, GetRef<tvm::tir::Var>(var));
+ }
+ if (const auto* buffer = obj.as<BufferNode>()) {
+ return Arg(name, GetRef<Buffer>(buffer));
+ }
+ LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey();
+ throw;
+ });
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncName").set_body_typed(FuncName);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
+TVM_REGISTER_GLOBAL("ir_builder.tir.FuncRet").set_body_typed(FuncRet);
+TVM_REGISTER_GLOBAL("ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Assert").set_body_typed(Assert);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Let").set_body_typed(Let);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Allocate").set_body_typed(Allocate);
+TVM_REGISTER_GLOBAL("ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Realize").set_body_typed(Realize);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Attr").set_body_typed(Attr);
+TVM_REGISTER_GLOBAL("ir_builder.tir.While").set_body_typed(While);
+TVM_REGISTER_GLOBAL("ir_builder.tir.If").set_body_typed(If);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Then").set_body_typed(Then);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Else").set_body_typed(Else);
+TVM_REGISTER_GLOBAL("ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
+TVM_REGISTER_GLOBAL("ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.EnvThread").set_body_typed(EnvThread);
+TVM_REGISTER_GLOBAL("ir_builder.tir.BufferStore").set_body_typed(BufferStore);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Prefetch").set_body_typed(Prefetch);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Evaluate").set_body_typed(Evaluate);
+
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int8").set_body_typed(Int8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int16").set_body_typed(Int16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32").set_body_typed(Int32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int64").set_body_typed(Int64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt8").set_body_typed(UInt8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt16").set_body_typed(UInt16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt32").set_body_typed(UInt32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.UInt64").set_body_typed(UInt64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float8").set_body_typed(Float8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float16").set_body_typed(Float16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float32").set_body_typed(Float32);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Float64").set_body_typed(Float64);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32x4").set_body_typed(Int32x4);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32x8").set_body_typed(Int32x8);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Int32x16").set_body_typed(Int32x16);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Boolean").set_body_typed(Boolean);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Handle").set_body_typed(Handle);
+TVM_REGISTER_GLOBAL("ir_builder.tir.Void").set_body_typed(Void);
+TVM_REGISTER_GLOBAL("ir_builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+ return tvm::min(a, b);
+});
+TVM_REGISTER_GLOBAL("ir_builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr {
+ return tvm::max(a, b);
+});
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
diff --git a/src/tir/ir_builder/ir_builder_frame.cc b/src/tir/ir_builder/ir_builder_frame.cc
new file mode 100644
index 0000000000..18ccc2a1ad
--- /dev/null
+++ b/src/tir/ir_builder/ir_builder_frame.cc
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/function.h>
+#include <tvm/tir/ir_builder.h>
+
+#include "../../tir/ir/script/script_complete.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+void BlockFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ Array<tvm::tir::Buffer> tir_alloc_buffers;
+ for (const tvm::tir::Buffer& buffer : alloc_buffers) {
+ tir_alloc_buffers.push_back(buffer);
+ }
+ Map<String, ObjectRef> attrs = annotations.value_or({});
+ if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) {
+ attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access));
+ }
+ tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()),
+ writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init,
+ tir_alloc_buffers, match_buffers, attrs);
+ if (no_realize) {
+ CHECK(iter_values.empty())
+ << "ValueError: Block bindings are not allowed when `no_realize=True`";
+ CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
+ AddToParent(block);
+ } else {
+ AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
+ }
+}
+
+void BlockInitFrameNode::EnterWithScope() {
+ BlockFrame frame = FindBlockFrame("T.init");
+ if (frame->init.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate block init declaration";
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void BlockInitFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ BlockFrame frame = FindBlockFrame("T.init");
+ frame->init = AsStmt(stmts);
+}
+
+void ForFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
+}
+
+void PrimFuncFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ tvm::tir::PrimFunc func(
+ /*params=*/args,
+ /*body=*/AsStmt(stmts),
+ /*ret_type=*/ret_type.value_or(TupleType::Empty()),
+ /*buffer_map=*/buffer_map,
+ /*preflattened_buffer_map=*/preflattened_buffer_map,
+ /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue<DictAttrs>());
+ func = tvm::tir::ScriptComplete(func, root_alloc_buffers);
+ IRBuilder builder = IRBuilder::Current();
+ if (builder->frames.empty()) {
+ ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+ builder->result = func;
+ } else if (Optional<IRModuleFrame> opt_frame = builder->FindFrame<IRModuleFrame>()) {
+ IRModuleFrame frame = opt_frame.value();
+ frame->global_vars.push_back(GlobalVar(name.value_or("")));
+ frame->functions.push_back(func);
+ } else {
+ LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
+ }
+}
+
+void AssertFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts)));
+}
+
+void LetFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts)));
+}
+
+void AllocateFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
+ AsStmt(stmts), annotations));
+}
+
+void AllocateConstFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(
+ tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+}
+
+void LaunchThreadFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts)));
+}
+
+void RealizeFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope",
+ tvm::tir::StringImm(storage_scope),
+ tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region,
+ condition, AsStmt(stmts))));
+}
+
+void AttrFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts)));
+}
+
+void WhileFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::While(condition, AsStmt(stmts)));
+}
+
+void IfFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ if (!stmts.empty()) {
+ LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame";
+ }
+ if (!then_stmts.defined()) {
+ LOG(FATAL) << "IfThenElse frame should have at least one then branch";
+ }
+ AddToParent(tvm::tir::IfThenElse(
+ condition, AsStmt(then_stmts.value()),
+ else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr)));
+}
+
+void ThenFrameNode::EnterWithScope() {
+ IfFrame frame = FindIfFrame("T.then_");
+ if (frame->then_stmts.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is "
+ << frame->then_stmts.value();
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void ThenFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ FindIfFrame("T.then_")->then_stmts = stmts;
+}
+
+void ElseFrameNode::EnterWithScope() {
+ IfFrame frame = FindIfFrame("T.else_");
+ if (!frame->then_stmts.defined()) {
+ LOG(FATAL) << "The else branch should follow then branch";
+ }
+ if (frame->else_stmts.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is "
+ << frame->else_stmts.value();
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void ElseFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ FindIfFrame("T.else_")->else_stmts = stmts;
+}
+
+void DeclBufferFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+}
+
+TVM_REGISTER_NODE_TYPE(TIRFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
+TVM_REGISTER_NODE_TYPE(ForFrameNode);
+TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
+TVM_REGISTER_NODE_TYPE(AssertFrameNode);
+TVM_REGISTER_NODE_TYPE(LetFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateFrameNode);
+TVM_REGISTER_NODE_TYPE(AllocateConstFrameNode);
+TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode);
+TVM_REGISTER_NODE_TYPE(RealizeFrameNode);
+TVM_REGISTER_NODE_TYPE(AttrFrameNode);
+TVM_REGISTER_NODE_TYPE(WhileFrameNode);
+TVM_REGISTER_NODE_TYPE(IfFrameNode);
+TVM_REGISTER_NODE_TYPE(ThenFrameNode);
+TVM_REGISTER_NODE_TYPE(ElseFrameNode);
+TVM_REGISTER_NODE_TYPE(DeclBufferFrameNode);
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
diff --git a/src/tir/ir_builder/utils.h b/src/tir/ir_builder/utils.h
new file mode 100644
index 0000000000..6d9cbda72a
--- /dev/null
+++ b/src/tir/ir_builder/utils.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_IR_BUILDER_UTILS_H_
+#define TVM_TIR_IR_BUILDER_UTILS_H_
+
+#include <tvm/tir/ir_builder.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace ir_builder {
+namespace tir {
+
+inline void AddToParent(tvm::tir::Stmt stmt) {
+ IRBuilder builder = IRBuilder::Current();
+ if (builder->frames.empty()) {
+ ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
+ builder->result = stmt;
+ } else if (const auto* tir_frame = builder->frames.back().as<TIRFrameNode>()) {
+ GetRef<TIRFrame>(tir_frame)->stmts.push_back(stmt);
+ } else {
+ LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back();
+ }
+}
+
+inline tvm::tir::Stmt AsStmt(const Array<tvm::tir::Stmt>& stmt) {
+ using namespace tvm::tir;
+ if (stmt.empty()) {
+ return tvm::tir::Evaluate(0);
+ } else if (stmt.size() == 1) {
+ return stmt[0];
+ } else {
+ return SeqStmt(stmt);
+ }
+}
+
+inline BlockFrame FindBlockFrame(const String& method) {
+ if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+ return frame.value();
+ }
+ LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
+ << "' is called under T.block()";
+ throw;
+}
+
+inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
+ if (Optional<PrimFuncFrame> frame = IRBuilder::Current()->GetLastFrame<PrimFuncFrame>()) {
+ return frame.value();
+ }
+ LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method
+ << "' is called under T.prim_func()";
+ throw;
+}
+
+inline IfFrame FindIfFrame(const String& method) {
+ if (Optional<IfFrame> frame = IRBuilder::Current()->GetLastFrame<IfFrame>()) {
+ return frame.value();
+ } else {
+ LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method
+ << "' is called under T.if_()";
+ }
+ throw;
+}
+
+inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) {
+ Array<Range> ranges;
+ for (const PrimExpr& index : buffer_load->indices) {
+ ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1)));
+ }
+ return tvm::tir::BufferRegion(buffer_load->buffer, ranges);
+}
+
+} // namespace tir
+} // namespace ir_builder
+} // namespace tvm
+
+#endif // TVM_TIR_IR_BUILDER_UTILS_H_
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 114571218b..dd3c7a0b0f 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
{x, y, q, s}, span);
}
+// address_of
+PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
+ return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
+}
+
+// lookup_param
+PrimExpr lookup_param(String param_name, Span span) {
+ return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
+ span);
+}
+
// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
@@ -702,6 +713,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
}
}
+// isnullptr
+PrimExpr isnullptr(PrimExpr x, Span span) {
+ return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
+}
+
// isinf
PrimExpr isinf(PrimExpr x, Span span) {
DataType t = DataType::Bool(x.dtype().lanes());
@@ -931,6 +947,8 @@ TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
+TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);
+
TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);
TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
@@ -949,6 +967,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
+TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity);
+
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
@@ -997,6 +1019,8 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, right_shift);
+TVM_REGISTER_GLOBAL("tir._OpNot").set_body_typed(logical_not);
+
TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
return if_then_else(cond, true_value, false_value, span);
diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc
index c3b8fd6766..4e0aa4ae24 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -121,7 +121,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info,
// Create block vars, block's accessed region and accessing indices
for (const PrimExpr& dim : cache_region->buffer->shape) {
Var var("v" + std::to_string(access_indices.size()), dim.dtype());
- block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, dim),
+ block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(make_zero(dim->dtype), dim),
/*var=*/var,
/*IterVarType=*/kDataPar));
access_indices.push_back(var);
diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_meta_schedule_auto_tensorize.py
index 5a60c19568..04ccdc514d 100644
--- a/tests/python/integration/test_meta_schedule_auto_tensorize.py
+++ b/tests/python/integration/test_meta_schedule_auto_tensorize.py
@@ -14,12 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Integration test for metascheduler's auto tensorization."""
+"""Integration test for MetaSchedule's auto tensorization."""
import tempfile
import numpy as np
import pytest
-
import tvm
import tvm.testing
import tvm.topi.testing
@@ -29,8 +28,9 @@ from tvm.meta_schedule import ApplyHistoryBest, postproc, schedule_rule
from tvm.meta_schedule.relay_integration import extract_task_from_relay
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.meta_schedule.tune import tune_extracted_tasks
-from tvm.tir.tensor_intrin import AMDGPU_SDOT4_INTRIN, DP4A_INTRIN
-from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN
+from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
CONFIG = ms.TuneConfig(
strategy="evolutionary",
@@ -393,7 +393,7 @@ def test_cuda_tensor_core(model_name, input_shape):
)
print(profiler.table())
- # Compile without meta-scheduler for correctness check
+ # Compile without MetaSchedule for correctness check
with tvm.transform.PassContext(opt_level=0):
rt_mod2 = relay.build(mod, target=target, params=params)
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py
index 9c597a55e5..bf0edaef5c 100644
--- a/tests/python/unittest/test_aot_legalize_packed_call.py
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -85,7 +85,7 @@ class Expected:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
- T.cast(0, dtype="float32"),
+ T.Cast("float32", 0),
0,
dtype="handle",
),
@@ -94,7 +94,7 @@ class Expected:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
- T.cast(0, dtype="float32"),
+ T.Cast("float32", 0),
0,
dtype="handle",
),
@@ -103,7 +103,7 @@ class Expected:
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
- T.cast(0, dtype="float32"),
+ T.Cast("float32", 0),
0,
dtype="handle",
),
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
index a1184c1edf..fc624cd5a6 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py
@@ -16,9 +16,9 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
-import tvm.tir.tensor_intrin
from tvm.meta_schedule import TuneContext, postproc
from tvm.script import tir as T
+from tvm.tir.tensor_intrin import arm_cpu, cuda, rocm, x86
@tvm.script.ir_module
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index d415ae9ce6..4da870e455 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -31,8 +31,8 @@ from tvm.meta_schedule.tune_context import TuneContext
from tvm.script import tir as T
from tvm.target import Target
from tvm.te import create_prim_func
-from tvm.tir.tensor_intrin import DP4A_INTRIN
-from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
+from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
def _create_context(mod, target, rule) -> TuneContext:
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index ce333887ec..97e8a69556 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -910,7 +910,7 @@ def test_cuda_nrm():
for i0_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("D"):
b = T.axis.spatial(1, i0_1)
- T.where(0 * 128 + i0_1 < 1)
+ T.where(T.int32(0) * T.int32(128) + i0_1 < 1)
T.reads(C_shared[b])
T.writes(D[b])
D[b] = T.sqrt(C_shared[b], dtype="float32")
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py
index d86b6fe48b..7d85b8757a 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -152,7 +152,7 @@ def test_meta_schedule_tune_relay(
work_dir=work_dir,
)
print(profiler.table())
- # Compile without meta-scheduler for correctness check
+ # Compile without meta-schedule for correctness check
with tvm.transform.PassContext(opt_level=0):
rt_mod2 = relay.build(mod, target=target, params=params)
@@ -252,7 +252,7 @@ def test_meta_schedule_te2primfunc_argument_order():
):
rt_mod1 = relay.build(mod, target=target, params=params)
- # Compile without meta-scheduler for correctness check
+ # Compile without meta-schedule for correctness check
with tvm.transform.PassContext(opt_level=0):
rt_mod2 = relay.build(mod, target=target, params=params)
@@ -314,7 +314,7 @@ def test_meta_schedule_relay_lowering():
):
rt_mod1 = relay.build(mod, target=target, params=params)
- # Compile without meta-scheduler for correctness check
+ # Compile without meta-schedule for correctness check
with tvm.transform.PassContext(opt_level=0):
rt_mod2 = relay.build(mod, target=target, params=params)
@@ -516,7 +516,7 @@ def test_tune_relay_manual_tir_vnni():
attrs={"schedule_rule": "meta_schedule.dense_vnni"},
)
- When the meta scheduler encounters a TensorIR block with the "schedule_rule" annotation,
+ When the MetaSchedule encounters a TensorIR block with the "schedule_rule" annotation,
it looks up the packed func registry for a function that is associated with the given schedule
rule key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule
functions must be
diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py
index 18bf9d1184..c576483828 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -17,20 +17,19 @@
import collections
import ctypes
import json
+import math
+import re
import sys
+import numpy as np
+import pytest
import tvm
import tvm.testing
from tvm import te
+from tvm.contrib import clang, utils
from tvm.relay.backend import Runtime
-from tvm.contrib import utils, clang
-from tvm.target.codegen import llvm_lookup_intrinsic_id, llvm_get_intrinsic_name
-import tvm.script.tir as T
-import numpy as np
-
-import math
-import re
-import pytest
+from tvm.script import tir as T
+from tvm.target.codegen import llvm_get_intrinsic_name, llvm_lookup_intrinsic_id
@tvm.testing.requires_llvm
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py
index 93b7caf9cd..0e08c0df8c 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -16,7 +16,6 @@
# under the License.
import pytest
-
import tvm
from tvm.script import tir as T
@@ -63,10 +62,23 @@ def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None:
@tvm.ir.register_op_attr("tir.intrin_test", "")
-def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1):
+def _intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1):
return 0
+def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1, dtype):
+ return tvm.tir.call_intrin(
+ dtype,
+ "tir.intrin_test",
+ data,
+ elem_offset,
+ stride_0,
+ stride_1,
+ shape_0,
+ shape_1,
+ )
+
+
@T.prim_func
def opaque_access(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (32, 64, 128))
@@ -82,7 +94,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_A.data,
sub_A.elem_offset,
sub_A.strides[0],
@@ -105,7 +117,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_B.data,
sub_B.elem_offset,
sub_B.strides[0],
@@ -126,7 +138,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None:
T.reads([])
T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16])
T.evaluate(
- T.intrin_test(
+ intrin_test(
A.data,
i * 131072 + j * 128 + k * 16,
8192,
@@ -141,7 +153,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None:
T.reads([])
T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8])
T.evaluate(
- T.intrin_test(
+ intrin_test(
B.data,
i * 4096 + j * 2048 + k * 8,
64,
@@ -169,7 +181,7 @@ def high_dim_opaque_access(a: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_A.data,
sub_A.elem_offset,
sub_A.strides[0],
@@ -189,7 +201,7 @@ def transformed_high_dim_opaque_access(a: T.handle) -> None:
T.reads([])
T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
T.evaluate(
- T.intrin_test(
+ intrin_test(
A.data,
i * 2048 + j * 1024 + k * 16,
64,
@@ -217,7 +229,7 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_A.data,
sub_A.elem_offset,
sub_A.strides[0],
@@ -237,7 +249,7 @@ def transformed_high_dim_opaque_access_with_source_strides(a: T.handle) -> None:
T.reads([])
T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
T.evaluate(
- T.intrin_test(
+ intrin_test(
A.data,
i * 2576 + j * 1280 + k * 16,
80,
@@ -298,7 +310,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_sub_A.data,
sub_sub_A.elem_offset,
sub_sub_A.strides[0],
@@ -343,7 +355,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None:
]
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
A.data,
i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4,
64,
@@ -375,7 +387,7 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None:
sub_A[ii, jj] = 1
for j in range(0, 4):
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_B.data,
sub_B.elem_offset,
sub_B.strides[0],
@@ -399,7 +411,7 @@ def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32)
A[i * m + ii, jj] = 1
for j in range(0, 4):
T.evaluate(
- T.intrin_test(
+ intrin_test(
B.data,
i * n * (m * 4),
m * 4,
@@ -423,7 +435,7 @@ def rank0_buffer(a: T.handle, b: T.handle) -> None:
sub_B = T.match_buffer(B[i, j], (), offset_factor=1)
sub_A[()] = 1
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_B.data,
sub_B.elem_offset,
0,
@@ -445,7 +457,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None:
T.writes([A[i, j], B[i, j]])
A[i, j] = 1
T.evaluate(
- T.intrin_test(
+ intrin_test(
B.data,
i * 8 + j,
0,
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py
index a97060f01b..929a6cfa19 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -16,19 +16,20 @@
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys
+
import pytest
import tvm
import tvm.testing
-from tvm import tir, te
+from tvm import te, tir
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
-from tvm.tir.tensor_intrin import (
- VNNI_DOT_16x4_INTRIN,
+from tvm.tir.tensor_intrin.arm_cpu import (
+ DP4A_INTRIN,
ARM_DOT_4x4_i8_NEON_INTRIN,
ARM_DOT_4x4_i8_SDOT_INTRIN,
- AMDGPU_SDOT4_INTRIN,
- DP4A_INTRIN,
)
+from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py
index 01496e0e0f..056a13cb1a 100644
--- a/tests/python/unittest/test_tir_transform_helpers.py
+++ b/tests/python/unittest/test_tir_transform_helpers.py
@@ -15,15 +15,14 @@
# specific language governing permissions and limitations
# under the License.
import pytest
-
import tvm
-from tvm.script import tir as T
import tvm.testing
+from tvm.script import tir as T
def test_annotate_entry_func_single_primfunc():
@tvm.script.ir_module
- class MockModule:
+ class MockModule2:
@T.prim_func
def func1(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
@@ -31,7 +30,7 @@ def test_annotate_entry_func_single_primfunc():
if i == 5:
A[i] = 0.0
- mod = MockModule
+ mod = MockModule2
assert mod
assert mod["func1"].attrs is None
after = tvm.tir.transform.AnnotateEntryFunc()(mod)
diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py b/tests/python/unittest/test_tir_transform_hoist_expression.py
index 8b7fc98bfd..26464baed3 100644
--- a/tests/python/unittest/test_tir_transform_hoist_expression.py
+++ b/tests/python/unittest/test_tir_transform_hoist_expression.py
@@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import tir
import tvm.testing
-
+from tvm import tir
+from tvm.script import from_source
from tvm.script import tir as T
-from tvm.tir.transform import HoistExpression, HoistedConditionals, HoistedLetBindings
+from tvm.tir.transform import HoistedConditionals, HoistedLetBindings, HoistExpression
class BaseBeforeAfter:
@@ -27,7 +27,7 @@ class BaseBeforeAfter:
hoisted_let_bindings = tvm.testing.parameter(HoistedLetBindings.All)
def test_hoist(self, hoisted_conditionals, hoisted_let_bindings):
- before = self.before
+ before = from_source(self.before)
before_mod = tvm.IRModule.from_expr(before)
config = {
@@ -41,7 +41,7 @@ class BaseBeforeAfter:
after_mod = tvm.tir.transform.HoistExpression()(before_mod)
after = after_mod["main"]
- expected = self.expected
+ expected = from_source(self.expected)
try:
tvm.ir.assert_structural_equal(after, expected)
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index edaeb7c9b6..34f988c77c 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -14,16 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import pytest
import sys
-import numpy as np
+import numpy as np
+import pytest
import tvm
import tvm.testing
import tvm.tir.tensor_intrin.cuda
-from tvm import tir, te, TVMError
-from tvm.script import tir as T
+from tvm import TVMError, te, tir
from tvm.meta_schedule.testing import te_workload
+from tvm.script import tir as T
from tvm.testing.tir import mma_schedule
from tvm.tir.tensor_intrin.cuda import (
LDMATRIX_16x16_A_DYN_INTRIN,
@@ -1060,7 +1060,7 @@ def test_simple_compute_async():
T.writes(B[0, tx, 0])
with T.attr(0, "async_commit_queue_scope", 0):
with T.attr(0, "async_scope", 1):
- B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2)
+ B[T.int32(0) % 2, tx, 0] = A[tx, 0] * T.float32(2)
with T.block():
T.reads(A[tx, 1:16], B[0:2, tx, 0])
T.writes(B[0:2, tx, 0], C[tx, 0:15])
@@ -1080,11 +1080,11 @@ def test_simple_compute_async():
with T.attr(0, "async_wait_inflight_count", 1):
C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1)
with T.block():
- T.reads(B[15 % 2, tx, 0])
+ T.reads(B[T.int32(15) % 2, tx, 0])
T.writes(C[tx, 15])
with T.attr(0, "async_wait_queue_scope", 0):
with T.attr(0, "async_wait_inflight_count", 0):
- C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1)
+ C[tx, 15] = B[T.int32(15) % 2, tx, 0] + T.float32(1)
tvm.ir.assert_structural_equal(mod["main"], ref, True)
diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
index b96afb6a09..f80573c43b 100644
--- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
+++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import tvm
+import tvm.testing
from tvm import te
-
from tvm.script import tir as T
vthread_name = tvm.testing.parameter("vthread", "cthread")
@@ -153,10 +153,10 @@ def test_vthread_simplified():
B = T.allocate([16], "int32", "shared")
# The indices for B should each be a single Ramp node, and
# should not be the sum of a Ramp and Broadcast node.
- B[0 * 4 : 0 * 4 + 4] = T.broadcast(0, 4)
- B[1 * 4 : 1 * 4 + 4] = T.broadcast(1, 4)
- B[2 * 4 : 2 * 4 + 4] = T.broadcast(2, 4)
- B[3 * 4 : 3 * 4 + 4] = T.broadcast(3, 4)
+ B[T.int32(0) * 4 : T.int32(0) * 4 + 4] = T.broadcast(0, 4)
+ B[T.int32(1) * 4 : T.int32(1) * 4 + 4] = T.broadcast(1, 4)
+ B[T.int32(2) * 4 : T.int32(2) * 4 + 4] = T.broadcast(2, 4)
+ B[T.int32(3) * 4 : T.int32(3) * 4 + 4] = T.broadcast(3, 4)
before_mod = tvm.IRModule.from_expr(before_func)
after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
@@ -178,10 +178,10 @@ def test_vthread_vectorized():
@T.prim_func
def expected_func():
B = T.allocate([4], "int32x4", "shared")
- B[0 * 4 / 4] = T.broadcast(0, 4)
- B[1 * 4 / 4] = T.broadcast(1, 4)
- B[2 * 4 / 4] = T.broadcast(2, 4)
- B[3 * 4 / 4] = T.broadcast(3, 4)
+ B[T.int32(0) * 4 / 4] = T.broadcast(0, 4)
+ B[T.int32(1) * 4 / 4] = T.broadcast(1, 4)
+ B[T.int32(2) * 4 / 4] = T.broadcast(2, 4)
+ B[T.int32(3) * 4 / 4] = T.broadcast(3, 4)
before_mod = tvm.IRModule.from_expr(before_func)
intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py
index acc68af065..aff503c272 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -76,14 +76,6 @@ def test_missing_type_annotation():
check_error(missing_type_annotation, 1)
-def invalid_expr_stmt() -> None:
- T.max(1, 2) # error
-
-
-def test_invalid_expr_stmt():
- check_error(invalid_expr_stmt, 2)
-
-
def invalid_for_function(a: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
@@ -115,14 +107,6 @@ def test_return_not_allowed():
check_error(return_not_allowed, 2)
-def tir_assert(a: T.handle) -> None:
- T.Assert(0, "") # error
-
-
-def test_tir_assert():
- check_error(tir_assert, 2)
-
-
def no_body(a: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
T.realize(A, "") # error
@@ -250,19 +234,6 @@ def test_invalid_match_buffer_region():
check_error(invalid_match_buffer_region, 5)
-def duplicate_buffer() -> None:
- A = T.alloc_buffer((128, 128), "float32")
- for i, j in T.grid(128, 128):
- with T.block():
- vi, vj = T.axis.remap("SS", [i, j])
- A = T.alloc_buffer((128, 128), "float32") # error
- T.evaluate(1.0)
-
-
-def test_duplicate_buffer():
- check_error(duplicate_buffer, 6)
-
-
def duplicate_reads() -> None:
A = T.alloc_buffer((128, 128), "float32")
for i, j in T.grid(128, 128):
@@ -334,7 +305,7 @@ def opaque_access_during_complete(a: T.handle) -> None: # error
def test_opaque_access_during_complete():
- check_error(opaque_access_during_complete, 1)
+ check_error(opaque_access_during_complete, 0)
def convert_slice_to_bufferload() -> None:
@@ -608,15 +579,6 @@ def test_binop_bad_type():
check_error(binop_bad_type, 3)
-def floor_dtype(h: T.handle):
- h_ = T.match_buffer(h, [1])
- h_[0] = T.floor(2) # error floor requires a dtype
-
-
-def test_floor_dtype():
- check_error(floor_dtype, 3)
-
-
def non_integer_typed_block_iter():
with T.block():
i = T.axis.S(0.1, 0.1) # error IterVar requires an integer dtype
diff --git a/tests/python/unittest/test_tvmscript_spans.py b/tests/python/unittest/test_tvmscript_spans.py
deleted file mode 100644
index f863a4dd98..0000000000
--- a/tests/python/unittest/test_tvmscript_spans.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-from tvm.script import tir as T
-
-
-@T.prim_func
-def loops() -> None:
- for i in T.parallel(0, 2):
- for j in T.serial(0, 1):
- for z in T.vectorized(3, 4):
- T.evaluate(0)
-
-
-def test_loops():
- start_line = 23
- parsed = loops
-
- assert parsed.span.line == start_line
-
- assert parsed.body.span.line == start_line + 1
- assert parsed.body.min.span.column == 25
- assert parsed.body.extent.span.column == 28
- assert parsed.body.extent.span.line == start_line + 1
-
- assert parsed.body.body.span.line == start_line + 2
- assert parsed.body.body.loop_var.span.line == start_line + 2
- assert parsed.body.body.loop_var.span.column == 13
-
- assert parsed.body.body.body.span.line == start_line + 3
- assert parsed.body.body.body.span.column == 22
-
- assert parsed.body.body.body.body.span.line == start_line + 4
- assert parsed.body.body.body.body.span.column == 17
-
-
-@T.prim_func
-def statements() -> None:
- T.evaluate(1)
- T.evaluate("test")
-
-
-def test_statements():
- start_line = 53
- parsed = statements
-
- assert parsed.body.span.line == start_line + 1
-
- assert parsed.body[0].span.line == start_line + 1
- assert parsed.body[0].span.column == 5
-
- assert parsed.body[0].span.line == start_line + 1
- assert parsed.body[0].span.column == 5
-
-
-if __name__ == "__main__":
- test_loops()
- test_statements()
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 7248a3a5f4..d09a0d143a 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -20,8 +20,8 @@ import sys
import pytest
import tvm.testing
from tvm.ir import assert_structural_equal
+from tvm.script import from_source
from tvm.script import tir as T
-from tvm.script.parser import from_source
from tvm.testing import check_error
@@ -164,15 +164,24 @@ def test_match_buffer_1d():
# match buffer failed case
-def test_match_buffer_no_kwargs_failed():
- with pytest.raises(ValueError) as e:
-
- @T.prim_func
- def elementwise_buffer_no_kwargs_failed(
- a: T.Buffer[(128, 128, 128, 128)],
- b: T.Buffer[(128, 128, 128, 128)],
- ) -> None:
- pass
+def test_match_buffer_without_dtype():
+ @T.prim_func
+ def no_dtype(
+ a: T.Buffer[(128, 128, 128, 128)],
+ b: T.Buffer[(128, 128, 128, 128)],
+ ) -> None:
+ pass
+
+ a0, a1, a2, a3 = no_dtype.buffer_map[no_dtype.params[0]].shape
+ b0, b1, b2, b3 = no_dtype.buffer_map[no_dtype.params[1]].shape
+ assert a0 == 128
+ assert a1 == 128
+ assert a2 == 128
+ assert a3 == 128
+ assert b0 == 128
+ assert b1 == 128
+ assert b2 == 128
+ assert b3 == 128
# dynamic shape gemm
@@ -274,8 +283,8 @@ def test_letstmt_bind_with_constant():
@T.prim_func
def constant_binds_wrapped():
- x = T.int32(1)
- y = T.float32(42.0)
+ x = T.inline(T.int32(1))
+ y = T.inline(T.float32(42.0))
T.evaluate(T.cast(x, "float32") + y)
assert_structural_equal(constant_binds, constant_binds_wrapped)
@@ -298,9 +307,9 @@ def test_func_call():
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
- thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
- thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
- thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)
+ thread_id_C, local_id_C = T.inline(shared_16x16_to_ldmatrix_32x8_layout(i, j))
+ thread_id_A, local_id_A = T.inline(shared_16x16_to_ldmatrix_32x8_layout(i, k))
+ thread_id_B, local_id_B = T.inline(shared_16x16_to_ldmatrix_32x8_layout(k, j))
T.reads(
C[thread_id_C, local_id_C],