You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/10 04:01:40 UTC
[incubator-tvm] branch master updated: Legalize - Use Non-recursive
Rewriter. (#5296)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 7d670b0 Legalize - Use Non-recursive Rewriter. (#5296)
7d670b0 is described below
commit 7d670b041e2e1b509ceac503b6635ec4434499c4
Author: Animesh Jain <an...@umich.edu>
AuthorDate: Thu Apr 9 21:01:35 2020 -0700
Legalize - Use Non-recursive Rewriter. (#5296)
* Legalize - Use Non-recursive Rewriter.
* Cleanup.
---
include/tvm/relay/expr_functor.h | 4 ++--
src/relay/transforms/legalize.cc | 19 +++++++++----------
2 files changed, 11 insertions(+), 12 deletions(-)
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index 6f8ac69..04b2754 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -330,7 +330,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
*
- * The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
+ * The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
@@ -408,7 +408,7 @@ class ExprRewriter {
/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
- * PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
+ * PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc
index 250dd69..01411a6 100644
--- a/src/relay/transforms/legalize.cc
+++ b/src/relay/transforms/legalize.cc
@@ -35,19 +35,18 @@ namespace legalize {
// Call registered FTVMLegalize of an op
// Returns the legalized expression
-class Legalizer : public ExprMutator {
+class Legalizer : public ExprRewriter {
public:
explicit Legalizer(const std::string& legalize_map_attr_name)
: legalize_map_attr_name_{legalize_map_attr_name} {}
- Expr VisitExpr_(const CallNode* call_node) {
+ Expr Rewrite_(const CallNode* call_node, const Expr& post) override {
// Get the new_call node without any changes to current call node.
- Expr new_e = ExprMutator::VisitExpr_(call_node);
- Call new_call = Downcast<Call>(new_e);
+ Call new_call = Downcast<Call>(post);
// Check if the string is registered in the OpRegistry.
if (!Op::HasAttr(legalize_map_attr_name_)) {
- return new_e;
+ return post;
}
// Collect the registered legalize function.
@@ -70,19 +69,18 @@ class Legalizer : public ExprMutator {
// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
- // Reassign new_e if the transformation succeeded.
+ // Return the new expr if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";
-
- new_e = legalized_value;
+ return legalized_value;
}
}
}
- return new_e;
+ return post;
}
private:
@@ -90,7 +88,8 @@ class Legalizer : public ExprMutator {
};
Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
- return Legalizer(legalize_map_attr_name).Mutate(expr);
+ auto rewriter = Legalizer(legalize_map_attr_name);
+ return PostOrderRewrite(expr, &rewriter);
}
} // namespace legalize