You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/17 00:03:25 UTC

[GitHub] [tvm] mbs-octoml commented on a change in pull request #9735: [AMP][Pass][Typing] Add faster type inference

mbs-octoml commented on a change in pull request #9735:
URL: https://github.com/apache/tvm/pull/9735#discussion_r771001558



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -381,6 +381,18 @@ class MixedPrecisionPass : public MixedModeMutator {
     return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
   }
 
+  Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
+    // The old checked type in the expression may not be valid so clear it
+    post->checked_type_ = Type(nullptr);

Review comment:
       am I missing something or will checked_type_ = null iff some sub-expression of post has been rewritten and thus it's type has changed?
   ie checked_type_ is non-null only if pre == post.get() ??
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -176,13 +176,13 @@ class MixedPrecisionPass : public MixedModeMutator {
   }
 
   Type GetType(const Expr& expr) const {
-    auto mod = IRModule::FromExpr(expr);
-    mod = transform::InferType()(mod);
-    if (expr.as<FunctionNode>()) {
-      return mod->Lookup("main")->checked_type();
-    } else {
-      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    Type checked_type = expr->checked_type_;
+    if (checked_type.defined()) {
+      return checked_type;

Review comment:
       // The expression has not been changed AND it's existing type
   // is known to still be valid. (See special handling for tuples etc
   // below for where we null out checked_type_ when we can not
   // sure it is still valid.
   
   (though see my comment below)

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*

Review comment:
       nit: Returns  the largest sub-graph who's inner nodes need types and leaves are vars standing in
   for already typed sub-expressions.

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*

Review comment:
       micro nit: move to before class, used /*! etc.

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*
+  Creates a small subgraph with the same type as the input expression. We attempt to do
+  by depending on existing type information being populated in expressions the target
+  node depends on. If a node with populated type information is found we simply
+  replace it with a variable of that type. In this way, we can avoid copying and
+  recursing through most of the expression graph. Note, this assumes that current
+  populated type information is correct!
+
+  ExprMutator is sufficient over MixedModemutator since we will not recurse much.
+  */
+
+  Expr VisitExpr_(const VarNode* op) { return Var(op->vid, op->type_annotation, op->span); }
+  Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, op->span); }
+  Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
+  Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
+  Expr VisitExpr_(const TupleNode* op) {
+    return Tuple(get_analogous_expression(op->fields), op->span);
+  }
+  Expr VisitExpr_(const FunctionNode* op) {
+    // Here will be the only VisitExpr
+    return Function(op->params, get_analogous_expression(op->body), op->ret_type, op->type_params,
+                    op->attrs, op->span);
+  }
+  Expr VisitExpr_(const CallNode* op) {
+    return Call(op->op, get_analogous_expression(op->args), op->attrs, op->type_args, op->span);
+  }
+  Expr VisitExpr_(const LetNode* op) {
+    return Let(op->var, get_analogous_expression(op->value), get_analogous_expression(op->body),
+               op->span);
+  }
+  Expr VisitExpr_(const IfNode* op) {
+    return If(get_analogous_expression(op->cond), get_analogous_expression(op->true_branch),
+              get_analogous_expression(op->false_branch), op->span);
+  }
+  Expr VisitExpr_(const TupleGetItemNode* op) {
+    return TupleGetItem(get_analogous_expression(op->tuple), op->index, op->span);
+  }
+  Expr VisitExpr_(const RefCreateNode* op) {
+    return RefCreate(get_analogous_expression(op->value), op->span);
+  }
+  Expr VisitExpr_(const RefReadNode* op) {
+    return RefRead(get_analogous_expression(op->ref), op->span);
+  }
+  Expr VisitExpr_(const RefWriteNode* op) {
+    return RefWrite(get_analogous_expression(op->ref), get_analogous_expression(op->value),
+                    op->span);
+  }
+  Expr VisitExpr_(const ConstructorNode* op) {
+    return Constructor(op->name_hint, op->inputs, op->belong_to);
+  }
+  Expr VisitExpr_(const MatchNode* op) {
+    return Match(get_analogous_expression(op->data), op->clauses, op->complete, op->span);
+  }
+
+ private:
+  Expr get_analogous_expression(const Expr& expr) {
+    // Replace the expression with a potentially simpler expression of the same type
+    if (!expr->checked_type_.defined()) {
+      return VisitExpr(expr);
+    }
+
+    return Var("dummy_var", expr->checked_type(), expr->span);

Review comment:
       // Since the expression already has a checked_type which we trust we don't need
   // full type inference to enter it. So stub it out with a dummy var of the same type. 

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*
+  Creates a small subgraph with the same type as the input expression. We attempt to do
+  by depending on existing type information being populated in expressions the target
+  node depends on. If a node with populated type information is found we simply
+  replace it with a variable of that type. In this way, we can avoid copying and
+  recursing through most of the expression graph. Note, this assumes that current
+  populated type information is correct!
+
+  ExprMutator is sufficient over MixedModemutator since we will not recurse much.
+  */
+
+  Expr VisitExpr_(const VarNode* op) { return Var(op->vid, op->type_annotation, op->span); }
+  Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, op->span); }
+  Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
+  Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
+  Expr VisitExpr_(const TupleNode* op) {
+    return Tuple(get_analogous_expression(op->fields), op->span);
+  }
+  Expr VisitExpr_(const FunctionNode* op) {
+    // Here will be the only VisitExpr
+    return Function(op->params, get_analogous_expression(op->body), op->ret_type, op->type_params,
+                    op->attrs, op->span);
+  }
+  Expr VisitExpr_(const CallNode* op) {
+    return Call(op->op, get_analogous_expression(op->args), op->attrs, op->type_args, op->span);
+  }
+  Expr VisitExpr_(const LetNode* op) {
+    return Let(op->var, get_analogous_expression(op->value), get_analogous_expression(op->body),
+               op->span);
+  }
+  Expr VisitExpr_(const IfNode* op) {
+    return If(get_analogous_expression(op->cond), get_analogous_expression(op->true_branch),
+              get_analogous_expression(op->false_branch), op->span);
+  }
+  Expr VisitExpr_(const TupleGetItemNode* op) {
+    return TupleGetItem(get_analogous_expression(op->tuple), op->index, op->span);
+  }
+  Expr VisitExpr_(const RefCreateNode* op) {
+    return RefCreate(get_analogous_expression(op->value), op->span);
+  }
+  Expr VisitExpr_(const RefReadNode* op) {
+    return RefRead(get_analogous_expression(op->ref), op->span);
+  }
+  Expr VisitExpr_(const RefWriteNode* op) {
+    return RefWrite(get_analogous_expression(op->ref), get_analogous_expression(op->value),
+                    op->span);
+  }
+  Expr VisitExpr_(const ConstructorNode* op) {
+    return Constructor(op->name_hint, op->inputs, op->belong_to);
+  }
+  Expr VisitExpr_(const MatchNode* op) {
+    return Match(get_analogous_expression(op->data), op->clauses, op->complete, op->span);
+  }
+
+ private:
+  Expr get_analogous_expression(const Expr& expr) {

Review comment:
       nit: GetAnalogousExpression




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org