You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/01/04 20:26:32 UTC
[incubator-tvm] branch master updated: [REFACTOR] Unified IR base
types. (#4616)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 1ecd3ee [REFACTOR] Unified IR base types. (#4616)
1ecd3ee is described below
commit 1ecd3ee2b3a4c7ca8c56145f74b82372d7126882
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sat Jan 4 12:26:21 2020 -0800
[REFACTOR] Unified IR base types. (#4616)
This PR moves a few base types from relay to the ir sub-folder.
These types will serve as a common type system across the stack.
Notably, we want to be able to use the same FuncType for all function signatures.
I tried to make a minimum move to bring the necessary dependencies for a FuncType.
We can discuss what additional things we want to move as a follow-up.
Notably, because the TensorType will have a dependency on low-level Expr,
we will need to break the type.h into two files and introduce a
tensor_type.h(or leave them in relay for now).
---
CMakeLists.txt | 3 +-
include/tvm/{relay/base.h => ir/span.h} | 106 ++------------
include/tvm/ir/type.h | 246 ++++++++++++++++++++++++++++++++
include/tvm/relay/base.h | 86 +----------
include/tvm/relay/type.h | 187 ++----------------------
src/{relay/ir/base.cc => ir/span.cc} | 34 ++---
src/ir/type.cc | 96 +++++++++++++
src/relay/ir/base.cc | 64 +--------
src/relay/ir/hash.cc | 5 -
src/relay/ir/module.cc | 2 +-
src/relay/ir/type.cc | 67 ---------
tests/cpp/relay_pass_type_infer_test.cc | 2 +-
12 files changed, 386 insertions(+), 512 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7d5c04a..b823528 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -125,6 +125,8 @@ assign_source_group("Include" ${GROUP_INCLUDE})
# Source file lists
file(GLOB COMPILER_SRCS
+ src/node/*.cc
+ src/ir/*.cc
src/api/*.cc
src/arithmetic/*.cc
src/autotvm/*.cc
@@ -132,7 +134,6 @@ file(GLOB COMPILER_SRCS
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
- src/node/*.cc
src/schedule/*.cc
)
diff --git a/include/tvm/relay/base.h b/include/tvm/ir/span.h
similarity index 50%
copy from include/tvm/relay/base.h
copy to include/tvm/ir/span.h
index d64d05f..8cbfff7 100644
--- a/include/tvm/relay/base.h
+++ b/include/tvm/ir/span.h
@@ -18,47 +18,18 @@
*/
/*!
- * \file tvm/relay/base.h
- * \brief Base classes for the Relay IR.
+ * \file tvm/ir/span.h
+ * \brief Span information for debugging purposes.
*/
-#ifndef TVM_RELAY_BASE_H_
-#define TVM_RELAY_BASE_H_
+#ifndef TVM_IR_SPAN_H_
+#define TVM_IR_SPAN_H_
-#include <tvm/api_registry.h>
-#include <tvm/ir.h>
+#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <string>
-#include <vector>
namespace tvm {
/*!
- * \brief Relay: a high level functional IR for TVM.
- *
- * This namespace contains the abstract syntax tree, and other
- * essential data structures for the Relay IR.
- *
- * You can find more about Relay by reading the language reference.
- */
-namespace relay {
-
-#define RELAY_DEBUG(...) \
-{ auto fdebug = runtime::Registry::Get("relay.debug"); \
- CHECK(fdebug) << "Could not find Relay Python debugger function."; \
- (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
-}
-
-#define RELAY_DEBUG_INTERP(...) \
-{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \
- CHECK(fdebug) << "Could not find Relay Python debugger function."; \
- (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
-}
-
-/*!
- * \brief Symbolic expression for tensor shape.
- */
-using IndexExpr = ::tvm::Expr;
-
-/*!
* \brief The source name in the Span
* \sa SourceNameNode, Span
*/
@@ -83,19 +54,6 @@ class SourceNameNode : public Object {
*/
class SourceName : public ObjectRef {
public:
- /*! \brief default constructor */
- SourceName() {}
-
- /*! \brief constructor from node pointer */
- explicit SourceName(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const SourceNameNode* operator->() const {
- return static_cast<const SourceNameNode*>(get());
- }
-
/*!
* \brief Get an SourceName for a given operator name.
* Will raise an error if the source name has not been registered.
@@ -104,8 +62,7 @@ class SourceName : public ObjectRef {
*/
TVM_DLL static SourceName Get(const std::string& name);
- /*! \brief specify container node */
- using ContainerType = SourceNameNode;
+ TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
};
/*!
@@ -136,58 +93,11 @@ class SpanNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};
+
class Span : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};
-/*!
- * \brief This is the base node container of all relay structures.
- */
-class RelayNode : public Object {
- public:
- /*! \brief The location of the program in a SourceFragment can be null,
- * check with span.defined() */
- mutable Span span;
-
- static constexpr const char* _type_key = "relay.Node";
- TVM_DECLARE_BASE_OBJECT_INFO(RelayNode, Object);
-};
-
-/*!
- * \brief The unique identifier of variables.
- *
- * Id is like name to the variables,
- * except that id is unique for each Var.
- *
- * \note Do not create Id directly, they are created in Var.
- */
-class IdNode : public Object {
- public:
- /*!
- * \brief The name of the variable,
- * this only acts as a hint to the user,
- * and is not used for equality.
- */
- std::string name_hint;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("name_hint", &name_hint);
- }
-
- static constexpr const char* _type_key = "relay.Id";
- TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object);
-};
-
-class Id : public ObjectRef {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
-};
-
-
-struct Module;
-
-} // namespace relay
} // namespace tvm
-
-#endif // TVM_RELAY_BASE_H_
+#endif // TVM_IR_SPAN_H_
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
new file mode 100644
index 0000000..ffe1ba8
--- /dev/null
+++ b/include/tvm/ir/type.h
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/type.h
+ * \brief IR/AST nodes for the unified type system in TVM.
+ *
+ * We use Relay's type system as the unified type system
+ * throughout the stack.
+ *
+ * This file contains types that are common across IR variants.
+ *
+ * ## Relation between Type and runtime::DataType
+ *
+ * Besides Type, we also store a dtype field in some of the low-level IR's Expr.
+ * runtime::DataType(dtype) provides coarse grained type information
+ * during compile time and runtime. It is eagerly built in
+ * low-level expression construction and can be used for
+ * quick type checking in the low-level IR.
+ * For example, when an Expr's dtype is int32,
+ * we know for sure that its type is also int32.
+ *
+ * On the other hand, Type provides more fine grained information.
+ * For example, a low level expression can have DataType::Handle() as
+ * its dtype and MemRef[float32] as its type.
+ * Types are usually lazily constructed via type checking,
+ * so they may not readily be available during IR construction.
+ *
+ * The unified Type serves as a common bridge across IR dialects.
+ * For example, we require all the functions to have a type signature,
+ * which allow us to build cross dialect function calls.
+ */
+#ifndef TVM_IR_TYPE_H_
+#define TVM_IR_TYPE_H_
+
+#include <tvm/runtime/object.h>
+#include <tvm/node/node.h>
+#include <tvm/node/container.h>
+#include <tvm/ir/span.h>
+#include <string>
+
+namespace tvm {
+
+/*! \brief Base type of all the types. */
+class TypeNode : public Object {
+ public:
+ /*!
+ * \brief Span that points to the original source code.
+ * Reserved debug information.
+ */
+ mutable Span span;
+
+ static constexpr const char* _type_key = "relay.Type";
+ TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
+};
+
+/*!
+ * \brief Type is the base type of all types.
+ *
+ * Relay's type system contains following two key concepts:
+ *
+ * - PrimitiveType: type of primitive type values used in the low-level IR.
+ * - TensorType: type of certain Tensor values in the expression.
+ * - FunctionType: the type of the function.
+ *
+ * There are also advanced types to support generic(polymorphic types),
+ * which can be ignored when first reading the code base.
+ */
+class Type : public ObjectRef {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode);
+};
+
+/*! \brief Possible kinds of TypeVars. */
+enum TypeKind : int {
+ kType = 0,
+ /*! \brief Template variable in shape expression. */
+ kShapeVar = 1,
+ kBaseType = 2,
+ kShape = 3,
+ kConstraint = 4,
+ kAdtHandle = 5,
+ kTypeData = 6
+};
+
+/*!
+ * \brief Type parameter in the function.
+ * This can be viewed as template parameter in c++ template function.
+ *
+ * For example, in the following pesudo code,
+ * the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
+ * This function can take in a Tensor with shape=(3, 3) and
+ * returns a Tensor with shape=(9,)
+ *
+ * \code
+ *
+ * template<i32 n>
+ * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
+ *
+ * \endcode
+ * \sa TypeVarNode The actual container class of TypeVar
+ */
+class TypeVar;
+/*! \brief TypeVar container node */
+class TypeVarNode : public TypeNode {
+ public:
+ /*!
+ * \brief The name of the variable,
+ * this only acts as a hint to the user,
+ * and is not used for equality.
+ */
+ std::string name_hint;
+ /*! \brief The kind of type parameter */
+ TypeKind kind;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("name_hint", &name_hint);
+ v->Visit("kind", &kind);
+ v->Visit("span", &span);
+ }
+
+ TVM_DLL static TypeVar make(std::string name, TypeKind kind);
+
+ static constexpr const char* _type_key = "relay.TypeVar";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
+};
+
+class TypeVar : public Type {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
+};
+
+/*!
+ * \brief A global type variable that is used for defining new types or type aliases.
+ */
+class GlobalTypeVar;
+/*! \brief GlobalTypeVar container node */
+class GlobalTypeVarNode : public TypeNode {
+ public:
+ /*!
+ * \brief The name of the variable,
+ * this only acts as a hint to the user,
+ * and is not used for equality.
+ */
+ std::string name_hint;
+ /*! \brief The kind of type parameter */
+ TypeKind kind;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("name_hint", &name_hint);
+ v->Visit("kind", &kind);
+ }
+
+ TVM_DLL static GlobalTypeVar make(std::string name, TypeKind kind);
+
+ static constexpr const char* _type_key = "relay.GlobalTypeVar";
+ TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
+};
+
+class GlobalTypeVar : public Type {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
+};
+
+/*!
+ * \brief Potential Constraints in the type.
+ * \note This is reserved for future use.
+ */
+class TypeConstraint;
+/*! \brief TypeConstraint container node. */
+class TypeConstraintNode : public TypeNode {
+ public:
+ static constexpr const char* _type_key = "relay.TypeConstraint";
+ TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
+};
+
+class TypeConstraint : public Type {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
+};
+
+class FuncType;
+/*!
+ * \brief Function type in Relay.
+ *
+ * Relay support polymorphic function type.
+ * This can be roughly viewed as template function in C++.
+ *
+ * \sa TypeVar, TypeConstraint
+ */
+class FuncTypeNode : public TypeNode {
+ public:
+ /*! \brief type type of arguments */
+ Array<Type> arg_types;
+ /*! \brief The type of return value. */
+ Type ret_type;
+ // The following fields are used in polymorphic(template) functions
+ // For normal functions, the following two fields will be empty.
+ /*! \brief The type parameters of the function */
+ Array<TypeVar> type_params;
+ /*!
+ * \brief potential constraint the type need to obey
+ * \note this field is reserved for futher purposes.
+ */
+ Array<TypeConstraint> type_constraints;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("arg_types", &arg_types);
+ v->Visit("ret_type", &ret_type);
+ v->Visit("type_params", &type_params);
+ v->Visit("type_constraints", &type_constraints);
+ v->Visit("span", &span);
+ }
+
+ TVM_DLL static FuncType make(Array<Type> arg_types,
+ Type ret_type,
+ Array<TypeVar> type_params,
+ Array<TypeConstraint> type_constraints);
+
+ static constexpr const char* _type_key = "relay.FuncType";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
+};
+
+class FuncType : public Type {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
+};
+
+} // namespace tvm
+#endif // TVM_IR_TYPE_H_
diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h
index d64d05f..7191e1f 100644
--- a/include/tvm/relay/base.h
+++ b/include/tvm/relay/base.h
@@ -25,6 +25,7 @@
#define TVM_RELAY_BASE_H_
#include <tvm/api_registry.h>
+#include <tvm/ir/span.h>
#include <tvm/ir.h>
#include <tvm/node/node.h>
#include <string>
@@ -58,88 +59,9 @@ namespace relay {
*/
using IndexExpr = ::tvm::Expr;
-/*!
- * \brief The source name in the Span
- * \sa SourceNameNode, Span
- */
-class SourceName;
-/*!
- * \brief The name of a source fragment.
- */
-class SourceNameNode : public Object {
- public:
- /*! \brief The source name. */
- std::string name;
- // override attr visitor
- void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
-
- static constexpr const char* _type_key = "relay.SourceName";
- TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
-};
-
-/*!
- * \brief The source name of a file span.
- * \sa SourceNameNode, Span
- */
-class SourceName : public ObjectRef {
- public:
- /*! \brief default constructor */
- SourceName() {}
-
- /*! \brief constructor from node pointer */
- explicit SourceName(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const SourceNameNode* operator->() const {
- return static_cast<const SourceNameNode*>(get());
- }
-
- /*!
- * \brief Get an SourceName for a given operator name.
- * Will raise an error if the source name has not been registered.
- * \param name Name of the operator.
- * \return SourceName valid throughout program lifetime.
- */
- TVM_DLL static SourceName Get(const std::string& name);
-
- /*! \brief specify container node */
- using ContainerType = SourceNameNode;
-};
-
-/*!
- * \brief Span information for debugging purposes
- */
-class Span;
-/*!
- * \brief Stores locations in frontend source that generated a node.
- */
-class SpanNode : public Object {
- public:
- /*! \brief The source name */
- SourceName source;
- /*! \brief Line number */
- int lineno;
- /*! \brief column offset */
- int col_offset;
- // override attr visitor
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("source", &source);
- v->Visit("lineno", &lineno);
- v->Visit("col_offset", &col_offset);
- }
-
- TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
-
- static constexpr const char* _type_key = "relay.Span";
- TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
-};
-
-class Span : public ObjectRef {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
-};
+using SourceName = tvm::SourceName;
+using Span = tvm::Span;
+using SpanNode = tvm::SpanNode;
/*!
* \brief This is the base node container of all relay structures.
diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h
index 8f51ea9..c6a560a 100644
--- a/include/tvm/relay/type.h
+++ b/include/tvm/relay/type.h
@@ -25,8 +25,8 @@
#define TVM_RELAY_TYPE_H_
#include <tvm/api_registry.h>
+#include <tvm/ir/type.h>
#include <tvm/ir.h>
-#include <tvm/node/node.h>
#include <string>
#include "base.h"
@@ -36,32 +36,17 @@ namespace tvm {
namespace relay {
using Any = tvm::ir::Any;
-
-/*! \brief Base type of the Relay type hiearchy. */
-class TypeNode : public RelayNode {
- public:
- static constexpr const char* _type_key = "relay.Type";
- TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
-};
-
-/*!
- * \brief Type is the base type of relay type hiearchy.
- *
- * Relay's type system contains following two key concepts:
- *
- * - TensorType: type of certain Tensor values in the expression.
- * - FunctionType: the type of the function.
- *
- * There are also advanced types to support generic(polymorphic types),
- * which can be ignored when first reading the code base.
- */
-class Type : public ObjectRef {
- public:
- Type() {}
- explicit Type(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
-
- using ContainerType = TypeNode;
-};
+using Kind = TypeKind;
+using Type = tvm::Type;
+using TypeNode = tvm::TypeNode;
+using TypeVar = tvm::TypeVar;
+using TypeVarNode = tvm::TypeVarNode;
+using GlobalTypeVar = tvm::GlobalTypeVar;
+using GlobalTypeVarNode = tvm::GlobalTypeVarNode;
+using TypeConstraint = tvm::TypeConstraint;
+using TypeConstraintNode = tvm::TypeConstraintNode;
+using FuncType = tvm::FuncType;
+using FuncTypeNode = tvm::FuncTypeNode;
/*!
* \brief Base of all Tensor types
@@ -124,90 +109,6 @@ class TensorType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};
-/*! \brief Possible kinds of Type. */
-enum Kind : int {
- kType = 0,
- /*! \brief Template variable in shape expression. */
- kShapeVar = 1,
- kBaseType = 2,
- kShape = 3,
- kConstraint = 4,
- kAdtHandle = 5,
- kTypeData = 6
-};
-
-/*!
- * \brief Type parameter in the function.
- * This can be viewed as template parameter in c++ template function.
- *
- * For example, in the following pesudo code,
- * the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
- * This function can take in a Tensor with shape=(3, 3) and
- * returns a Tensor with shape=(9,)
- *
- * \code
- *
- * template<i32 n>
- * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
- *
- * \endcode
- * \sa TypeVarNode The actual container class of TypeVar
- */
-class TypeVar;
-/*! \brief TypeVar container node */
-class TypeVarNode : public TypeNode {
- public:
- /*! \brief Name of the variable, it only acts as a hint. */
- std::string name_hint;
- /*! \brief The kind of type parameter */
- Kind kind;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("name_hint", &name_hint);
- v->Visit("kind", &kind);
- v->Visit("span", &span);
- }
-
- TVM_DLL static TypeVar make(std::string name, Kind kind);
-
- static constexpr const char* _type_key = "relay.TypeVar";
- TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
-};
-
-class TypeVar : public Type {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
-};
-
-/*!
- * \brief A global type variable that is used for defining new types or type aliases.
- */
-class GlobalTypeVar;
-/*! \brief GlobalTypeVar container node */
-class GlobalTypeVarNode : public TypeNode {
- public:
- /*! \brief Name of the variable, it only acts as a hint. */
- std::string name_hint;
- /*! \brief The kind of type parameter */
- Kind kind;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("name_hint", &name_hint);
- v->Visit("kind", &kind);
- v->Visit("span", &span);
- }
-
- TVM_DLL static GlobalTypeVar make(std::string name, Kind kind);
-
- static constexpr const char* _type_key = "relay.GlobalTypeVar";
- TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
-};
-
-class GlobalTypeVar : public Type {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
-};
-
/*!
* \brief Type application.
*/
@@ -271,70 +172,6 @@ class IncompleteType : public Type {
};
/*!
- * \brief Potential Constraints in the type.
- * \note This is reserved for future use.
- */
-class TypeConstraint;
-/*! \brief TypeConstraint container node. */
-class TypeConstraintNode : public TypeNode {
- public:
- static constexpr const char* _type_key = "relay.TypeConstraint";
- TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
-};
-
-class TypeConstraint : public Type {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
-};
-
-class FuncType;
-/*!
- * \brief Function type in Relay.
- *
- * Relay support polymorphic function type.
- * This can be roughly viewed as template function in C++.
- *
- * \sa TypeVar, TypeConstraint
- */
-class FuncTypeNode : public TypeNode {
- public:
- /*! \brief type type of arguments */
- tvm::Array<Type> arg_types;
- /*! \brief The type of return value. */
- Type ret_type;
- // The following fields are used in polymorphic(template) functions
- // For normal functions, the following two fields will be empty.
- /*! \brief The type parameters of the function */
- tvm::Array<TypeVar> type_params;
- /*!
- * \brief potential constraint the type need to obey
- * \note this field is reserved for futher purposes.
- */
- tvm::Array<TypeConstraint> type_constraints;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("arg_types", &arg_types);
- v->Visit("ret_type", &ret_type);
- v->Visit("type_params", &type_params);
- v->Visit("type_constraints", &type_constraints);
- v->Visit("span", &span);
- }
-
- TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
- Type ret_type,
- tvm::Array<TypeVar> type_params,
- tvm::Array<TypeConstraint> type_constraints);
-
- static constexpr const char* _type_key = "relay.FuncType";
- TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
-};
-
-class FuncType : public Type {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
-};
-
-/*!
* \brief The type of tuple values.
*/
class TupleType;
diff --git a/src/relay/ir/base.cc b/src/ir/span.cc
similarity index 75%
copy from src/relay/ir/base.cc
copy to src/ir/span.cc
index ca87557..1d9f079 100644
--- a/src/relay/ir/base.cc
+++ b/src/ir/span.cc
@@ -16,19 +16,14 @@
* specific language governing permissions and limitations
* under the License.
*/
-
/*!
- * \file base.cc
- * \brief The core base types for Relay.
+ * \file span.cc
+ * \brief The span data structure.
*/
-#include <tvm/api_registry.h>
-#include <tvm/relay/base.h>
+#include <tvm/ir/span.h>
+#include <tvm/packed_func_ext.h>
namespace tvm {
-namespace relay {
-
-using tvm::IRPrinter;
-using namespace tvm::runtime;
ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
// always return pointer as the reference can change as map re-allocate.
@@ -50,11 +45,11 @@ SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name));
}
-TVM_REGISTER_API("relay._make.SourceName")
+TVM_REGISTER_GLOBAL("relay._make.SourceName")
.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<SourceNameNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
+.set_dispatch<SourceNameNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
@@ -75,24 +70,13 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode);
-TVM_REGISTER_API("relay._make.Span")
+TVM_REGISTER_GLOBAL("relay._make.Span")
.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<SpanNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
+.set_dispatch<SpanNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
- p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
+ p->stream << "Span(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
});
-
-TVM_REGISTER_NODE_TYPE(IdNode);
-
-TVM_REGISTER_API("relay._base.set_span")
-.set_body_typed<void(ObjectRef, Span)>([](ObjectRef node_ref, Span sp) {
- auto rn = node_ref.as<RelayNode>();
- CHECK(rn);
- rn->span = sp;
-});
-
-} // namespace relay
} // namespace tvm
diff --git a/src/ir/type.cc b/src/ir/type.cc
new file mode 100644
index 0000000..ef5f75b
--- /dev/null
+++ b/src/ir/type.cc
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/ir/type.cc
+ * \brief Common type system AST nodes throughout the IR.
+ */
+#include <tvm/ir/type.h>
+#include <tvm/packed_func_ext.h>
+
+namespace tvm {
+
+TypeVar TypeVarNode::make(std::string name, TypeKind kind) {
+ ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
+ n->name_hint = std::move(name);
+ n->kind = std::move(kind);
+ return TypeVar(n);
+}
+
+TVM_REGISTER_NODE_TYPE(TypeVarNode);
+
+TVM_REGISTER_GLOBAL("relay._make.TypeVar")
+.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
+ return TypeVarNode::make(name, static_cast<TypeKind>(kind));
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
+ auto* node = static_cast<const TypeVarNode*>(ref.get());
+ p->stream << "TypeVar(" << node->name_hint << ", "
+ << node->kind << ")";
+});
+
+GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) {
+ ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
+ n->name_hint = std::move(name);
+ n->kind = std::move(kind);
+ return GlobalTypeVar(n);
+}
+
+TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
+
+TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
+.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
+ return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind));
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
+ auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
+ p->stream << "GlobalTypeVar(" << node->name_hint << ", "
+ << node->kind << ")";
+});
+
+FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
+ Type ret_type,
+ tvm::Array<TypeVar> type_params,
+ tvm::Array<TypeConstraint> type_constraints) {
+ ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>();
+ n->arg_types = std::move(arg_types);
+ n->ret_type = std::move(ret_type);
+ n->type_params = std::move(type_params);
+ n->type_constraints = std::move(type_constraints);
+ return FuncType(n);
+}
+
+TVM_REGISTER_NODE_TYPE(FuncTypeNode);
+
+TVM_REGISTER_GLOBAL("relay._make.FuncType")
+.set_body_typed(FuncTypeNode::make);
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
+ auto* node = static_cast<const FuncTypeNode*>(ref.get());
+ p->stream << "FuncType(" << node->type_params << ", "
+ << node->arg_types << ", " << node->ret_type << ", "
+ << node->type_constraints << ")";
+});
+
+} // namespace tvm
diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc
index ca87557..3f98d87 100644
--- a/src/relay/ir/base.cc
+++ b/src/relay/ir/base.cc
@@ -22,76 +22,26 @@
* \brief The core base types for Relay.
*/
#include <tvm/api_registry.h>
+#include <tvm/ir/type.h>
#include <tvm/relay/base.h>
namespace tvm {
namespace relay {
-using tvm::IRPrinter;
using namespace tvm::runtime;
-ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
- // always return pointer as the reference can change as map re-allocate.
- // or use another level of indirection by creating a unique_ptr
- static std::unordered_map<std::string, ObjectPtr<SourceNameNode> > source_map;
-
- auto sn = source_map.find(name);
- if (sn == source_map.end()) {
- ObjectPtr<SourceNameNode> n = make_object<SourceNameNode>();
- source_map[name] = n;
- n->name = std::move(name);
- return n;
- } else {
- return sn->second;
- }
-}
-
-SourceName SourceName::Get(const std::string& name) {
- return SourceName(GetSourceNameNode(name));
-}
-
-TVM_REGISTER_API("relay._make.SourceName")
-.set_body_typed(SourceName::Get);
-
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<SourceNameNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
- auto* node = static_cast<const SourceNameNode*>(ref.get());
- p->stream << "SourceName(" << node->name << ", " << node << ")";
- });
-
-TVM_REGISTER_NODE_TYPE(SourceNameNode)
-.set_creator(GetSourceNameNode)
-.set_global_key([](const Object* n) {
- return static_cast<const SourceNameNode*>(n)->name;
- });
-
-Span SpanNode::make(SourceName source, int lineno, int col_offset) {
- auto n = make_object<SpanNode>();
- n->source = std::move(source);
- n->lineno = lineno;
- n->col_offset = col_offset;
- return Span(n);
-}
-
-TVM_REGISTER_NODE_TYPE(SpanNode);
-
-TVM_REGISTER_API("relay._make.Span")
-.set_body_typed(SpanNode::make);
-
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<SpanNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
- auto* node = static_cast<const SpanNode*>(ref.get());
- p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
- << node->col_offset << ")";
- });
-
TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span")
.set_body_typed<void(ObjectRef, Span)>([](ObjectRef node_ref, Span sp) {
- auto rn = node_ref.as<RelayNode>();
+ if (auto* rn = node_ref.as<RelayNode>()) {
CHECK(rn);
rn->span = sp;
+ } else if (auto* rn = node_ref.as<TypeNode>()) {
+ rn->span = sp;
+ } else {
+ LOG(FATAL) << "Expect Type or RelayNode ";
+ }
});
} // namespace relay
diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc
index 459e8b0..6199c54 100644
--- a/src/relay/ir/hash.cc
+++ b/src/relay/ir/hash.cc
@@ -228,11 +228,6 @@ class RelayHashHandler:
hash = Combine(hash, TypeHash(var_node->type_annotation));
}
hash_map_[var] = hash;
- // TODO(tqchen) Introduce TypeVarExpr
- // const auto* ty_param = var.as<TypeVarNode>();
- // if (ty_param && ty_param->kind == Kind::kShapeVar) {
- // hash_map_[ty_param->var] = hash;
- // }
return hash;
}
diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc
index 38f86a5..9f371dd 100644
--- a/src/relay/ir/module.cc
+++ b/src/relay/ir/module.cc
@@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const {
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
- << "There is no definition of " << var->name_hint;
+ << "There is no definition of " << var->name_hint;
return (*it).second;
}
diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc
index 48f211b..f1efddf 100644
--- a/src/relay/ir/type.cc
+++ b/src/relay/ir/type.cc
@@ -63,48 +63,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
-TypeVar TypeVarNode::make(std::string name, Kind kind) {
- ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
- n->name_hint = std::move(name);
- n->kind = std::move(kind);
- return TypeVar(n);
-}
-
-TVM_REGISTER_NODE_TYPE(TypeVarNode);
-
-TVM_REGISTER_API("relay._make.TypeVar")
-.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
- return TypeVarNode::make(name, static_cast<Kind>(kind));
-});
-
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
- auto* node = static_cast<const TypeVarNode*>(ref.get());
- p->stream << "TypeVarNode(" << node->name_hint << ", "
- << node->kind << ")";
-});
-
-GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
- ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
- n->name_hint = std::move(name);
- n->kind = std::move(kind);
- return GlobalTypeVar(n);
-}
-
-TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
-
-TVM_REGISTER_API("relay._make.GlobalTypeVar")
-.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
- return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
- });
-
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
- auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
- p->stream << "GlobalTypeVarNode(" << node->name_hint << ", "
- << node->kind << ")";
-});
-
TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
ObjectPtr<TypeCallNode> n = make_object<TypeCallNode>();
n->func = std::move(func);
@@ -143,31 +101,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
-FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
- Type ret_type,
- tvm::Array<TypeVar> type_params,
- tvm::Array<TypeConstraint> type_constraints) {
- ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>();
- n->arg_types = std::move(arg_types);
- n->ret_type = std::move(ret_type);
- n->type_params = std::move(type_params);
- n->type_constraints = std::move(type_constraints);
- return FuncType(n);
-}
-
-TVM_REGISTER_NODE_TYPE(FuncTypeNode);
-
-TVM_REGISTER_API("relay._make.FuncType")
-.set_body_typed(FuncTypeNode::make);
-
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
- auto* node = static_cast<const FuncTypeNode*>(ref.get());
- p->stream << "FuncTypeNode(" << node->type_params << ", "
- << node->arg_types << ", " << node->ret_type << ", "
- << node->type_constraints << ")";
-});
-
TypeRelation TypeRelationNode::make(TypeRelationFn func,
Array<Type> args,
int num_inputs,
diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc
index cdc6996..03ad228 100644
--- a/tests/cpp/relay_pass_type_infer_test.cc
+++ b/tests/cpp/relay_pass_type_infer_test.cc
@@ -38,7 +38,7 @@ TEST(Relay, SelfReference) {
auto type_fx = mod->Lookup("main");
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
- CHECK(AlphaEqual(type_fx->checked_type(), expected));
+ CHECK(relay::AlphaEqual(type_fx->checked_type(), expected));
}
int main(int argc, char ** argv) {