You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/01/22 00:09:36 UTC

[tvm] branch main updated: Fix an issue with dynamic functions overwritting call arg types (#7295)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b6a1a7  Fix an issue with dynamic functions overwritting call arg types (#7295)
7b6a1a7 is described below

commit 7b6a1a7bcaa403b1c277a494e58774dc36b38326
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Thu Jan 21 17:09:19 2021 -0700

    Fix an issue with dynamic functions overwritting call arg types (#7295)
    
    * Fix an issue with dynamic functions overwritting call arg types
    
    * fix a bug for un-annotated inputs
    
    * normalize names in TypeSolver::Unifier
    
    * fix name normalization
---
 src/relay/analysis/type_solver.cc     | 18 ++++++++++--------
 src/relay/analysis/type_solver.h      |  3 ++-
 src/relay/transforms/type_infer.cc    | 12 ++++++------
 tests/python/relay/test_type_infer.py | 14 ++++++++++++++
 4 files changed, 32 insertions(+), 15 deletions(-)

diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc
index 64db13a..cc1ada6 100644
--- a/src/relay/analysis/type_solver.cc
+++ b/src/relay/analysis/type_solver.cc
@@ -102,11 +102,12 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
  public:
   explicit Unifier(TypeSolver* solver, const Span& span) : solver_(solver), span(span) {}
 
-  Type Unify(const Type& src, const Type& dst) {
+  Type Unify(const Type& lhs_type, const Type& rhs_type, bool assign_lhs = true,
+             bool assign_rhs = true) {
     // Known limitation
     // - handle shape pattern matching
-    TypeNode* lhs = solver_->GetTypeNode(dst);
-    TypeNode* rhs = solver_->GetTypeNode(src);
+    TypeNode* lhs = solver_->GetTypeNode(lhs_type);
+    TypeNode* rhs = solver_->GetTypeNode(rhs_type);
 
     // do occur check so we don't create self-referencing structure
     if (lhs->FindRoot() == rhs->FindRoot()) {
@@ -127,7 +128,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
       solver_->MergeFromTo(rhs, lhs);
       return lhs->resolved_type;
     } else {
-      Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
+      Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type);
 
       if (!resolved.defined()) {
         solver_->diag_ctx_.Emit(
@@ -139,8 +140,8 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
         return lhs->resolved_type;
       } else {
         TypeNode* top = solver_->GetTypeNode(resolved);
-        solver_->MergeFromTo(lhs, top);
-        solver_->MergeFromTo(rhs, top);
+        if (assign_lhs) solver_->MergeFromTo(lhs, top);
+        if (assign_rhs) solver_->MergeFromTo(rhs, top);
         return resolved;
       }
     }
@@ -549,9 +550,10 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) {
 }
 
 // Add equality constraint
-Type TypeSolver::Unify(const Type& dst, const Type& src, const Span& span) {
+Type TypeSolver::Unify(const Type& dst, const Type& src, const Span& span, bool assign_lhs,
+                       bool assign_rhs) {
   Unifier unifier(this, span);
-  return unifier.Unify(dst, src);
+  return unifier.Unify(dst, src, assign_lhs, assign_rhs);
 }
 
 // Add type constraint to the solver.
diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h
index 4ae2e6a..56cea60 100644
--- a/src/relay/analysis/type_solver.h
+++ b/src/relay/analysis/type_solver.h
@@ -88,7 +88,8 @@ class TypeSolver {
    * \param rhs The right operand
    * \param location The location at which the unification problem arose.
    */
-  Type Unify(const Type& lhs, const Type& rhs, const Span& span);
+  Type Unify(const Type& lhs, const Type& rhs, const Span& span, bool assign_lhs = true,
+             bool assign_rhs = true);
   /*!
    * \brief Report a diagnostic.
    * \param diag The diagnostic to report.
diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc
index 327b5d1..921e83f 100644
--- a/src/relay/transforms/type_infer.cc
+++ b/src/relay/transforms/type_infer.cc
@@ -162,9 +162,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
 
   // Perform unification on two types and report the error at the expression
   // or the span of the expression.
-  Type Unify(const Type& t1, const Type& t2, const Span& span) {
+  Type Unify(const Type& t1, const Type& t2, const Span& span, bool assign_lhs = true,
+             bool assign_rhs = true) {
     try {
-      return solver_.Unify(t1, t2, span);
+      return solver_.Unify(t1, t2, span, assign_lhs, assign_rhs);
     } catch (const dmlc::Error& e) {
       this->EmitFatal(Diagnostic::Error(span)
                       << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what());
@@ -495,7 +496,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     }
 
     for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
-      this->Unify(fn_ty->arg_types[i], arg_types[i], call->span);
+      this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false);
     }
 
     for (auto cs : fn_ty->type_constraints) {
@@ -526,6 +527,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
       }
     }
 
+    solver_.Solve();
     return GeneralCall(call, arg_types);
   }
 
@@ -572,9 +574,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {});
   }
 
-  void Solve() {
-    solver_.Solve();
-  }
+  void Solve() { solver_.Solve(); }
 };
 
 class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator {
diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py
index b518c31..e8179a3 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -402,6 +402,20 @@ def @main(%f: float32) -> float32 {
     tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")])
 
 
+def test_dynamic_function():
+    dy_tt = relay.TensorType([relay.Any()], "float32")
+    s_tt = relay.TensorType([10], "float32")
+    x = relay.Var("x", dy_tt)
+    f = relay.Function([x], x + x)
+    y = relay.Var("y", s_tt)
+    c = f(y)
+
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function([y], c)
+    mod = transform.InferType()(mod)
+    assert mod["main"].params[0].checked_type == s_tt
+
+
 if __name__ == "__main__":
     import sys