You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/14 06:24:00 UTC

[GitHub] [incubator-tvm] icemelon9 opened a new pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

icemelon9 opened a new pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052


   Use pattern matching rewriter to merge two consecutive reshape ops.
   
   @mbrookhart I also added an InferType pass after rewriting each pattern. I think this change can make the pattern rewriter more useful, at least I need this feature in my pass.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454712719



##########
File path: python/tvm/relay/dataflow_pattern/__init__.py
##########
@@ -748,11 +756,14 @@ def rewrite(callbacks, expr: Expr) -> Expr:
         The Expression with matched subgraphs rewritten by the callbacks.
     """
     if isinstance(callbacks, DFPatternCallback):
-        tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)]
+        assert callbacks.pattern is not None
+        tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback, callbacks.require_type)]
     else:
         tmp = []
         for callback in callbacks:
-            tmp.append(_DFPatternCallback(callback.pattern, callback.callback))
+            assert callback.pattern is not None
+            tmp.append(_DFPatternCallback(callback.pattern, callback.callback,
+                                          callback.require_type))

Review comment:
       Fixed




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454668508



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.cc
+ * \brief A pass for simplifying the Relay expression.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/support/logging.h>
+#include "../op/tensor/transform.h"
+
+namespace tvm {
+namespace relay {
+
+static Op reshape_op = Op::Get("reshape");
+static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape");
+
+/*!
+ * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
+ *   and merges into one reshape op.
+ */
+class SimplifyReshape {
+ public:
+  SimplifyReshape() {
+    x_ = WildcardPattern(make_object<WildcardPatternNode>());
+    auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
+    auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
+    pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {});
+  }
+
+  Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
+    auto x = node_map[x_][0];
+    bool const_shape = true;
+    Array<Integer> newshape;
+    for (auto dim : Downcast<TensorType>(pre->checked_type())->shape) {
+      if (dim.as<IntImmNode>() == nullptr) {
+        const_shape = false;
+        break;
+      }
+      newshape.push_back(Downcast<Integer>(dim));
+    }
+    if (const_shape) {
+      return MakeReshape(x, newshape);
+    }
+    return post;
+  }
+
+  DFPattern pattern() const { return pattern_; }
+
+ private:
+  /*! \brief Pattern input */
+  DFPattern x_;
+  /*! \brief Pattern for consecutive reshape or reverse_reshape ops */
+  DFPattern pattern_;
+};
+
+/*!
+ * \brief ExprSimplifier simplifies the Relay expression.
+ */
+class ExprSimplifier {
+ public:
+  ExprSimplifier() {
+    auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
+      Expr pre = args[0];
+      Expr post = args[1];
+      Map<DFPattern, Array<Expr>> node_map = args[2];
+      *rv = simplify_reshape_.callback(pre, post, node_map);
+    };
+    callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func),

Review comment:
       The reason that I didn't inherit directly from `DFPatternCallback` is because you need to create the pattern somewhere else as it's required in the `DFPatternCallback` constructor.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454451030



##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -740,10 +740,10 @@ class PatternRewriter : protected MixedModeMutator {
         groups_ = grouper.GroupMatches(callback_->pattern_, post);
         gid_assignments_ = grouper.GetGIDAssignments();
         memo_.clear();
-        post = this->VisitExpr(post);
+        post = InferType(this->VisitExpr(post));

Review comment:
       This is failing all of the pattern language unit tests because they don't assume you need a typed graph for pattern matching. Maybe we should make this behavior optional? Or do we change the API to assert that Expressions have to be well typed to run the pattern rewriter?

##########
File path: python/tvm/relay/transform/simplify_expr.py
##########
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""
+A pass for simplifying the Relay expression.
+"""
+from . import transform
+from ..dataflow_pattern import wildcard, is_op, DFPatternCallback, rewrite
+from .. import op as _op
+
+class SimplifyReshapeCallback(DFPatternCallback):
+    """Callback to merge consecutive reshape ops"""
+    def __init__(self):
+        self.x = wildcard()
+        reshape1 = is_op("reshape") | is_op("contrib_reverse_reshape")
+        reshape2 = is_op("reshape") | is_op("contrib_reverse_reshape")
+        self.pattern = reshape1(reshape2(self.x))
+
+    def callback(self, pre, post, node_map):
+        x = node_map[self.x][0]
+        return _op.reshape(x, newshape=pre.checked_type.shape)
+
+
+@transform.function_pass(opt_level=0, required=["InferType"])
+class SimplifyExpr:
+    """ A pass to simplify the Relay expression."""
+    def __init__(self):
+        self.callbacks = [SimplifyReshapeCallback()]
+
+    def transform_function(self, func, mod, _):
+        return rewrite(self.callbacks, func)

Review comment:
       :) I've been thinking about putting together an algebraic simplifier for a while, this seems like a great first step. 

##########
File path: tests/python/relay/test_pass_simplify_expr.py
##########
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.testing import run_opt_pass
+
+
+def test_simplify_reshape():
+    def before():
+        x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+        w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
+        y = relay.nn.conv2d(x, w, padding=(1, 1))
+        y = relay.reshape(y, newshape=(1, 16, -1))
+        y = relay.reshape(y, newshape=(4, 8, -1, 16))

Review comment:
       Can you add a test case that hits `contrib_reverse_reshape`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#issuecomment-658271299


   Thanks!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r455145780



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.cc
+ * \brief A pass for simplifying the Relay expression.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/support/logging.h>
+#include "../op/tensor/transform.h"
+
+namespace tvm {
+namespace relay {
+
+static Op reshape_op = Op::Get("reshape");
+static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape");
+
+/*!
+ * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
+ *   and merges into one reshape op.
+ */
+class SimplifyReshape {
+ public:
+  SimplifyReshape() {
+    x_ = WildcardPattern(make_object<WildcardPatternNode>());
+    auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
+    auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
+    pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {});
+  }
+
+  Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
+    auto x = node_map[x_][0];
+    bool const_shape = true;
+    Array<Integer> newshape;
+    for (auto dim : Downcast<TensorType>(pre->checked_type())->shape) {
+      if (dim.as<IntImmNode>() == nullptr) {
+        const_shape = false;
+        break;
+      }
+      newshape.push_back(Downcast<Integer>(dim));
+    }
+    if (const_shape) {
+      return MakeReshape(x, newshape);
+    }
+    return post;
+  }
+
+  DFPattern pattern() const { return pattern_; }
+
+ private:
+  /*! \brief Pattern input */
+  DFPattern x_;
+  /*! \brief Pattern for consecutive reshape or reverse_reshape ops */
+  DFPattern pattern_;
+};
+
+/*!
+ * \brief ExprSimplifier simplifies the Relay expression.
+ */
+class ExprSimplifier {
+ public:
+  ExprSimplifier() {
+    auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
+      Expr pre = args[0];
+      Expr post = args[1];
+      Map<DFPattern, Array<Expr>> node_map = args[2];
+      *rv = simplify_reshape_.callback(pre, post, node_map);
+    };
+    callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func),

Review comment:
       :/ I think I focused too much on the Python API and left an Ugly C++ API. I'll see if I can clean that up in a follow up PR. Thanks!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454650764



##########
File path: python/tvm/relay/dataflow_pattern/__init__.py
##########
@@ -748,11 +756,14 @@ def rewrite(callbacks, expr: Expr) -> Expr:
         The Expression with matched subgraphs rewritten by the callbacks.
     """
     if isinstance(callbacks, DFPatternCallback):
-        tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)]
+        assert callbacks.pattern is not None
+        tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback, callbacks.require_type)]
     else:
         tmp = []
         for callback in callbacks:
-            tmp.append(_DFPatternCallback(callback.pattern, callback.callback))
+            assert callback.pattern is not None
+            tmp.append(_DFPatternCallback(callback.pattern, callback.callback,
+                                          callback.require_type))

Review comment:
       We can simplify this part by adding
   `callbacks = [callbacks] if instance(callbacks, DFPatternCallback) else callback`
   on the top.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen merged pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454666904



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -42,11 +42,16 @@ class DFPatternCallback;
 class DFPatternCallbackNode : public Object {
  public:
   /*! \brief Pattern this callback matches */
-  DFPattern pattern_;
+  DFPattern pattern;
   /*! \brief Function to call when finding a matched expression */
-  PackedFunc function_;
+  PackedFunc function;
+  /*! \brief Require InferType to be run before the callback */
+  bool require_type;

Review comment:
       Because these variables are public, it's probably better and more consistent to name it without "_" at the end imo.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r455221697



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.cc
+ * \brief A pass for simplifying the Relay expression.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/support/logging.h>
+#include "../op/tensor/transform.h"
+
+namespace tvm {
+namespace relay {
+
+static Op reshape_op = Op::Get("reshape");
+static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape");
+
+/*!
+ * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
+ *   and merges into one reshape op.
+ */
+class SimplifyReshape {
+ public:
+  SimplifyReshape() {
+    x_ = WildcardPattern(make_object<WildcardPatternNode>());
+    auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
+    auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
+    pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {});
+  }
+
+  Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
+    auto x = node_map[x_][0];
+    bool const_shape = true;
+    Array<Integer> newshape;
+    for (auto dim : Downcast<TensorType>(pre->checked_type())->shape) {
+      if (dim.as<IntImmNode>() == nullptr) {
+        const_shape = false;
+        break;
+      }
+      newshape.push_back(Downcast<Integer>(dim));
+    }
+    if (const_shape) {
+      return MakeReshape(x, newshape);
+    }
+    return post;
+  }
+
+  DFPattern pattern() const { return pattern_; }
+
+ private:
+  /*! \brief Pattern input */
+  DFPattern x_;
+  /*! \brief Pattern for consecutive reshape or reverse_reshape ops */
+  DFPattern pattern_;
+};
+
+/*!
+ * \brief ExprSimplifier simplifies the Relay expression.
+ */
+class ExprSimplifier {
+ public:
+  ExprSimplifier() {
+    auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
+      Expr pre = args[0];
+      Expr post = args[1];
+      Map<DFPattern, Array<Expr>> node_map = args[2];
+      *rv = simplify_reshape_.callback(pre, post, node_map);
+    };
+    callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func),

Review comment:
       Sounds good. :)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#issuecomment-657994257


   cc @zhiics @comaniac @jroesch 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454650310



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -42,11 +42,16 @@ class DFPatternCallback;
 class DFPatternCallbackNode : public Object {
  public:
   /*! \brief Pattern this callback matches */
-  DFPattern pattern_;
+  DFPattern pattern;
   /*! \brief Function to call when finding a matched expression */
-  PackedFunc function_;
+  PackedFunc function;
+  /*! \brief Require InferType to be run before the callback */
+  bool require_type;

Review comment:
       https://tvm.apache.org/docs/contribute/code_guide.html
   https://google.github.io/styleguide/cppguide.html#Variable_Names
   
   Why the move away from the Google Style Guide convention? You seem to use the var_name_ convention in simplify_expr.cc.

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.cc
+ * \brief A pass for simplifying the Relay expression.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/support/logging.h>
+#include "../op/tensor/transform.h"
+
+namespace tvm {
+namespace relay {
+
+static Op reshape_op = Op::Get("reshape");
+static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape");
+
+/*!
+ * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
+ *   and merges into one reshape op.
+ */
+class SimplifyReshape {
+ public:
+  SimplifyReshape() {
+    x_ = WildcardPattern(make_object<WildcardPatternNode>());
+    auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
+    auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
+    pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {});
+  }
+
+  Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
+    auto x = node_map[x_][0];
+    bool const_shape = true;
+    Array<Integer> newshape;
+    for (auto dim : Downcast<TensorType>(pre->checked_type())->shape) {
+      if (dim.as<IntImmNode>() == nullptr) {
+        const_shape = false;
+        break;
+      }
+      newshape.push_back(Downcast<Integer>(dim));
+    }
+    if (const_shape) {
+      return MakeReshape(x, newshape);
+    }
+    return post;
+  }
+
+  DFPattern pattern() const { return pattern_; }
+
+ private:
+  /*! \brief Pattern input */
+  DFPattern x_;
+  /*! \brief Pattern for consecutive reshape or reverse_reshape ops */
+  DFPattern pattern_;
+};
+
+/*!
+ * \brief ExprSimplifier simplifies the Relay expression.
+ */
+class ExprSimplifier {
+ public:
+  ExprSimplifier() {
+    auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
+      Expr pre = args[0];
+      Expr post = args[1];
+      Map<DFPattern, Array<Expr>> node_map = args[2];
+      *rv = simplify_reshape_.callback(pre, post, node_map);
+    };
+    callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func),

Review comment:
       Maybe have SimplifyReshape directly inherit DFPatternCallback? You could fold this directly into that and keep it out of the main Simplifier.

##########
File path: tests/python/relay/test_dataflow_pattern.py
##########
@@ -599,6 +599,7 @@ def test_rewrite():
 
     class TestRewrite(DFPatternCallback):
         def __init__(self):
+            super(TestRewrite, self).__init__()

Review comment:
       Thanks for updating these!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454467609



##########
File path: tests/python/relay/test_pass_simplify_expr.py
##########
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.testing import run_opt_pass
+
+
+def test_simplify_reshape():
+    def before():
+        x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+        w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
+        y = relay.nn.conv2d(x, w, padding=(1, 1))
+        y = relay.reshape(y, newshape=(1, 16, -1))
+        y = relay.reshape(y, newshape=(4, 8, -1, 16))

Review comment:
       The line below uses `contrib_reverse_reshape`. It's called `reverse_reshape` in the frontend.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#issuecomment-658270784


   @mbrookhart I'll try to move the pass to C++ and make the infer type to optional.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454471816



##########
File path: tests/python/relay/test_pass_simplify_expr.py
##########
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.testing import run_opt_pass
+
+
+def test_simplify_reshape():
+    def before():
+        x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+        w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
+        y = relay.nn.conv2d(x, w, padding=(1, 1))
+        y = relay.reshape(y, newshape=(1, 16, -1))
+        y = relay.reshape(y, newshape=(4, 8, -1, 16))

Review comment:
       :+1: 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454473775



##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -740,10 +740,10 @@ class PatternRewriter : protected MixedModeMutator {
         groups_ = grouper.GroupMatches(callback_->pattern_, post);
         gid_assignments_ = grouper.GetGIDAssignments();
         memo_.clear();
-        post = this->VisitExpr(post);
+        post = InferType(this->VisitExpr(post));

Review comment:
       I agree, in the multi-stage rewrite scenario, it makes sense to have the InferType here. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#discussion_r454471458



##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -740,10 +740,10 @@ class PatternRewriter : protected MixedModeMutator {
         groups_ = grouper.GroupMatches(callback_->pattern_, post);
         gid_assignments_ = grouper.GetGIDAssignments();
         memo_.clear();
-        post = this->VisitExpr(post);
+        post = InferType(this->VisitExpr(post));

Review comment:
       I agree to make the InferType optional, but assertion may not work, as one pattern may rewrite a graph multiple times, so the rewritten nodes are still not typed even the original nodes are well typed before running rewriter. One solution is requiring users to manually type new nodes in the rewrite callback, but it seems not trivial.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6052: [Relay][Pass] Merge two consecutive reshape ops

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6052:
URL: https://github.com/apache/incubator-tvm/pull/6052#issuecomment-659104000


   Thanks @zhiics @icemelon9 @mbrookhart 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org