You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/01/22 00:09:36 UTC
[tvm] branch main updated: Fix an issue with dynamic functions
overwritting call arg types (#7295)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7b6a1a7 Fix an issue with dynamic functions overwritting call arg types (#7295)
7b6a1a7 is described below
commit 7b6a1a7bcaa403b1c277a494e58774dc36b38326
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Thu Jan 21 17:09:19 2021 -0700
Fix an issue with dynamic functions overwritting call arg types (#7295)
* Fix an issue with dynamic functions overwritting call arg types
* fix a bug for un-annotated inputs
* normalize names in TypeSolver::Unifier
* fix name normalization
---
src/relay/analysis/type_solver.cc | 18 ++++++++++--------
src/relay/analysis/type_solver.h | 3 ++-
src/relay/transforms/type_infer.cc | 12 ++++++------
tests/python/relay/test_type_infer.py | 14 ++++++++++++++
4 files changed, 32 insertions(+), 15 deletions(-)
diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc
index 64db13a..cc1ada6 100644
--- a/src/relay/analysis/type_solver.cc
+++ b/src/relay/analysis/type_solver.cc
@@ -102,11 +102,12 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
public:
explicit Unifier(TypeSolver* solver, const Span& span) : solver_(solver), span(span) {}
- Type Unify(const Type& src, const Type& dst) {
+ Type Unify(const Type& lhs_type, const Type& rhs_type, bool assign_lhs = true,
+ bool assign_rhs = true) {
// Known limitation
// - handle shape pattern matching
- TypeNode* lhs = solver_->GetTypeNode(dst);
- TypeNode* rhs = solver_->GetTypeNode(src);
+ TypeNode* lhs = solver_->GetTypeNode(lhs_type);
+ TypeNode* rhs = solver_->GetTypeNode(rhs_type);
// do occur check so we don't create self-referencing structure
if (lhs->FindRoot() == rhs->FindRoot()) {
@@ -127,7 +128,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
solver_->MergeFromTo(rhs, lhs);
return lhs->resolved_type;
} else {
- Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
+ Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type);
if (!resolved.defined()) {
solver_->diag_ctx_.Emit(
@@ -139,8 +140,8 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return lhs->resolved_type;
} else {
TypeNode* top = solver_->GetTypeNode(resolved);
- solver_->MergeFromTo(lhs, top);
- solver_->MergeFromTo(rhs, top);
+ if (assign_lhs) solver_->MergeFromTo(lhs, top);
+ if (assign_rhs) solver_->MergeFromTo(rhs, top);
return resolved;
}
}
@@ -549,9 +550,10 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) {
}
// Add equality constraint
-Type TypeSolver::Unify(const Type& dst, const Type& src, const Span& span) {
+Type TypeSolver::Unify(const Type& dst, const Type& src, const Span& span, bool assign_lhs,
+ bool assign_rhs) {
Unifier unifier(this, span);
- return unifier.Unify(dst, src);
+ return unifier.Unify(dst, src, assign_lhs, assign_rhs);
}
// Add type constraint to the solver.
diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h
index 4ae2e6a..56cea60 100644
--- a/src/relay/analysis/type_solver.h
+++ b/src/relay/analysis/type_solver.h
@@ -88,7 +88,8 @@ class TypeSolver {
* \param rhs The right operand
* \param location The location at which the unification problem arose.
*/
- Type Unify(const Type& lhs, const Type& rhs, const Span& span);
+ Type Unify(const Type& lhs, const Type& rhs, const Span& span, bool assign_lhs = true,
+ bool assign_rhs = true);
/*!
* \brief Report a diagnostic.
* \param diag The diagnostic to report.
diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc
index 327b5d1..921e83f 100644
--- a/src/relay/transforms/type_infer.cc
+++ b/src/relay/transforms/type_infer.cc
@@ -162,9 +162,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// Perform unification on two types and report the error at the expression
// or the span of the expression.
- Type Unify(const Type& t1, const Type& t2, const Span& span) {
+ Type Unify(const Type& t1, const Type& t2, const Span& span, bool assign_lhs = true,
+ bool assign_rhs = true) {
try {
- return solver_.Unify(t1, t2, span);
+ return solver_.Unify(t1, t2, span, assign_lhs, assign_rhs);
} catch (const dmlc::Error& e) {
this->EmitFatal(Diagnostic::Error(span)
<< "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what());
@@ -495,7 +496,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
- this->Unify(fn_ty->arg_types[i], arg_types[i], call->span);
+ this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false);
}
for (auto cs : fn_ty->type_constraints) {
@@ -526,6 +527,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
}
+ solver_.Solve();
return GeneralCall(call, arg_types);
}
@@ -572,9 +574,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {});
}
- void Solve() {
- solver_.Solve();
- }
+ void Solve() { solver_.Solve(); }
};
class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator {
diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py
index b518c31..e8179a3 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -402,6 +402,20 @@ def @main(%f: float32) -> float32 {
tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")])
+def test_dynamic_function():
+ dy_tt = relay.TensorType([relay.Any()], "float32")
+ s_tt = relay.TensorType([10], "float32")
+ x = relay.Var("x", dy_tt)
+ f = relay.Function([x], x + x)
+ y = relay.Var("y", s_tt)
+ c = f(y)
+
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([y], c)
+ mod = transform.InferType()(mod)
+ assert mod["main"].params[0].checked_type == s_tt
+
+
if __name__ == "__main__":
import sys