You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/01/11 21:02:37 UTC

[incubator-tvm] branch master updated: [REFACTOR][IR] Unified IR Primitive Op and Registry (#4687)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new d8f0602  [REFACTOR][IR] Unified IR Primitive Op and Registry (#4687)
d8f0602 is described below

commit d8f06020a10deb25722c9f97363368e8ae1c2b62
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sat Jan 11 13:02:29 2020 -0800

    [REFACTOR][IR] Unified IR Primitive Op and Registry (#4687)
    
    This PR migrates relay's Op into the ir folder.
    Op and its registry provides an useful mechanism to
    store any attribute meta-data of an operator include
    function signatures, lowering rules, side effect etc.
    
    These features are not only useful for Relay, but also needed in the low-level IR.
    At the current moment, intrinsic functions in the low-level IR are simply
    represented by a string. This means we cannot type-check the low-level IR
    when the type does not meet the constraint, nor can we obtain further
    information such as side-effect and read write relation of these intrinsics
    wrt to arguments.
    
    Op will be used as the way to handle primitive ops(in DL terminology)
    (builtin intrinsics or in compiler terminology).
    We will perform follow-up refactors to make low-level CallNode
    take Op as the function argument.
---
 include/tvm/{relay => ir}/op.h |  85 +++---
 include/tvm/ir/type.h          |   1 +
 include/tvm/ir/type_relation.h | 175 ++++++++++++
 include/tvm/relay/op.h         | 588 +----------------------------------------
 include/tvm/relay/type.h       | 147 +----------
 src/{relay => }/ir/op.cc       |  12 +-
 src/ir/type_relation.cc        |  54 ++++
 src/relay/ir/type.cc           |  24 --
 8 files changed, 298 insertions(+), 788 deletions(-)

diff --git a/include/tvm/relay/op.h b/include/tvm/ir/op.h
similarity index 90%
copy from include/tvm/relay/op.h
copy to include/tvm/ir/op.h
index 6bd0a35..19c5a51 100644
--- a/include/tvm/relay/op.h
+++ b/include/tvm/ir/op.h
@@ -18,27 +18,25 @@
  */
 
 /*!
- * \file tvm/relay/op.h
- * \brief Primitive operator definition.
+ * \file tvm/ir/op.h
+ * \brief Primitive operators(builtin intrinsics)
+ *        and registry for them.
  */
-#ifndef TVM_RELAY_OP_H_
-#define TVM_RELAY_OP_H_
+#ifndef TVM_IR_OP_H_
+#define TVM_IR_OP_H_
 
 #include <dmlc/registry.h>
+#include <tvm/attrs.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/ir/expr.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_relation.h>
 
-#include <functional>
-#include <limits>
 #include <string>
-#include <typeinfo>
 #include <utility>
 #include <vector>
 
-#include "base.h"
-#include "expr.h"
-#include "type.h"
-
 namespace tvm {
-namespace relay {
 
 // forward declare name.
 template <typename ValueType>
@@ -46,10 +44,19 @@ class OpMap;
 class GenericOpMap;
 class OpRegistry;
 
+// TODO(tvm-team): migrate low-level intrinsics to use Op
 /*!
- * \brief Node container of operator structure.
+ * \brief Primitive Op(builtin intrinsics)
+ *
+ * This data structure stores the meta-data
+ * about primitive operators that can be invoked via Call.
+ *
+ * Low-level IR intrinsics(such as libc.expf) are also
+ * implemented via Op.
+ *
+ * \sa Op
  */
-class OpNode : public relay::ExprNode {
+class OpNode : public RelayExprNode {
  public:
   /*! \brief name of the operator */
   std::string name;
@@ -106,13 +113,13 @@ class OpNode : public relay::ExprNode {
   }
 
   static constexpr const char* _type_key = "relay.Op";
-  TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);
 
  private:
   // friend class
   friend class GenericOpMap;
   friend class OpRegistry;
-  friend bool IsPrimitiveOp(const Expr&);
+  friend bool IsPrimitiveOp(const RelayExpr&);
   // Program internal unique index of operator.
   // Used to help index the program.
   uint32_t index_{0};
@@ -133,9 +140,10 @@ class OpNode : public relay::ExprNode {
 };
 
 /*!
- * \brief Operator reference class.
+ * \brief Managed reference class to OpNode.
+ * \sa OpNode
  */
-class Op : public relay::Expr {
+class Op : public RelayExpr {
  public:
   /*! \brief default constructor  */
   Op() {}
@@ -187,7 +195,10 @@ class Op : public relay::Expr {
   TVM_DLL static const bool HasGenericAttr(const std::string& key);
 };
 
-/*! \brief Helper structure to register operators */
+/*!
+ * \brief Helper structure to register operators
+ * \sa TVM_REGISTER_OP
+ */
 class OpRegistry {
  public:
   /*! \return the operator */
@@ -324,7 +335,7 @@ class GenericOpMap {
    * \tparam ValueType The content value type.
    */
   template <typename ValueType>
-  inline ValueType get(const Expr& expr, ValueType def_value) const;
+  inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
 
  private:
   friend class OpRegistry;
@@ -369,7 +380,7 @@ class OpMap {
    *         or if expr is not an Op.
    * \return the const reference to the content value.
    */
-  inline ValueType get(const Expr& expr, ValueType def_value) const;
+  inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
 
  private:
   friend class Op;
@@ -380,28 +391,28 @@ class OpMap {
 };
 
 // internal macros to make
-#define RELAY_REGISTER_VAR_DEF \
-  static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp
+#define TVM_OP_REGISTER_VAR_DEF \
+  static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op
 
 /*!
- * \def RELAY_REGISTER_OP
+ * \def TVM_REGISTER_OP
  * \brief Register a new operator, or set attribute of the corresponding op.
  *
  * \param OpName The name of registry
  *
  * \code
  *
- *  RELAY_REGISTER_OP("add")
+ *  TVM_REGISTER_OP("add")
  *  .describe("add two inputs together")
  *  .set_num_inputs(2)
  *  .set_attr<OpKernel>("gpu_kernel", AddKernel);
  *
  * \endcode
  */
-#define RELAY_REGISTER_OP(OpName)                        \
-  DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \
-      ::tvm::relay::OpRegistry::Registry()               \
-          ->__REGISTER_OR_GET__(OpName)                  \
+#define TVM_REGISTER_OP(OpName)                                \
+  TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) =       \
+  ::tvm::OpRegistry::Registry()                                \
+          ->__REGISTER_OR_GET__(OpName)                        \
           .set_name()
 
 // implementations
@@ -465,7 +476,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
   std::string input_name_prefix = "in";
   for (int i = 0; i < get()->num_inputs; i++) {
     auto name = input_name_prefix + std::to_string(i);
-    auto param = TypeVarNode::make(name, Kind::kType);
+    auto param = TypeVarNode::make(name, TypeKind::kType);
     type_params.push_back(param);
     arg_types.push_back(param);
   }
@@ -473,7 +484,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
   Array<Type> ty_call_args = arg_types;
 
   // Add output type.
-  auto out_param = TypeVarNode::make("out", Kind::kType);
+  auto out_param = TypeVarNode::make("out", TypeKind::kType);
   type_params.push_back(out_param);
   // this will trigger copy on write.
   ty_call_args.push_back(out_param);
@@ -558,7 +569,7 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
 }
 
 template <typename ValueType>
-inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const {
+inline ValueType GenericOpMap::get(const RelayExpr& expr, ValueType value) const {
   CHECK(expr.defined());
   if (const OpNode* op = expr.as<OpNode>()) {
     const uint32_t idx = op->index_;
@@ -589,7 +600,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
 }
 
 template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const Expr& expr,
+inline ValueType OpMap<ValueType>::get(const RelayExpr& expr,
                                        ValueType def_value) const {
   return map_.get<ValueType>(expr, def_value);
 }
@@ -603,12 +614,14 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
  *
  * That is the arguments are all type variables, and there is a single
  * type relation applied to the input and output types.
+ *
+ * \param expr An expression.
+ * \return Whether the expression is primitive op.
  */
-inline bool IsPrimitiveOp(const Expr& expr) {
+inline bool IsPrimitiveOp(const RelayExpr& expr) {
   const auto* op = expr.as<OpNode>();
   return op != nullptr && op->IsPrimitiveOp();
 }
 
-}  // namespace relay
 }  // namespace tvm
-#endif  // TVM_RELAY_OP_H_
+#endif  // TVM_IR_OP_H_
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index ab2003e..ddabd0f 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -51,6 +51,7 @@
 
 #include <tvm/runtime/object.h>
 #include <tvm/node/node.h>
+#include <tvm/node/env_func.h>
 #include <tvm/node/container.h>
 #include <tvm/ir/span.h>
 #include <string>
diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h
new file mode 100644
index 0000000..71d1d9e
--- /dev/null
+++ b/include/tvm/ir/type_relation.h
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/type_relation.h
+ * \brief Type relation function for type checking.
+ */
+#ifndef TVM_IR_TYPE_RELATION_H_
+#define TVM_IR_TYPE_RELATION_H_
+
+#include <tvm/ir/type.h>
+#include <tvm/attrs.h>
+
+namespace tvm {
+
+// TODO(tqchen): remove after migrate Module to ir.
+namespace relay {
+struct Module;
+}
+
+/*!
+ * \brief reporter that reports back to the
+ *  type resolution information.
+ */
+class TypeReporterNode : public Object {
+ public:
+  /*!
+   * \brief Create a type equality constraint.
+   *
+   *  The "assign direction" acts as a hint to the solver
+   *  showing that it is more likely to resolve dst by src.
+   *  But it is possible for the solver to resolve src by dst as well.
+   */
+  TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
+
+  /*!
+   * \brief assert shape expression comparison.
+   * \note Use assert only if any of the condition input is symbolic.
+   * \param cond The condition of operation.
+   * \return false if assertation can be proven to have failed
+   *      true if solver can still proceed.
+   */
+  TVM_DLL virtual bool Assert(const PrimExpr& cond)= 0;
+  /*!
+   * \brief assert shape expression equals each other.
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   * \return false if assertation can be proven to have failed
+   *      true if solver can still proceed.
+   */
+  TVM_DLL virtual bool AssertEQ(const PrimExpr& lhs, const PrimExpr& rhs) = 0;
+
+  /*!
+   * \brief Set the location at which to report unification errors.
+   * \param ref The program node to report the error.
+   */
+  TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0;
+
+  /*!
+   * \brief Retrieve the current global module.
+   * \return The global module.
+   */
+  TVM_DLL virtual relay::Module GetModule() = 0;
+
+  // solver is not serializable.
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "relay.TypeReporter";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
+};
+
+/*!
+ * \brief Container class of TypeReporter.
+ * \sa TypeReporterNode
+ */
+class TypeReporter : public ObjectRef {
+ public:
+  TypeReporter() {}
+  explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
+  }
+  TypeReporterNode* operator->() const {
+    return const_cast<TypeReporterNode*>(
+        static_cast<const TypeReporterNode*>(get()));
+  }
+  using ContainerType = TypeReporterNode;
+};
+
+/*!
+ * \brief User defined type constraint function.
+ *
+ * If the input type information can be used to fully decide
+ * the IncompleteTypes, then the function should call
+ * reporter.Assign to report the new types, and return true.
+ * Otherwise, the function should return false.
+ *
+ * \param args The arguments to the relation.
+ *   The types are stored in the form of
+ *   [input_type_0, input_type_1, ... input_type_n,
+ *    output_type_0, output_type_1, ... output_type_m]
+ *
+ * \param num_inputs Number of input types in the args.
+ * \param attrs The additional attributes of the operator.
+ * \param reporter The reporter to report solution to.
+ * \return false if This relation cannot be resolved.
+ *   true if this relation has been resolved.
+ */
+using TypeRelationFn =
+    TypedEnvFunc<bool(const Array<Type>& args,
+                      int num_inputs,
+                      const Attrs& attrs,
+                      const TypeReporter& reporter)>;
+
+/*!
+ * \brief User defined type relation, is an input-output relation on types.
+ */
+class TypeRelation;
+/*!
+ * \brief TypeRelation container.
+ * \note This node is not directly serializable.
+ * The type function need to be lookedup in the module.
+ */
+class TypeRelationNode : public TypeConstraintNode {
+ public:
+  /*!
+   * \brief The function on input and output variables which
+   *  this is not directly serializable,
+   *  need to be looked-up in the module.
+   */
+  TypeRelationFn func;
+  /*! \brief The type arguments to the type function. */
+  tvm::Array<Type> args;
+  /*! \brief Number of inputs arguments */
+  int num_inputs;
+  /*! \brief Attributes to the relation function */
+  Attrs attrs;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("func", &func);
+    v->Visit("args", &args);
+    v->Visit("num_inputs", &num_inputs);
+    v->Visit("attrs", &attrs);
+    v->Visit("span", &span);
+  }
+
+  TVM_DLL static TypeRelation make(TypeRelationFn func,
+                                   Array<Type> args,
+                                   int num_args,
+                                   Attrs attrs);
+
+  static constexpr const char* _type_key = "relay.TypeRelation";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
+};
+
+class TypeRelation : public TypeConstraint {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
+};
+}  // namespace tvm
+#endif  // TVM_IR_TYPE_RELATION_H_
diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h
index 6bd0a35..fa47da2 100644
--- a/include/tvm/relay/op.h
+++ b/include/tvm/relay/op.h
@@ -19,595 +19,23 @@
 
 /*!
  * \file tvm/relay/op.h
- * \brief Primitive operator definition.
+ * \brief Primitive operators(builtin intrinsics).
  */
 #ifndef TVM_RELAY_OP_H_
 #define TVM_RELAY_OP_H_
 
-#include <dmlc/registry.h>
-
-#include <functional>
-#include <limits>
-#include <string>
-#include <typeinfo>
-#include <utility>
-#include <vector>
-
-#include "base.h"
-#include "expr.h"
-#include "type.h"
+#include <tvm/ir/op.h>
+#include <tvm/relay/type.h>
+#include <tvm/relay/expr.h>
 
 namespace tvm {
 namespace relay {
 
-// forward declare name.
-template <typename ValueType>
-class OpMap;
-class GenericOpMap;
-class OpRegistry;
-
-/*!
- * \brief Node container of operator structure.
- */
-class OpNode : public relay::ExprNode {
- public:
-  /*! \brief name of the operator */
-  std::string name;
-  /*! \brief the type of the operator */
-  mutable FuncType op_type;
-  /*!
-   * \brief detailed description of the operator
-   *  This can be used to generate docstring automatically for the operator.
-   */
-  std::string description;
-  /* \brief Information of input arguments to the operator */
-  Array<AttrFieldInfo> arguments;
-  /*!
-   * \brief The type key of the attribute field
-   *  This can be empty, in which case it defaults to anything.
-   */
-  std::string attrs_type_key;
-  /*!
-   * \brief attribute type index,
-   * this field varies in each run and is not exposed to frontend.
-   */
-  uint32_t attrs_type_index{0};
-  /*!
-   * \brief number of input arguments to the operator,
-   * -1 means it is variable length
-   */
-  int32_t num_inputs = -1;
-  /*!
-   * \brief support level of the operator,
-   *  The lower the more priority it contains.
-   *  This is in analogies to BLAS levels.
-   */
-  int32_t support_level = 10;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("op_type", &op_type);
-    v->Visit("description", &description);
-    v->Visit("arguments", &arguments);
-    v->Visit("attrs_type_key", &attrs_type_key);
-    v->Visit("num_inputs", &num_inputs);
-    v->Visit("support_level", &support_level);
-  }
-
-  /*!
-   * \brief Check that if current op is a "primtive operator".
-   * That is the arguments are all type variables, and there is a single
-   * type relation applied to the input and output types.
-   */
-  bool IsPrimitiveOp() const {
-    if (is_primitive_ != -1) return is_primitive_ != 0;
-    is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
-    return is_primitive_ != 0;
-  }
-
-  static constexpr const char* _type_key = "relay.Op";
-  TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, ExprNode);
-
- private:
-  // friend class
-  friend class GenericOpMap;
-  friend class OpRegistry;
-  friend bool IsPrimitiveOp(const Expr&);
-  // Program internal unique index of operator.
-  // Used to help index the program.
-  uint32_t index_{0};
-  // whether this is a primitive op. -1 means unknown.
-  mutable int is_primitive_{-1};
-  // Internal function to compute if it is primitive op
-  bool IsPrimitiveOp_() const {
-    const auto& fn_ty = this->op_type;
-    if (fn_ty->type_constraints.size() != 1) return false;
-    const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
-    if (rel == nullptr) return false;
-    // validate if the type parameter matches up
-    for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
-      if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
-    }
-    return true;
-  }
-};
-
-/*!
- * \brief Operator reference class.
- */
-class Op : public relay::Expr {
- public:
-  /*! \brief default constructor  */
-  Op() {}
-  /*! \brief constructor from node pointer */
-  explicit Op(ObjectPtr<Object> n) : RelayExpr(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const OpNode* operator->() const;
-  /*!
-   * \brief Get additional registered attribute about operators.
-   *  If nothing has been registered, an empty OpMap will be returned.
-   * \param attr_name The name of the attribute.
-   * \return An OpMap of specified attr_name.
-   * \tparam ValueType The type of the attribute.
-   */
-  template <typename ValueType>
-  inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
-  /*!
-   * \brief Checks if an attr is present in the registry.
-   * \param attr_name The name of the attribute.
-   * \return bool True if the attr is present.
-   */
-  inline static bool HasAttr(const std::string& attr_name);
-  /*!
-   * \brief Get an Op for a given operator name.
-   *  Will raise an error if the op has not been registered.
-   * \param op_name Name of the operator.
-   * \return Pointer to a Op, valid throughout program lifetime.
-   */
-  TVM_DLL static const Op& Get(const std::string& op_name);
-
-  /*! \brief specify container node */
-  using ContainerType = OpNode;
-
- private:
-  /*!
-   * \brief Get generic attrmap given attr name
-   * \param key The attribute key
-   * \return reference to GenericOpMap
-   */
-  TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
-  /*!
-   * \brief Checks if the key is present in the registry
-   * \param key The attribute key
-   * \return bool True if the key is present
-   */
-  TVM_DLL static const bool HasGenericAttr(const std::string& key);
-};
-
-/*! \brief Helper structure to register operators */
-class OpRegistry {
- public:
-  /*! \return the operator */
-  const Op& op() const { return op_; }
-  /*!
-   * \brief setter function during registration
-   *  Set the description of operator
-   * \param descr the description string.
-   * \return reference to self.
-   */
-  inline OpRegistry& describe(const std::string& descr);  // NOLINT(*)
-  /*!
-   * \brief Add argument information to the function.
-   * \param name Name of the argument.
-   * \param type Type of the argument.
-   * \param description Description of the argument.
-   * \return reference to self.
-   */
-  inline OpRegistry& add_argument(const std::string& name,
-                                  const std::string& type,
-                                  const std::string& description);
-  /*!
-   * \brief Attach the type function corresponding to the return type.
-   * \param rel_name The type relation name to register.
-   * \param type_rel_func The backing relation function which can solve an arbitrary
-   * relation on variables.
-   * \return reference to self.
-   */
-  inline OpRegistry& add_type_rel(
-      const std::string& rel_name,
-      runtime::TypedPackedFunc<bool(const Array<Type>&,
-                                    int,
-                                    const Attrs&,
-                                    const TypeReporter&)> type_rel_func);
-  /*!
-   * \brief Set the the attrs type key and index to be AttrsType.
-   * \tparam AttrsType the attribute type to b set.
-   * \return reference to self.
-   */
-  template<typename AttrsType>
-  inline OpRegistry& set_attrs_type();
-  /*!
-   * \brief Set the num_inputs
-   * \param n The number of inputs to be set.
-   * \return reference to self.
-   */
-  inline OpRegistry& set_num_inputs(int32_t n);  // NOLINT(*)
-  /*!
-   * \brief Set the support level of op.
-   * \param level The support level.
-   * \return reference to self.
-   */
-  inline OpRegistry& set_support_level(int32_t level);  // NOLINT(*)
-  /*!
-   * \brief Register additional attributes to operator.
-   * \param attr_name The name of the attribute.
-   * \param value The value to be set.
-   * \param plevel The priority level of this set,
-   *  an higher priority level attribute
-   *  will replace lower priority level attribute.
-   *  Must be bigger than 0.
-   *
-   *  Cannot set with same plevel twice in the code.
-   *
-   * \tparam ValueType The type of the value to be set.
-   */
-  template <typename ValueType>
-  inline OpRegistry& set_attr(const std::string& attr_name,  // NOLINT(*)
-                              const ValueType& value, int plevel = 10);
-
-  /*!
-   * \brief Resets an attr of the registry.
-   * \param attr_name The name of the attribute.
-   */
-  inline void reset_attr(const std::string& attr_name);
-
-  // set the name of the op to be the same as registry
-  inline OpRegistry& set_name() {  // NOLINT(*)
-    if (get()->name.length() == 0) {
-      get()->name = name;
-    }
-    return *this;
-  }
-  /*! \return The global single registry */
-  TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry();
-
- private:
-  friend class ::dmlc::Registry<OpRegistry>;
-  // the name
-  std::string name;
-  /*! \brief The operator */
-  Op op_;
-  // private constructor
-  TVM_DLL OpRegistry();
-  // return internal pointer to op.
-  inline OpNode* get();
-  // update the attribute OpMap
-  TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
-                          int plevel);
-};
-
-/*!
- * \brief Generic map to store additional information of Op.
- */
-class GenericOpMap {
- public:
-  /*!
-   * \brief Check if the map has op as key.
-   * \param op The key to the map
-   * \return 1 if op is contained in map, 0 otherwise.
-   */
-  inline int count(const Op& op) const;
-  /*!
-   * \brief get the corresponding value element at op
-   * \param op The key to the map
-   * \return the const reference to the content value.
-   */
-  inline const TVMRetValue& operator[](const Op& op) const;
-  /*!
-   * \brief get the corresponding value element at op with default value.
-   * \param op The key to the map
-   * \param def_value The default value when the key does not exist.
-   * \return the const reference to the content value.
-   * \tparam ValueType The content value type.
-   */
-  template <typename ValueType>
-  inline ValueType get(const Op& op, ValueType def_value) const;
-  /*!
-   * \brief get the corresponding value element at op with default value.
-   * \param expr The key to the map
-   * \param def_value The default value when the key does not exist
-   *         or if expr is not an Op.
-   * \return the const reference to the content value.
-   * \tparam ValueType The content value type.
-   */
-  template <typename ValueType>
-  inline ValueType get(const Expr& expr, ValueType def_value) const;
-
- private:
-  friend class OpRegistry;
-  // the attribute field.
-  std::string attr_name_;
-  // internal data
-  std::vector<std::pair<TVMRetValue, int> > data_;
-  // The value
-  GenericOpMap() = default;
-};
-
-/*!
- * \brief Map<Op,ValueType> used to store meta-information about Op.
- * \tparam ValueType The type of the value stored in map.
- */
-template <typename ValueType>
-class OpMap {
- public:
-  /*!
-   * \brief Check if the map has op as key.
-   * \param op The key to the map
-   * \return 1 if op is contained in map, 0 otherwise.
-   */
-  inline int count(const Op& op) const;
-  /*!
-   * \brief get the corresponding value element at op
-   * \param op The key to the map
-   * \return the const reference to the content value.
-   */
-  inline ValueType operator[](const Op& op) const;
-  /*!
-   * \brief get the corresponding value element at op with default value.
-   * \param op The key to the map
-   * \param def_value The default value when the key does not exist.
-   * \return the const reference to the content value.
-   */
-  inline ValueType get(const Op& op, ValueType def_value) const;
-  /*!
-   * \brief get the corresponding value element at op with default value.
-   * \param expr The key to the map
-   * \param def_value The default value when the key does not exist
-   *         or if expr is not an Op.
-   * \return the const reference to the content value.
-   */
-  inline ValueType get(const Expr& expr, ValueType def_value) const;
-
- private:
-  friend class Op;
-  // constructor
-  explicit OpMap(const GenericOpMap& map) : map_(map) {}
-  /*! \brief The internal map field */
-  const GenericOpMap& map_;
-};
-
-// internal macros to make
-#define RELAY_REGISTER_VAR_DEF \
-  static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp
-
-/*!
- * \def RELAY_REGISTER_OP
- * \brief Register a new operator, or set attribute of the corresponding op.
- *
- * \param OpName The name of registry
- *
- * \code
- *
- *  RELAY_REGISTER_OP("add")
- *  .describe("add two inputs together")
- *  .set_num_inputs(2)
- *  .set_attr<OpKernel>("gpu_kernel", AddKernel);
- *
- * \endcode
- */
-#define RELAY_REGISTER_OP(OpName)                        \
-  DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \
-      ::tvm::relay::OpRegistry::Registry()               \
-          ->__REGISTER_OR_GET__(OpName)                  \
-          .set_name()
-
-// implementations
-inline const OpNode* Op::operator->() const {
-  return static_cast<const OpNode*>(get());
-}
-
-template <typename ValueType>
-inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
-  return OpMap<ValueType>(Op::GetGenericAttr(key));
-}
-
-inline bool Op::HasAttr(const std::string& key) {
-  return Op::HasGenericAttr(key);
-}
-
-inline OpNode* OpRegistry::get() {
-  return const_cast<OpNode*>(op_.operator->());
-}
-
-inline OpRegistry& OpRegistry::describe(
-    const std::string& descr) {  // NOLINT(*)
-  get()->description = descr;
-  return *this;
-}
-
-inline OpRegistry& OpRegistry::add_argument(const std::string& name,
-                                            const std::string& type,
-                                            const std::string& description) {
-  auto n = make_object<AttrFieldInfoNode>();
-  n->name = name;
-  n->type_info = type;
-  n->description = description;
-  get()->arguments.push_back(AttrFieldInfo(n));
-  return *this;
-}
-
-inline OpRegistry& OpRegistry::add_type_rel(
-    const std::string& rel_name,
-    runtime::TypedPackedFunc<bool(const Array<Type>&,
-                                  int,
-                                  const Attrs&,
-                                  const TypeReporter&)> type_rel_func) {
-  auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
-  TypeRelationFn env_type_rel_func;
-
-  if (runtime::Registry::Get(func_name)) {
-    auto env_func = EnvFunc::Get(func_name);
-    env_type_rel_func = env_func;
-  } else {
-    runtime::Registry::Register(func_name)
-        .set_body(type_rel_func.packed());
-    auto env_func = EnvFunc::Get(func_name);
-    env_type_rel_func = env_func;
-  }
-
-  Array<TypeVar> type_params;
-  Array<Type> arg_types;
-
-  // Add inputs.
-  std::string input_name_prefix = "in";
-  for (int i = 0; i < get()->num_inputs; i++) {
-    auto name = input_name_prefix + std::to_string(i);
-    auto param = TypeVarNode::make(name, Kind::kType);
-    type_params.push_back(param);
-    arg_types.push_back(param);
-  }
-
-  Array<Type> ty_call_args = arg_types;
-
-  // Add output type.
-  auto out_param = TypeVarNode::make("out", Kind::kType);
-  type_params.push_back(out_param);
-  // this will trigger copy on write.
-  ty_call_args.push_back(out_param);
-
-  // The attributes of primitive op is nullptr
-  //
-  // The attributes of primitive operator can vary at the call site.
-  // The type of sum is also dependent on Attrs being passed.
-  // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs.
-  //
-  // A common example is sum(x, axis), where the choice of axis
-  // can affect the type of the function.
-  TypeConstraint type_rel =
-      TypeRelationNode::make(env_type_rel_func,
-                             ty_call_args,
-                             arg_types.size(),
-                             Attrs());
-
-  auto func_type =
-      FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
-
-  get()->op_type = func_type;
-
-  return *this;
-}
+using Op = tvm::Op;
+using OpNode = tvm::OpNode;
 
-inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) {  // NOLINT(*)
-  get()->num_inputs = n;
-  return *this;
-}
-
-template<typename AttrsType>
-inline OpRegistry& OpRegistry::set_attrs_type() {  // NOLINT(*)
-  get()->attrs_type_key = AttrsType::_type_key;
-  get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
-  return *this;
-}
-
-inline OpRegistry& OpRegistry::set_support_level(int32_t n) {  // NOLINT(*)
-  get()->support_level = n;
-  return *this;
-}
-
-template <typename ValueType>
-inline OpRegistry& OpRegistry::set_attr(  // NOLINT(*)
-    const std::string& attr_name, const ValueType& value, int plevel) {
-  CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
-  TVMRetValue rv;
-  rv = value;
-  UpdateAttr(attr_name, rv, plevel);
-  return *this;
-}
-
-// member functions of OpMap
-inline int GenericOpMap::count(const Op& op) const {
-  if (op.defined()) {
-    const uint32_t idx = op->index_;
-    return idx < data_.size() ? (data_[idx].second != 0) : 0;
-  } else {
-    return 0;
-  }
-}
-
-inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
-  CHECK(op.defined());
-  const uint32_t idx = op->index_;
-  CHECK(idx < data_.size() && data_[idx].second != 0)
-      << "Attribute " << attr_name_ << " has not been registered for Operator "
-      << op->name;
-  return data_[idx].first;
-}
-
-template <typename ValueType>
-inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
-  CHECK(op.defined());
-  const uint32_t idx = op->index_;
-  if (idx < data_.size() && data_[idx].second != 0) {
-    return data_[idx].first;
-  } else {
-    return value;
-  }
-}
-
-template <typename ValueType>
-inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const {
-  CHECK(expr.defined());
-  if (const OpNode* op = expr.as<OpNode>()) {
-    const uint32_t idx = op->index_;
-    if (idx < data_.size() && data_[idx].second != 0) {
-      return data_[idx].first;
-    } else {
-      return value;
-    }
-  } else {
-    return value;
-  }
-}
-
-template <typename ValueType>
-inline int OpMap<ValueType>::count(const Op& op) const {
-  return map_.count(op);
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
-  return map_[op];
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const Op& op,
-                                       ValueType def_value) const {
-  return map_.get<ValueType>(op, def_value);
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const Expr& expr,
-                                       ValueType def_value) const {
-  return map_.get<ValueType>(expr, def_value);
-}
-
-/*!
- * \brief Check that an expression is a "primitive operator".
- *
- * Will return true if the expression is an operator which
- * matches the form of primitive operators registered directly
- * by the Relay codebase.
- *
- * That is the arguments are all type variables, and there is a single
- * type relation applied to the input and output types.
- */
-inline bool IsPrimitiveOp(const Expr& expr) {
-  const auto* op = expr.as<OpNode>();
-  return op != nullptr && op->IsPrimitiveOp();
-}
+#define RELAY_REGISTER_OP(OpName)               \
+  TVM_REGISTER_OP(OpName)
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h
index 31e85f9..7748bd1 100644
--- a/include/tvm/relay/type.h
+++ b/include/tvm/relay/type.h
@@ -24,8 +24,8 @@
 #ifndef TVM_RELAY_TYPE_H_
 #define TVM_RELAY_TYPE_H_
 
-
 #include <tvm/ir/type.h>
+#include <tvm/ir/type_relation.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/node/env_func.h>
@@ -51,6 +51,11 @@ using TypeConstraint = tvm::TypeConstraint;
 using TypeConstraintNode = tvm::TypeConstraintNode;
 using FuncType = tvm::FuncType;
 using FuncTypeNode = tvm::FuncTypeNode;
+using TypeRelation = tvm::TypeRelation;
+using TypeRelationNode = tvm::TypeRelationNode;
+using TypeRelationFn = tvm::TypeRelationFn;
+using TypeReporter = tvm::TypeReporter;
+using TypeReporterNode = tvm::TypeReporterNode;
 
 /*!
  * \brief Base of all Tensor types
@@ -235,146 +240,6 @@ class RefType : public Type {
   TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode);
 };
 
-class TypeReporter;
-
-/*!
- * \brief reporter that reports back to the
- *  type resolution information.
- */
-class TypeReporterNode : public Object {
- public:
-  /*!
-   * \brief Create a type equality constraint.
-   *
-   *  The "assign direction" acts as a hint to the solver
-   *  showing that it is more likely to resolve dst by src.
-   *  But it is possible for the solver to resolve src by dst as well.
-   */
-  TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
-
-  /*!
-   * \brief assert shape expression comparison.
-   * \note Use assert only if any of the condition input is symbolic.
-   * \param cond The condition of operation.
-   * \return false if assertation can be proven to have failed
-   *      true if solver can still proceed.
-   */
-  TVM_DLL virtual bool Assert(const IndexExpr& cond)= 0;
-  /*!
-   * \brief assert shape expression equals each other.
-   * \param lhs The left operand.
-   * \param rhs The right operand.
-   * \return false if assertation can be proven to have failed
-   *      true if solver can still proceed.
-   */
-  TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
-
-  /*!
-   * \brief Set the location at which to report unification errors.
-   * \param ref The program node to report the error.
-   */
-  TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0;
-
-  /*!
-   * \brief Retrieve the current global module.
-   * \return The global module.
-   */
-  TVM_DLL virtual Module GetModule() = 0;
-
-  // solver is not serializable.
-  void VisitAttrs(tvm::AttrVisitor* v) {}
-
-  static constexpr const char* _type_key = "relay.TypeReporter";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
-};
-
-/*!
- * \brief Container class of TypeReporter.
- * \sa TypeReporterNode
- */
-class TypeReporter : public ObjectRef {
- public:
-  TypeReporter() {}
-  explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
-  }
-  TypeReporterNode* operator->() const {
-    return const_cast<TypeReporterNode*>(
-        static_cast<const TypeReporterNode*>(get()));
-  }
-  using ContainerType = TypeReporterNode;
-};
-
-/*!
- * \brief User defined type constraint function.
- *
- * If the input type information can be used to fully decide
- * the IncompleteTypes, then the function should call
- * reporter.Assign to report the new types, and return true.
- * Otherwise, the function should return false.
- *
- * \param args The arguments to the relation.
- *   The types are stored in the form of
- *   [input_type_0, input_type_1, ... input_type_n,
- *    output_type_0, output_type_1, ... output_type_m]
- *
- * \param num_inputs Number of input types in the args.
- * \param attrs The additional attributes of the operator.
- * \param reporter The reporter to report solution to.
- * \return false if This relation cannot be resolved.
- *   true if this relation has been resolved.
- */
-using TypeRelationFn =
-    TypedEnvFunc<bool(const Array<Type>& args,
-                      int num_inputs,
-                      const Attrs& attrs,
-                      const TypeReporter& reporter)>;
-
-/*!
- * \brief User defined type relation, is an input-output relation on types.
- */
-class TypeRelation;
-/*!
- * \brief TypeRelation container.
- * \note This node is not directly serializable.
- * The type function need to be lookedup in the module.
- */
-class TypeRelationNode : public TypeConstraintNode {
- public:
-  /*!
-   * \brief The function on input and output variables which
-   *  this is not directly serializable,
-   *  need to be looked-up in the module.
-   */
-  TypeRelationFn func;
-  /*! \brief The type arguments to the type function. */
-  tvm::Array<Type> args;
-  /*! \brief Number of inputs arguments */
-  int num_inputs;
-  /*! \brief Attributes to the relation function */
-  Attrs attrs;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("func", &func);
-    v->Visit("args", &args);
-    v->Visit("num_inputs", &num_inputs);
-    v->Visit("attrs", &attrs);
-    v->Visit("span", &span);
-  }
-
-  TVM_DLL static TypeRelation make(TypeRelationFn func,
-                                   Array<Type> args,
-                                   int num_args,
-                                   Attrs attrs);
-
-  static constexpr const char* _type_key = "relay.TypeRelation";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
-};
-
-class TypeRelation : public TypeConstraint {
- public:
-  TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
-};
-
 // The following fields contains advanced typing
 // Only keep the class name and reserved for future usage.
 class GenericTensorType;
diff --git a/src/relay/ir/op.cc b/src/ir/op.cc
similarity index 96%
rename from src/relay/ir/op.cc
rename to src/ir/op.cc
index b888ecb..0ed2f3d 100644
--- a/src/relay/ir/op.cc
+++ b/src/ir/op.cc
@@ -18,11 +18,11 @@
  */
 
 /*!
- * \file src/tvm/relay/op.cc
- * \brief Resolve incomplete types to complete types.
+ * \file src/tvm/ir/op.cc
+ * \brief Primitive operators and intrinsics.
  */
-#include <tvm/relay/op.h>
-#include <tvm/relay/type.h>
+#include <tvm/ir/op.h>
+#include <tvm/ir/type.h>
 #include <tvm/runtime/module.h>
 #include <tvm/runtime/packed_func.h>
 
@@ -31,11 +31,10 @@
 
 namespace dmlc {
 // enable registry
-DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
+DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
 }  // namespace dmlc
 
 namespace tvm {
-namespace relay {
 
 ::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
   return ::dmlc::Registry<OpRegistry>::Get();
@@ -230,5 +229,4 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
     p->stream << "Op(" << node->name << ")";
   });
 
-}  // namespace relay
 }  // namespace tvm
diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc
new file mode 100644
index 0000000..cc5ceef
--- /dev/null
+++ b/src/ir/type_relation.cc
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/ir/type_relation.cc
+ * \brief Type relation
+ */
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_relation.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/packed_func_ext.h>
+
+namespace tvm {
+TypeRelation TypeRelationNode::make(TypeRelationFn func,
+                                    Array<Type> args,
+                                    int num_inputs,
+                                    Attrs attrs) {
+  ObjectPtr<TypeRelationNode> n = make_object<TypeRelationNode>();
+  n->func = std::move(func);
+  n->args = std::move(args);
+  n->num_inputs = num_inputs;
+  n->attrs = std::move(attrs);
+  return TypeRelation(n);
+}
+
+TVM_REGISTER_NODE_TYPE(TypeRelationNode);
+
+TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
+.set_body_typed(TypeRelationNode::make);
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
+    auto* node = static_cast<const TypeRelationNode*>(ref.get());
+    p->stream << "TypeRelationNode("
+              << node->func->name
+              << ", " << node->args << ")";
+});
+}  // namespace tvm
diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc
index 4ae2ee5..099b801 100644
--- a/src/relay/ir/type.cc
+++ b/src/relay/ir/type.cc
@@ -101,30 +101,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
     p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
   });
 
-TypeRelation TypeRelationNode::make(TypeRelationFn func,
-                                    Array<Type> args,
-                                    int num_inputs,
-                                    Attrs attrs) {
-  ObjectPtr<TypeRelationNode> n = make_object<TypeRelationNode>();
-  n->func = std::move(func);
-  n->args = std::move(args);
-  n->num_inputs = num_inputs;
-  n->attrs = std::move(attrs);
-  return TypeRelation(n);
-}
-
-TVM_REGISTER_NODE_TYPE(TypeRelationNode);
-
-TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
-.set_body_typed(TypeRelationNode::make);
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
-    auto* node = static_cast<const TypeRelationNode*>(ref.get());
-    p->stream << "TypeRelationNode("
-              << node->func->name
-              << ", " << node->args << ")";
-});
 
 TupleType TupleTypeNode::make(Array<Type> fields) {
   ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();