You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/03/14 21:26:34 UTC
[incubator-tvm] branch master updated: [TIR] Introduce
tir::PrimFunc (#5070)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new e031641 [TIR] Introduce tir::PrimFunc (#5070)
e031641 is described below
commit e03164159ce08f2739a26c10531b26713e72153e
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sat Mar 14 14:26:23 2020 -0700
[TIR] Introduce tir::PrimFunc (#5070)
This PR introduces tir::PrimFunc which will be used as the TIR function
container in the unified IR.
Also streamlined the function attributes a bit further.
- All common attributes are under tvm::attr
- TIR specific attributes are under tvm::tir::attr and comes with a tir prefix
- Use stl_style for attributes for now
---
include/tvm/ir/function.h | 99 ++++++++++++
include/tvm/ir/type.h | 7 +
include/tvm/relay/function.h | 27 ----
include/tvm/tir/function.h | 177 +++++++++++++++++++++
include/tvm/tir/op.h | 12 ++
python/tvm/ir/__init__.py | 3 +-
python/tvm/ir/expr.py | 9 --
python/tvm/ir/{__init__.py => function.py} | 24 ++-
python/tvm/relay/expr.py | 3 +-
python/tvm/tir/__init__.py | 2 +
python/tvm/tir/function.py | 86 ++++++++++
src/ir/function.cc | 1 +
src/printer/relay_text_printer.cc | 13 +-
src/relay/ir/function.cc | 18 +--
src/tir/ir/function.cc | 91 +++++++++++
src/tir/ir/op.cc | 12 ++
.../{test_lang_basic.py => test_tir_nodes.py} | 24 ++-
17 files changed, 536 insertions(+), 72 deletions(-)
diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index 4cb5d70..db7f446 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -33,6 +33,36 @@
namespace tvm {
/*!
+ * \brief Possible Calling conventions.
+ *
+ * NOTE: The calling convention also implies
+ * the way we implement the function during lowering.
+ */
+enum class CallingConv : int {
+ /*!
+ * \brief Default calling convetion.
+ *
+ * - Uses the native calling convention of the target.
+ * - Implementation: specified by the native target.
+ */
+ kDefault = 0,
+ /*!
+ * \brief Device kernel launch
+ *
+ * - Call by PackedFunc calling convention.
+ * - Implementation: defined by device runtime(e.g. runtime/cuda)
+ */
+ kDeviceKernelLaunch = 2,
+ /*!
+ * \brief PackedFunc that exposes a CPackedFunc signature.
+ *
+ * - Calling by PackedFunc calling convention.
+ * - Implementation: Expose a function with the CPackedFunc signature.
+ */
+ kCPackedFunc = 3,
+};
+
+/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
@@ -115,5 +145,74 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
+/*!
+ * \brief Create a new function that copies func, but overrides
+ * the attribute value key with the value.
+ *
+ * \param func The input function.
+ * \param attr_key The attribute key.
+ * \param attr_value The value attribute value.
+ *
+ * \tparam TFunc The corresponding function type.
+ *
+ * \returns The new function with updated attributes.
+ *
+ * \note This function performs copy on write optimization for func.
+ * If we move a uniquely referenced func into WithAttr,
+ * then no additional copy will be performed.
+ *
+ * This is also why we make it as a function instead of a member function
+ * and why we pass by value in the first argument.
+ *
+ * \code
+ *
+ * // Recommended way to trigger copy on write
+ * func = WithAttr(std::move(func), "key1", value1);
+ * func = WithAttr(std::move(func), "key2", value2);
+ *
+ * \endcode
+ */
+template<typename TFunc,
+ typename = typename std::enable_if<
+ std::is_base_of<BaseFunc, TFunc>::value>::type>
+inline TFunc WithAttr(TFunc func,
+ const std::string& attr_key,
+ ObjectRef attr_value) {
+ using TNode = typename TFunc::ContainerType;
+ static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
+ TNode* node = func.CopyOnWrite();
+ if (node->attrs.defined()) {
+ node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
+ } else {
+ Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
+ node->attrs = DictAttrs(dict);
+ }
+ return func;
+}
+
+/*!
+ * \brief Generic attribute names that can be attached to any function.
+ *
+ * \sa tvm::tir::attr, tvm::relay::attr
+ */
+namespace attr {
+/*!
+ * \brief Indicates the special calling convention.
+ *
+ * Type: Integer
+ *
+ * \sa tvm::CallingConv
+ */
+constexpr const char* kCallingConv = "calling_conv";
+
+/*!
+ * \brief Compilation target of the function.
+ *
+ * Type: Target
+ *
+ * \sa tvm::Target
+ */
+constexpr const char* kTarget = "target";
+} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 9e87731..7fd224b 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -277,6 +277,13 @@ class TupleType : public Type {
};
/*!
+ * \return a type that represents void.
+ */
+inline Type VoidType() {
+ return TupleType::Empty();
+}
+
+/*!
* \brief Potential Constraints in a function.
* \sa TypeConstraint
*/
diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index 27aa2e8..f7514c7 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -115,33 +115,6 @@ class Function : public BaseFunc {
};
/*!
- * \brief Create a new function that copies func, but overrides
- * the attribute value key with the value.
- *
- * \param func The input function.
- * \param attr_key The attribute key.
- * \param attr_value The value attribute value.
- *
- * \returns The new function with updated attributes.
- *
- * \note This function performs copy on write optimization for func.
- * If we move a uniquely referenced func into WithAttr,
- * then no additional copy will be performed.
- *
- * This is also why we make it as a function instead of a member function
- * and why we pass by value in the first argument.
- *
- * \code
- *
- * // Recommended way to trigger copy on write
- * func = WithAttr(std::move(func), "key1", value1);
- * func = WithAttr(std::move(func), "key2", value2);
- *
- * \endcode
- */
-TVM_DLL Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value);
-
-/*!
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
namespace attr {
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
new file mode 100644
index 0000000..0680267
--- /dev/null
+++ b/include/tvm/tir/function.h
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/tir/function.h
+ * \brief TIR Function.
+ */
+#ifndef TVM_TIR_FUNCTION_H_
+#define TVM_TIR_FUNCTION_H_
+
+#include <tvm/ir/function.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/tir/stmt.h>
+#include <string>
+
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Primitive functions that contains TIR statements.
+ *
+ * The PrimFunc provides low-level code representation does not
+ * automatically manage
+ *
+ * \sa PrimFunc
+ */
+class PrimFuncNode : public BaseFuncNode {
+ public:
+ /*! \brief Function parameters */
+ Array<tir::Var> params;
+ /*! \brief The body of the function */
+ tir::Stmt body;
+ /*! \brief The return type of the function. */
+ Type ret_type;
+ /*!
+ * \brief Maps some parameters to specific Buffer data structures.
+ *
+ * buffer_map provides a way to express data structure's field and shape
+ * constraints. The provided information is used in the program analysis
+ * and the code generation.
+ *
+ * - It defines the vars in the Buffer (m, n) in the cases below when
+ * they appears in the buffer_map for the first time.
+ * - When a var appears multiple times, they translate into runtime
+ * assertion to check the field constraint.
+ *
+ * \code
+ *
+ * # The corresponding fields of f are as follows
+ * #
+ * # - f.params = [a, b]
+ * # - f.buffer_map = {a: A, b: B}
+ * # - A = decl_buffer(shape=[m, n])
+ * # - B = decl_buffer(shape=[m, n])
+ *
+ * def f(a, b):
+ * m, n = var(), var()
+ * A = bind_buffer(a, shape=[m, n])
+ * B = bind_buffer(b, shape=[m, n])
+ * # body
+ *
+ * \endcode
+ *
+ * buffer_map is a sugar to express:
+ * - Parameter unpacking: e.g. I can load a.shape[0] to get value of m
+ * - Constraint checking: a.shape[0] must equal b.shape[0] because they
+ * both corresponds to m.
+
+ * While we could have express parameter unpacking and constraint using
+ * normal statements, making buffer_map as first class citizen of PrimFunc
+ * will make program analysis much easier.
+ *
+ * \note This field can be nullptr
+ */
+ Map<tir::Var, Buffer> buffer_map;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("params", ¶ms);
+ v->Visit("body", &body);
+ v->Visit("ret_type", &ret_type);
+ v->Visit("buffer_map", &buffer_map);
+ v->Visit("attrs", &attrs);
+ v->Visit("span", &span);
+ v->Visit("_checked_type_", &checked_type_);
+ }
+
+ /*!
+ * \brief Return the derived function annotation of this function.
+ *
+ * \return The function type annotation.
+ * \note The function type annotation of PrimExpr is
+ * directly derived from the Vars without the need of type inference.
+ */
+ TVM_DLL FuncType func_type_annotation() const;
+
+ static constexpr const char* _type_key = "tir.PrimFunc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
+};
+
+/*!
+ * \brief Managed reference to PrimFuncNode.
+ * \sa PrimFuncNode
+ */
+class PrimFunc : public BaseFunc {
+ public:
+ /*!
+ * \brief Constructor
+ * \param params The parameters of the function.
+ * \param body The body of the function.
+ * \param ret_type The return type of the function.
+ * \param buffer_map The buffer map for parameter buffer unpacking.
+ * \param attrs Additional function attributes.
+ */
+ TVM_DLL PrimFunc(Array<tir::Var> params,
+ Stmt body,
+ Type ret_type = VoidType(),
+ Map<tir::Var, Buffer> buffer_map = NullValue<Map<tir::Var, Buffer>>(),
+ DictAttrs attrs = NullValue<DictAttrs>());
+
+ TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
+};
+
+/*!
+ * \brief PrimFunc specific attribute names.
+ *
+ * \sa tvm::attr
+ */
+namespace attr {
+/*!
+ * \brief List of thread IterVar that a DeviceLaunch function corresponds to.
+ *
+ * Type: Array<tir::IterVar>
+ *
+ * We call a device kernel launch function f using the following convention:
+ *
+ * Call(f,
+ * [arg1, arg2, ..., arg_n,
+ * work_size_1, work_size_2, ... work_size_m])
+ *
+ * Here n = len(arg), m = len(work_size) = len(device_thread_axis).
+ *
+ * The list of device_thread_axis indicates how can be bind the
+ * work_size arguments to the corresponding threads.
+ *
+ * \sa tvm::CallingConv::kDeviceKernelLaunch
+ */
+constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
+
+/*!
+ * \brief Whether to set noalias rule on the function arguments.
+ *
+ * Type: Integer
+ */
+constexpr const char* kNoAlias = "tir.noalias";
+} // namespace attr
+} // namespace tir
+} // namespace tvm
+#endif // TVM_TIR_FUNCTION_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index a30c3c9..6ee5063 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -28,6 +28,7 @@
#ifndef TVM_TIR_OP_H_
#define TVM_TIR_OP_H_
+#include <tvm/ir/type.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
@@ -37,6 +38,7 @@
namespace tvm {
+
// Most common operators can be overloaded by argument type(PrimExpr).
// So we put them under the root namespace.
// It is also necessary to overload operators for PrimExpr.
@@ -45,6 +47,16 @@ namespace tvm {
// as they are more specific to the tir namespace.
/*!
+ * \brief Get the type of the expression under the unified type system.
+ *
+ * This function could return a more refined type than
+ * the runtime type provided by expr->dtype
+ *
+ * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
+ */
+TVM_DLL Type GetType(const PrimExpr& expr);
+
+/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
* \return the maximum possible value in this format.
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index a718124..4160326 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -21,7 +21,8 @@ from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
+from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
+from .function import BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs, DictAttrs, make_node
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index 00ceb5b..4e6bf16 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -51,15 +51,6 @@ class RelayExpr(BaseExpr):
return ret
-class BaseFunc(RelayExpr):
- """Base class of all functions."""
- @property
- def attrs(self):
- """Return the attrs member of the function.
- """
- return _ffi_api.BaseFunc_Attrs(self)
-
-
@tvm._ffi.register_object("relay.GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/function.py
similarity index 55%
copy from python/tvm/ir/__init__.py
copy to python/tvm/ir/function.py
index a718124..70eb51a 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/function.py
@@ -14,17 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=unused-import
-"""Common data structures across all IR variants."""
-from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
-from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
-from .tensor_type import TensorType
-from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
-from .adt import Constructor, TypeData
-from .module import IRModule
-from .attrs import Attrs, DictAttrs, make_node
-from .container import Array, Map
+"""Function defintiions."""
+from .expr import RelayExpr
+from . import _ffi_api
-from . import transform
+
+class BaseFunc(RelayExpr):
+ """Base class of all functions."""
+ @property
+ def attrs(self):
+ """Return the attrs member of the function.
+ """
+ return _ffi_api.BaseFunc_Attrs(self)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index a3c6251..61a5fb7 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -282,7 +282,8 @@ class Function(BaseFunc):
func : Function
A new copy of the function
"""
- return _expr.FunctionWithAttr(self, attr_key, attr_value)
+ return _expr.FunctionWithAttr(
+ self, attr_key, convert(attr_value))
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index fa244ac..b8a56f8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -31,6 +31,8 @@ from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
+from .function import PrimFunc
+
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
new file mode 100644
index 0000000..37946f6
--- /dev/null
+++ b/python/tvm/tir/function.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Function data types."""
+
+import tvm._ffi
+import tvm.runtime
+from tvm.ir import BaseFunc
+from .buffer import Buffer
+from .expr import Var
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("tir.PrimFunc")
+class PrimFunc(BaseFunc):
+ """A function declaration expression.
+
+ Parameters
+ ----------
+ params: List[Union[tvm.tir.Var, tvm.tir.Buffer]]
+ List of input parameters to the function.
+
+ body: tvm.tir.Stmt
+ The body of the function.
+
+ ret_type: tvm.ir.Type
+ The return type annotation of the function.
+
+ buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer]
+ The buffer binding map.
+
+ attrs: Optional[tvm.Attrs]
+ Attributes of the function, can be None
+ """
+ def __init__(self,
+ params,
+ body,
+ ret_type=None,
+ buffer_map=None,
+ attrs=None):
+ param_list = []
+ buffer_map = {} if buffer_map is None else buffer_map
+ for x in params:
+ if isinstance(x, Buffer):
+ var = Var(x.name, dtype="handle")
+ param_list.append(var)
+ buffer_map[var] = x
+ elif isinstance(x, Var):
+ param_list.append(x)
+ else:
+ raise TypeError("params can only contain Var or Buffer")
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
+
+ def with_attr(self, attr_key, attr_value):
+ """Create a new copy of the function and update the attribute
+
+ Parameters
+ ----------
+ attr_key : str
+ The attribute key to use.
+
+ attr_value : Object
+ The new attribute value.
+
+ Returns
+ -------
+ func : Function
+ A new copy of the function
+ """
+ return _ffi_api.PrimFuncWithAttr(
+ self, attr_key, tvm.runtime.convert(attr_value))
diff --git a/src/ir/function.cc b/src/ir/function.cc
index d3753d8..e7ccbbe 100644
--- a/src/ir/function.cc
+++ b/src/ir/function.cc
@@ -30,4 +30,5 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
.set_body_typed([](BaseFunc func) {
return func->attrs;
});
+
} // namespace tvm
diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc
index 2799be0..56e77b7 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -99,7 +99,11 @@ class RelayTextPrinter :
}
Doc PrintFinal(const ObjectRef& node) {
- if (node.as<ExprNode>()) {
+ if (node->IsInstance<BaseFuncNode>() &&
+ !node->IsInstance<relay::FunctionNode>()) {
+ // Temporarily skip non-relay functions.
+ // TODO(tvm-team) enhance the code to work for all functions
+ } else if (node.as<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}
@@ -122,7 +126,10 @@ class RelayTextPrinter :
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) {
- if (node.as<ExprNode>()) {
+ bool is_non_relay_func =
+ node->IsInstance<BaseFuncNode>() &&
+ !node->IsInstance<relay::FunctionNode>();
+ if (node.as<ExprNode>() && !is_non_relay_func) {
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
@@ -134,7 +141,7 @@ class RelayTextPrinter :
// default module.
std::ostringstream os;
os << node;
- return Doc() << os.str();
+ return Doc::RawText(os.str());
}
}
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index c1bd710..63ad4dd 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -60,18 +60,6 @@ bool FunctionNode::UseDefaultCompiler() const {
return !val.defined() || val->value == "default";
}
-Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value) {
- FunctionNode* node = func.CopyOnWrite();
- if (node->attrs.defined()) {
- node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
- } else {
- Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
- node->attrs = DictAttrs(dict);
- }
- return func;
-}
-
-
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
@@ -94,9 +82,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr")
.set_body_typed(
- [](Function func, std::string name, ObjectRef ref) {
- return WithAttr(std::move(func), name, ref);
-});
+ [](Function func, std::string name, ObjectRef ref) {
+ return WithAttr(std::move(func), name, ref);
+ });
} // namespace relay
} // namespace tvm
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
new file mode 100644
index 0000000..7464e3a
--- /dev/null
+++ b/src/tir/ir/function.cc
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/function.cc
+ * \brief The function data structure.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/function.h>
+
+namespace tvm {
+namespace tir {
+
+PrimFunc::PrimFunc(Array<tir::Var> params,
+ Stmt body,
+ Type ret_type,
+ Map<tir::Var, Buffer> buffer_map,
+ DictAttrs attrs) {
+ // Assume void-return type for now
+ // TODO(tvm-team) consider type deduction from body.
+ if (!ret_type.defined()) {
+ ret_type = VoidType();
+ }
+ auto n = make_object<PrimFuncNode>();
+ n->params = std::move(params);
+ n->body = std::move(body);
+ n->ret_type = std::move(ret_type);
+ n->buffer_map = std::move(buffer_map);
+ n->attrs = std::move(attrs);
+ data_ = std::move(n);
+}
+
+FuncType PrimFuncNode::func_type_annotation() const {
+ Array<Type> param_types;
+ for (auto param : this->params) {
+ param_types.push_back(GetType(param));
+ }
+ return FuncType(param_types, ret_type, {}, {});
+}
+
+TVM_REGISTER_NODE_TYPE(PrimFuncNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ // TODO(tvm-team) redirect to Text printer once we have a good text format.
+ auto* node = static_cast<const PrimFuncNode*>(ref.get());
+ p->stream << "PrimFunc(" << node->params << ") ";
+ if (node->attrs.defined()) {
+ p->stream << "attrs=" << node->attrs;
+ }
+ p->stream << " {\n";
+ p->indent += 2;
+ p->Print(node->body);
+ p->indent -= 2;
+ p->stream << "}\n";
+});
+
+
+TVM_REGISTER_GLOBAL("tir.PrimFunc")
+.set_body_typed([](Array<tir::Var> params,
+ Stmt body,
+ Type ret_type,
+ Map<tir::Var, Buffer> buffer_map,
+ DictAttrs attrs) {
+ return PrimFunc(params, body, ret_type, buffer_map, attrs);
+});
+
+
+TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr")
+.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) {
+ return WithAttr(std::move(func), name, ref);
+});
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc
index 2882fea..b073643 100644
--- a/src/tir/ir/op.cc
+++ b/src/tir/ir/op.cc
@@ -32,6 +32,18 @@ namespace tvm {
using namespace tir;
+
+Type GetType(const PrimExpr& expr) {
+ runtime::DataType dtype = expr.dtype();
+ // These types already implies the specific type.
+ if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) {
+ return PrimType(dtype);
+ }
+ // TODO(tqchen): add recursive type inference for Var and Call here
+ // once we introduced the corresponding fields to the IR.
+ return PrimType(dtype);
+}
+
// simple cast that only checks if type matches and cast
inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_tir_nodes.py
similarity index 92%
rename from tests/python/unittest/test_lang_basic.py
rename to tests/python/unittest/test_tir_nodes.py
index c279194..3a7985d 100644
--- a/tests/python/unittest/test_lang_basic.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -19,6 +19,7 @@ from tvm import te
import numpy as np
+
def test_const():
x = tvm.tir.const(1, "int32")
print(x.dtype)
@@ -46,8 +47,8 @@ def test_make():
x = tvm.tir.const(1, "int32")
y = te.var("x")
z = x + y
- assert isinstance(tvm.te.max(x, y), tvm.tir.Max)
- assert isinstance(tvm.te.min(x, y), tvm.tir.Min)
+ assert isinstance(tvm.tir.max(x, y), tvm.tir.Max)
+ assert isinstance(tvm.tir.min(x, y), tvm.tir.Min)
def test_ir():
@@ -111,7 +112,6 @@ def test_stmt():
tvm.tir.For.Serial, 0,
x)
-
def test_dir():
x = te.var('x')
dir(x)
@@ -247,8 +247,26 @@ def test_equality_string_imm():
x == y.value
x == y
+def test_prim_func():
+ x = te.var('x')
+ y = te.var('y')
+ b = tvm.tir.decl_buffer((x,), "float32")
+ stmt = tvm.tir.LetStmt(
+ x, 10, tvm.tir.Evaluate(x + 1));
+
+ func = tvm.tir.PrimFunc(
+ [x, y, b], stmt)
+
+ assert func.buffer_map[func.params[2]].same_as(b)
+
+ assert len(func.buffer_map) == 1
+ f2 = func.with_attr("calling_conv", 1)
+ assert f2.attrs["calling_conv"].value == 1
+ assert func.attrs is None
+
if __name__ == "__main__":
+ test_prim_func()
test_cast()
test_attr()
test_const()