You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/01/04 07:09:05 UTC

[incubator-tvm] branch master updated: [REFACTOR][TYPE] Remove un-necessary var sub-field in GlobalTypeVar and TypeVar (#4615)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 24e6fcb  [REFACTOR][TYPE] Remove un-necessary var sub-field in GlobalTypeVar and TypeVar (#4615)
24e6fcb is described below

commit 24e6fcb687dab623393926cf162ef41932901695
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Fri Jan 3 23:08:55 2020 -0800

    [REFACTOR][TYPE] Remove un-necessary var sub-field in GlobalTypeVar and TypeVar (#4615)
    
    Currently, we use a tvm::Var to represent a placeholder for shapes in generic types.
    This is not necessary for GlobalTypeVar(as we never parameterize by shape var),
    and is a bit twisted for TypeVar.
    
    As we move to a unified type system, we want to break the dependency
    from the base TypeVar(which is shared across the languages) from the expression.
    Note that it is fine for TensorType to depend on Expr.
    
    One alternative solution to embed the Var would be to introduce a TypeVarExpr,
    which can wrap a TypeVar as Expr. However, this new alternative won't be
    natural until we migrate the type to the global scope.
    
    Lucikly, we have not yet start to depend on the shape parameterization heavily yet.
    
    This PR removes the tvm::Var from the typevars. We will follow up with another
    PR to migrate the types to a base location. After that, we should be able to
    use the more elegant approach via TypeVarExpr.
---
 include/tvm/relay/type.h         | 18 ++++++------------
 python/tvm/relay/_parser.py      |  2 +-
 python/tvm/relay/type_functor.py |  4 ++--
 src/relay/ir/alpha_equal.cc      |  9 ++-------
 src/relay/ir/hash.cc             | 10 +++++-----
 src/relay/ir/module.cc           | 16 ++++++++--------
 src/relay/ir/pretty_printer.cc   |  8 ++++----
 src/relay/ir/type.cc             | 16 ++++++++--------
 src/relay/pass/de_duplicate.cc   |  2 +-
 src/relay/pass/to_cps.cc         |  2 +-
 tests/python/relay/test_any.py   |  7 ++++---
 11 files changed, 42 insertions(+), 52 deletions(-)

diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h
index 08fe957..8f51ea9 100644
--- a/include/tvm/relay/type.h
+++ b/include/tvm/relay/type.h
@@ -157,16 +157,13 @@ class TypeVar;
 /*! \brief TypeVar container node */
 class TypeVarNode : public TypeNode {
  public:
-  /*!
-   * \brief The variable itself is only meaningful when
-   *  kind is ShapeVar, otherwise, we only use the name.
-   */
-  tvm::Var var;
+  /*! \brief Name of the variable, it only acts as a hint. */
+  std::string name_hint;
   /*! \brief The kind of type parameter */
   Kind kind;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("var", &var);
+    v->Visit("name_hint", &name_hint);
     v->Visit("kind", &kind);
     v->Visit("span", &span);
   }
@@ -189,16 +186,13 @@ class GlobalTypeVar;
 /*! \brief GlobalTypeVar container node */
 class GlobalTypeVarNode : public TypeNode {
  public:
-  /*!
-   * \brief The variable itself is only meaningful when
-   *  kind is ShapeVar; otherwise, we only use the name.
-   */
-  tvm::Var var;
+  /*! \brief Name of the variable, it only acts as a hint. */
+  std::string name_hint;
   /*! \brief The kind of type parameter */
   Kind kind;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("var", &var);
+    v->Visit("name_hint", &name_hint);
     v->Visit("kind", &kind);
     v->Visit("span", &span);
   }
diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py
index 45822c5..1f0088a 100644
--- a/python/tvm/relay/_parser.py
+++ b/python/tvm/relay/_parser.py
@@ -272,7 +272,7 @@ class ParseTreeToRelayIR(RelayVisitor):
 
     def _type_expr_name(self, e):
         if isinstance(e, adt.Constructor):
-            return "`{0}` ADT constructor".format(e.belong_to.var.name)
+            return "`{0}` ADT constructor".format(e.belong_to.name_hint)
         elif isinstance(e, ty.GlobalTypeVar):
             if e.kind == ty.Kind.AdtHandle:
                 return "ADT definition"
diff --git a/python/tvm/relay/type_functor.py b/python/tvm/relay/type_functor.py
index 1331058..7139ccb 100644
--- a/python/tvm/relay/type_functor.py
+++ b/python/tvm/relay/type_functor.py
@@ -143,7 +143,7 @@ class TypeMutator(TypeFunctor):
     and reconstructs the AST.
     """
     def visit_type_var(self, tv):
-        return TypeVar(tv.var.name, tv.kind)
+        return TypeVar(tv.name_hint, tv.kind)
 
     def visit_incomplete_type(self, it):
         return IncompleteType(it.kind)
@@ -180,7 +180,7 @@ class TypeMutator(TypeFunctor):
         return RefType(self.visit(rt.value))
 
     def visit_global_type_var(self, gtv):
-        return GlobalTypeVar(gtv.var.name, gtv.kind)
+        return GlobalTypeVar(gtv.name_hint, gtv.kind)
 
     def visit_type_call(self, tc):
         return TypeCall(
diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index 589de09..d8dcddd 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -69,8 +69,8 @@ class AlphaEqualHandler:
       }
       if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
       for (const auto& p : lhsm->type_definitions) {
-        if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
-            !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
+        if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
+            !Equal(p.second, rhsm->LookupDef(p.first->name_hint))) {
           return false;
         }
       }
@@ -233,11 +233,6 @@ class AlphaEqualHandler:
           return false;
         }
         equal_map_[lhs->type_params[i]] = rhs->type_params[i];
-        // set up type parameter equal
-        if (lhs->type_params[i]->kind == Kind::kShapeVar) {
-          // map variable
-          equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
-        }
       }
       for (size_t i = 0; i < lhs->arg_types.size(); i++) {
         if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc
index 15f5105..459e8b0 100644
--- a/src/relay/ir/hash.cc
+++ b/src/relay/ir/hash.cc
@@ -228,11 +228,11 @@ class RelayHashHandler:
       hash = Combine(hash, TypeHash(var_node->type_annotation));
     }
     hash_map_[var] = hash;
-
-    const auto* ty_param = var.as<TypeVarNode>();
-    if (ty_param && ty_param->kind == Kind::kShapeVar) {
-      hash_map_[ty_param->var] = hash;
-    }
+    // TODO(tqchen) Introduce TypeVarExpr
+    // const auto* ty_param = var.as<TypeVarNode>();
+    // if (ty_param && ty_param->kind == Kind::kShapeVar) {
+    //   hash_map_[ty_param->var] = hash;
+    // }
     return hash;
   }
 
diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc
index 2fa79c7..38f86a5 100644
--- a/src/relay/ir/module.cc
+++ b/src/relay/ir/module.cc
@@ -55,9 +55,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
 
   for (const auto& kv : n->type_definitions) {
     // set global typevar map
-    CHECK(n->global_type_var_map_.count(kv.first->var->name_hint) == 0)
-      << "Duplicate global type definition name " << kv.first->var->name_hint;
-    n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
+    CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
+      << "Duplicate global type definition name " << kv.first->name_hint;
+    n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
     n->RegisterConstructors(kv.first, kv.second);
   }
 
@@ -177,7 +177,7 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
   // We hash the global type var name to use as a globally unique prefix for tags.
   // The hash will be used as the most significant byte of the tag, with the index of
   // the constructor in the less significant bytes
-  size_t hash = std::hash<std::string>()(var->var->name_hint);
+  size_t hash = std::hash<std::string>()(var->name_hint);
   int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
   for (size_t i = 0; i < type->constructors.size(); ++i) {
     type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
@@ -197,10 +197,10 @@ void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
   this->type_definitions.Set(var, type);
   if (!update) {
     // set global type var map
-    CHECK(global_type_var_map_.count(var->var->name_hint) == 0)
-      << "Duplicate global type definition name " << var->var->name_hint;
+    CHECK(global_type_var_map_.count(var->name_hint) == 0)
+      << "Duplicate global type definition name " << var->name_hint;
   }
-  global_type_var_map_.Set(var->var->name_hint, var);
+  global_type_var_map_.Set(var->name_hint, var);
   RegisterConstructors(var, type);
 }
 
@@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const {
 TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
   auto it = type_definitions.find(var);
   CHECK(it != type_definitions.end())
-    << "There is no definition of " << var->var->name_hint;
+      << "There is no definition of " << var->name_hint;
   return (*it).second;
 }
 
diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc
index 478469c..9926844 100644
--- a/src/relay/ir/pretty_printer.cc
+++ b/src/relay/ir/pretty_printer.cc
@@ -312,7 +312,7 @@ class PrettyPrinter :
       val << "-malformed-ir";
       return val;
     }
-    std::string name = var->var->name_hint;
+    std::string name = var->name_hint;
     if (name.length() == 0 || !std::isalpha(name[0])) {
       name = "t" + name;
     }
@@ -493,7 +493,7 @@ class PrettyPrinter :
       doc << "[";
       std::vector<Doc> type_params;
       for (const TypeVar& tv : fn->type_params) {
-        type_params.push_back(Doc(tv->var->name_hint));
+        type_params.push_back(Doc(tv->name_hint));
       }
       doc << PrintSep(type_params);
       doc << "]";
@@ -701,11 +701,11 @@ class PrettyPrinter :
   }
 
   Doc VisitType_(const TypeVarNode* node) final {
-    return Doc(node->var->name_hint);
+    return Doc(node->name_hint);
   }
 
   Doc VisitType_(const GlobalTypeVarNode* node) final {
-    return Doc(node->var->name_hint);
+    return Doc(node->name_hint);
   }
 
   Doc VisitType_(const TypeCallNode* node) final {
diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc
index 70071d0..48f211b 100644
--- a/src/relay/ir/type.cc
+++ b/src/relay/ir/type.cc
@@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 
 TypeVar TypeVarNode::make(std::string name, Kind kind) {
   ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
-  n->var = tvm::Var(name);
+  n->name_hint = std::move(name);
   n->kind = std::move(kind);
   return TypeVar(n);
 }
@@ -74,19 +74,19 @@ TVM_REGISTER_NODE_TYPE(TypeVarNode);
 
 TVM_REGISTER_API("relay._make.TypeVar")
 .set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
-    return TypeVarNode::make(name, static_cast<Kind>(kind));
-    });
+  return TypeVarNode::make(name, static_cast<Kind>(kind));
+});
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
-    auto* node = static_cast<const TypeVarNode*>(ref.get());
-  p->stream << "TypeVarNode(" << node->var->name_hint << ", "
-    << node->kind << ")";
+  auto* node = static_cast<const TypeVarNode*>(ref.get());
+  p->stream << "TypeVarNode(" << node->name_hint << ", "
+            << node->kind << ")";
 });
 
 GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
   ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
-  n->var = tvm::Var(name);
+  n->name_hint = std::move(name);
   n->kind = std::move(kind);
   return GlobalTypeVar(n);
 }
@@ -101,7 +101,7 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar")
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
     auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
-  p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", "
+  p->stream << "GlobalTypeVarNode(" << node->name_hint << ", "
             << node->kind << ")";
 });
 
diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc
index 6816cc7..cf99dc3 100644
--- a/src/relay/pass/de_duplicate.cc
+++ b/src/relay/pass/de_duplicate.cc
@@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) {
                        public PatternMutator {
    public:
     TypeVar Fresh(const TypeVar& tv) {
-      TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
+      TypeVar ret = TypeVarNode::make(tv->name_hint, tv->kind);
       type_rename_[tv] = ret;
       return ret;
     }
diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc
index 1dfa327..96e7f1a 100644
--- a/src/relay/pass/to_cps.cc
+++ b/src/relay/pass/to_cps.cc
@@ -334,7 +334,7 @@ Function UnCPS(const Function& f) {
   auto new_ret_type = Type(cont_type->arg_types[0]);
   std::vector<TypeVar> new_type_params;
   for (const auto& tp : f->type_params) {
-    new_type_params.push_back(TypeVarNode::make(tp->var->name_hint, tp->kind));
+    new_type_params.push_back(TypeVarNode::make(tp->name_hint, tp->kind));
   }
   auto answer_type = new_type_params.back();
   new_type_params.pop_back();
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index a30326c..fe2e9e9 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -534,15 +534,16 @@ def test_fused_ops():
         tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2)
 
 def test_arange_with_dynamic_shape():
-    m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
-    x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
+    # m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
+    m, n, k = relay.Any(), relay.Any(), relay.Any()
+    x = relay.var('x', shape=(m, n, k), dtype='float32')
     y0 = relay.shape_of(x)
     y1 = relay.take(y0, relay.const(0, 'int32'))
     y2 = relay.op.arange(y1, dtype="int32")
     y3 = y2 + relay.const(1, dtype="int32")
     data = np.random.rand(10, 5, 3).astype('float32')
     mod = relay.module.Module()
-    mod["main"] = relay.Function([x], y3, type_params=[m, n, k])
+    mod["main"] = relay.Function([x], y3)
     for kind in ["debug", "vm"]:
         ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
         result = ex.evaluate()(data)