You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/01/23 11:50:13 UTC

[GitHub] [incubator-tvm] mbarrett97 opened a new pull request #4771: [Relay] Added Merge Composite pass

mbarrett97 opened a new pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771
 
 
   This adds a pass, MergeComposite, which merges patterns (expressed as Relay expressions). The detected patterns are wrapped in a function call which is marked with a 'Composite' attribute that names the pattern.
   
   This is primarily for use with the external codegen infrastructure in the case where a combination of Relay ops map to a single external codegen op. It is not the same as fusion.
   
   Further discussion on this PR can be found at https://discuss.tvm.ai/t/rfc-external-codegen-defining-composite-relay-operators/5470.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r375547370
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'. This is primarily intended to be used alongside the
+ * external codegen infrastructure to support the case where multiple
+ * Relay operators map to a single external operator.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
+    : pattern_name_(pattern_name), pattern_(pattern) {}
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      // if we haven't encountered this var yet, make a new free var and associate
+      // it with the value at 'root'
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return std::move(free_var);
+    } else {
+      // if we have encountered this var already, return the free var that was created
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  /* How does this work?
+   *
+   * A pattern consists of Relay expression containing only operator call nodes, constants
+   * and free variables. The free variables indicate where the pattern can 'attach' in your
+   * graph. This function takes the final call node of the pattern and the call node currently
+   * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
+   * from the graph (referred to as the 'root' node here) to check they're identical. If at any point
+   * they differ, an empty expression is returned to signify the extract failed. If a free var is
+   * reached in the pattern, the corresponding value in the root is associated with the name of the
+   * free var (via the var_map) so that when we construct the composite function, the inputs match
+   * up correctly with the rest of the graph. The return value of this function when successful is
+   * a new Relay expression ready to be wrapped into a composite function.
+   */
+  Expr ExtractPattern(const Call& pattern, const Call& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    // check to make sure both calls are to operators (not functions)
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return Expr();
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return Expr();
+
+    unsigned int i = 0;
+    Array<Expr> new_args;
+    for (const auto& arg : pattern->args) {
+      Expr new_arg;
+      if (arg->IsInstance<CallNode>()) {
+        // fail if the root argument is not also a call node
+        if (!root->args[i]->IsInstance<CallNode>()) {
+          return Expr();
+        }
+        // if it's a call node, recursively call this function
+        new_arg = ExtractPattern(Downcast<Call>(arg),
+                                 Downcast<Call>(root->args[i]),
+                                 var_map);
+      } else if (arg->IsInstance<VarNode>()) {
+        // if there's a var in the pattern, it must be a free var
+        // so call the function to update the var_map
+        new_arg = ExtractPattern(Downcast<Var>(arg),
+                                 root->args[i],
+                                 var_map);
+      } else if (arg->IsInstance<ConstantNode>()) {
+        // if there's a constant, simply get the corresponding
+        // value of the constant from the root
+        new_arg = ExtractPattern(Downcast<Constant>(arg),
+                                 root->args[i],
+                                 var_map);
+      }
+      if (!new_arg.defined()) {
+        return Expr();
+      }
+      new_args.push_back(new_arg);
+      i++;
+    }
+    return CallNode::make(root->op, new_args, root->attrs);
+  }
+
+  Expr VisitExpr_(const CallNode* cn) {
+    Call call = GetRef<Call>(cn);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      CHECK(func.defined());
+      const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+      // don't step into existing composite functions
+      if (name_node->value != "") {
 
 Review comment:
   `if (name_node && name_node->value != "")`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377545174
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       ReLu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       ReLu
+
+    """
+    pattern_table = {
+        "add_sub_mul": make_add_relu_pattern()
+    }
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
 
 Review comment:
   It should be as simple as implementing get_attribute in Python for relay.Functions. It's not been an essential part of this PR because the majority of passes (I think all non-prototype passes?) that use this are in C++ where you can get attributes.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-584074285
 
 
   @zhiics Regarding the add-sub/sub-add case, yes this would require two patterns with the order or merging controlled by their priority. I can't think of any general way to express both these cases as a single pattern, but if you have any thoughts I'd be glad to hear them. There is potentially an issue with requiring lots of patterns and if we can come up with some concrete examples where that may be the case then I can try and reason about how to improve the pattern matching.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-581887507
 
 
   I've done a big clean up + simplified the code a bit and added some comments.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r370875473
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const tvm::Map<std::string, Expr>& pattern_map)
+    : pattern_map_(pattern_map) {}
+
+  bool MatchPattern(const Call& pattern, const Call& root) {
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return false;
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return false;
+    if (pattern->args.size() != root->args.size())
+      return false;
+
+    unsigned int i = 0;
+    for (const auto& arg : pattern->args) {
+      if (arg->IsInstance<CallNode>()) {
+        if (!root->args[i]->IsInstance<CallNode>())
+          return false;
+        if (!MatchPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i])))
+          return false;
+      }
+      i++;
+    }
+    return true;
+  }
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return free_var;
+    } else {
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  Expr ExtractPattern(const Call& pattern, const Call& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    Expr expr;
+    Expr empty_expr;
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return empty_expr;
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return empty_expr;
+    if (pattern->args.size() != root->args.size())
+      return empty_expr;
+
+    unsigned int i = 0;
+    Array<Expr> new_args;
+    for (const auto& arg : pattern->args) {
+      if (arg->IsInstance<CallNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Call>(arg),
+                                          Downcast<Call>(root->args[i]),
+                                          var_map));
+      }
+      if (arg->IsInstance<VarNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Var>(arg),
+                                          root->args[i],
+                                          var_map));
+      }
+      if (arg->IsInstance<ConstantNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Constant>(arg),
+                                          root->args[i],
+                                          var_map));
+      }
+      i++;
+    }
+
+    auto new_call = CallNode::make(root->op, new_args, root->attrs);
+    return new_call;
+  }
+
+  Expr VisitExpr_(const CallNode* cn) {
+    Call call = GetRef<Call>(cn);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      CHECK(func.defined());
+      const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+      if (name_node->value != "") {
+        tvm::Array<tvm::relay::Expr> new_args;
+        for (const auto& arg : call->args) {
+          auto new_e = this->Mutate(arg);
+          new_args.push_back(new_e);
+        }
+        return CallNode::make(call->op, new_args, call->attrs);
+      }
+    }
+
+    Expr expr = ExprMutator::VisitExpr_(cn);
+    call = Downcast<Call>(expr);
+    if (!call->op->IsInstance<OpNode>())
+      return call;
+
+    Op op = Downcast<Op>(call->op);
+    CHECK(op.defined());
+    for (const auto& x : pattern_map_) {
+      Call pattern = Downcast<Call>(x.second);
+      if (Downcast<Op>(pattern->op)->name != op->name)
+        continue;
+
+      if (MatchPattern(pattern, call)) {
+        Map<std::string, Array<Expr>> args_map;
+        auto extract = ExtractPattern(pattern, call, &args_map);
+        auto free_vars = FreeVars(extract);
+        Function new_func = FunctionNode::make(free_vars, extract,
+                call->checked_type_, {}, Attrs());
+        new_func = FunctionSetAttr(new_func, attr::kComposite,
+                                   tir::StringImmNode::make(x.first));
+        new_func = FunctionSetAttr(new_func, attr::kPrimitive,
+            tvm::Integer(1));
+        Array<Expr> args;
+        for (const auto& free_var : free_vars) {
+          args.push_back(args_map[free_var->name_hint()][1]);
+        }
+        auto new_call = CallNode::make(new_func, args);
+        return new_call;
 
 Review comment:
   Does this mean if there multiple patterns which can be matched to this calll, it will return the first match as determined by tvm::Map key order? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579448339
 
 
   It's important to emphasise the difference between partitioned subgraphs (which correspond to 'External' functions), and the merged patterns here (which correspond to 'Composite' functions).
   
   Composite functions are a way to group together relay patterns which correspond to a particular external codegen function (eg. qnn_conv2d + bias_add + requantize maps to a single convolution function in Arm Compute Library). External functions are a way to represent a subgraph that should not go via TVM codegen but should instead go through an external compiler (eg. DNNL).
   
   For libraries like ACL or DNNL, it happens to be the case that you can just put every supported operator/composite function in it's own external function (partitioned subgraph) and because the library calls are standalone that works well enough.
   
   A more complex external compiler though may want to do some advanced fusion or manipulation which means it needs to see the full supported subgraph, not just a single operator/composite function at a time. It will produce some binary artifact that computes the entire subgraph rather than a particular operator within it. All that complexity should be handled by the external compiler's codegen pass, the responsibility of the partitioner is to ensure that pass receives a valid subgraph that is as large as possible while still being fully supported and introducing no dataflow problems.
   
   In summary, I don't think this pass replaces partitioning which is essential for the 'complex compiler' case, but I think it may replace the annotation pass.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-578066583
 
 
   To hook it up end-to-end will, for now, require writing a custom annotation pass that can recognise composite functions and understand that they need to be treated in the same way as operators. I've got a generic pass that does this, but it builds on top of @zhiics annotation PR which has since been withdrawn. If we agree that this merge_composite pass is useful then I can discuss with zhiics how we want to move forward with the annotator.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-578057856
 
 
   ok I already have a working test case for detecting conv + bias + relu (Note the "Composite" attribute). Now I really want to hook this up with dnnl "compiler" attribute to enable codegen.
   
   ```
   def @main(%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 16, 224, 224), float32] {
     %2 = fn (%data1: Tensor[(1, 3, 224, 224), float32], %weight: Tensor[(16, 3, 3, 3), float32], %bias: Tensor[(16, 1, 1), float32], Primitive=1, Composite="dnnl.conv_bias_relu") -> Tensor[(1, 16, 224, 224), float32] {
       %0 = nn.conv2d(%data1, %weight, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
       %1 = add(%0, %bias) /* ty=Tensor[(1, 16, 224, 224), float32] */;
       nn.relu(%1) /* ty=Tensor[(1, 16, 224, 224), float32] */
     };
     %3 = %2(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */ /* ty=Tensor[(16, 3, 3, 3), float32] */, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */ /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */;
     %6 = fn (%data2: Tensor[(1, 16, 224, 224), float32], %weight1: Tensor[(16, 16, 3, 3), float32], %bias1: Tensor[(16, 1, 1), float32], Primitive=1, Composite="dnnl.conv_bias_relu") -> Tensor[(1, 16, 224, 224), float32] {
       %4 = nn.conv2d(%data2, %weight1, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
       %5 = add(%4, %bias1) /* ty=Tensor[(1, 16, 224, 224), float32] */;
       nn.relu(%5) /* ty=Tensor[(1, 16, 224, 224), float32] */
     };
     %6(%3, meta[relay.Constant][2] /* ty=Tensor[(16, 16, 3, 3), float32] */ /* ty=Tensor[(16, 16, 3, 3), float32] */, meta[relay.Constant][3] /* ty=Tensor[(16, 1, 1), float32] */ /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */
   }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-578057856
 
 
   ok I already have a working test case for detecting conv + bias + relu (Note the "Composite" attribute). Now I really want to hook this up with "Compiler" and "ExternalSymbol" attributes to enable codegen.
   
   ```
   def @main(%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 16, 224, 224), float32] {
     %2 = fn (%data1: Tensor[(1, 3, 224, 224), float32], %weight: Tensor[(16, 3, 3, 3), float32], %bias: Tensor[(16, 1, 1), float32], Primitive=1, Composite="dnnl.conv_bias_relu") -> Tensor[(1, 16, 224, 224), float32] {
       %0 = nn.conv2d(%data1, %weight, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
       %1 = add(%0, %bias) /* ty=Tensor[(1, 16, 224, 224), float32] */;
       nn.relu(%1) /* ty=Tensor[(1, 16, 224, 224), float32] */
     };
     %3 = %2(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */ /* ty=Tensor[(16, 3, 3, 3), float32] */, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */ /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */;
     %6 = fn (%data2: Tensor[(1, 16, 224, 224), float32], %weight1: Tensor[(16, 16, 3, 3), float32], %bias1: Tensor[(16, 1, 1), float32], Primitive=1, Composite="dnnl.conv_bias_relu") -> Tensor[(1, 16, 224, 224), float32] {
       %4 = nn.conv2d(%data2, %weight1, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
       %5 = add(%4, %bias1) /* ty=Tensor[(1, 16, 224, 224), float32] */;
       nn.relu(%5) /* ty=Tensor[(1, 16, 224, 224), float32] */
     };
     %6(%3, meta[relay.Constant][2] /* ty=Tensor[(16, 16, 3, 3), float32] */ /* ty=Tensor[(16, 16, 3, 3), float32] */, meta[relay.Constant][3] /* ty=Tensor[(16, 1, 1), float32] */ /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */
   }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r370468299
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       ReLu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       ReLu
+
+    """
+    pattern_table = {
+        "add_sub_mul": make_add_relu_pattern()
 
 Review comment:
   add_relu

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-584313221
 
 
   @zhiics @comaniac It is worth discussing if we can use composite and partitioning passes to remove the annotation pass, as mentioned by @mbarrett97    

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi removed a comment on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi removed a comment on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-577983004
 
 
   From your test cases, I don't see why this pass would help remove pattern matching on codegen side. Since Relay doesn't have `add_sub_mul` op, I still need to first find multiply, then traverse further to find sub etc, no? From the codegen side I don't see the difference compared to manual partitioning. Or is that part of another future PR?
   
   It definitely removes the need for writing custom annotator for each pattern, though.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377023736
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'. This is primarily intended to be used alongside the
+ * external codegen infrastructure to support the case where multiple
+ * Relay operators map to a single external operator.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
+    : pattern_name_(pattern_name), pattern_(pattern) {}
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      // if we haven't encountered this var yet, make a new free var and associate
+      // it with the value at 'root'
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return std::move(free_var);
+    } else {
+      // if we have encountered this var already, return the free var that was created
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  /* How does this work?
+   *
+   * A pattern consists of Relay expression containing only operator call nodes, constants
+   * and free variables. The free variables indicate where the pattern can 'attach' in your
+   * graph. This function takes the final call node of the pattern and the call node currently
+   * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
+   * from the graph (referred to as the 'root' node here) to check they're identical. If at any point
+   * they differ, an empty expression is returned to signify the extract failed. If a free var is
+   * reached in the pattern, the corresponding value in the root is associated with the name of the
+   * free var (via the var_map) so that when we construct the composite function, the inputs match
+   * up correctly with the rest of the graph. The return value of this function when successful is
+   * a new Relay expression ready to be wrapped into a composite function.
+   */
+  Expr ExtractPattern(const Call& pattern, const Call& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    // check to make sure both calls are to operators (not functions)
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return Expr();
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return Expr();
+
+    unsigned int i = 0;
+    Array<Expr> new_args;
+    for (const auto& arg : pattern->args) {
+      Expr new_arg;
+      if (arg->IsInstance<CallNode>()) {
+        // fail if the root argument is not also a call node
+        if (!root->args[i]->IsInstance<CallNode>()) {
+          return Expr();
+        }
+        // if it's a call node, recursively call this function
+        new_arg = ExtractPattern(Downcast<Call>(arg),
+                                 Downcast<Call>(root->args[i]),
+                                 var_map);
+      } else if (arg->IsInstance<VarNode>()) {
+        // if there's a var in the pattern, it must be a free var
+        // so call the function to update the var_map
+        new_arg = ExtractPattern(Downcast<Var>(arg),
+                                 root->args[i],
+                                 var_map);
+      } else if (arg->IsInstance<ConstantNode>()) {
+        // if there's a constant, simply get the corresponding
+        // value of the constant from the root
+        new_arg = ExtractPattern(Downcast<Constant>(arg),
+                                 root->args[i],
+                                 var_map);
+      }
+      if (!new_arg.defined()) {
+        return Expr();
+      }
+      new_args.push_back(new_arg);
+      i++;
+    }
+    return CallNode::make(root->op, new_args, root->attrs);
+  }
+
+  Expr VisitExpr_(const CallNode* cn) {
+    Call call = GetRef<Call>(cn);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      CHECK(func.defined());
+      const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+      // don't step into existing composite functions
+      if (name_node->value != "") {
+        tvm::Array<tvm::relay::Expr> new_args;
+        for (const auto& arg : call->args) {
+          auto new_e = this->Mutate(arg);
+          new_args.push_back(new_e);
+        }
+        return CallNode::make(call->op, new_args, call->attrs);
+      }
+    }
+
+    Expr expr = ExprMutator::VisitExpr_(cn);
+    call = Downcast<Call>(expr);
+    if (!call->op->IsInstance<OpNode>())
+      return std::move(call);
+
+    // only call patterns are supported
+    Call pattern = Downcast<Call>(pattern_);
+    CHECK(pattern.defined());
+    Map<std::string, Array<Expr>> args_map;
+    auto extract = ExtractPattern(pattern, call, &args_map);
+    if (extract.defined()) {
+      auto free_vars = FreeVars(extract);
+      // make the composite function
+      auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs());
+      f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_));
+      f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1));
+      // find the expressions associated with the free vars using the args_map
+      // this tells us which expressions should be given as inputs to the composite function
+      Array<Expr> args;
+      for (const auto& free_var : free_vars) {
+        args.push_back(args_map[free_var->name_hint()][1]);
+      }
+      auto new_call = CallNode::make(f, args);
+      return std::move(new_call);
+    }
+    return std::move(call);
+  }
+
+ private:
+  /*! \brief The name of the pattern to match */
+  std::string pattern_name_;
+  /*! \brief The pattern to match */
+  Expr pattern_;
+};
+
+Expr MergeComposite(const Expr& expr,
+    const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) {
+  CHECK(pattern_names.size() == patterns.size());
+  Expr merged_expr = expr;
+  // merge the patterns one-by-one in order
+  for (size_t i = 0; i < patterns.size(); i++) {
+    std::string pattern_name = pattern_names[i]->value;
+    Expr pattern = patterns[i];
 
 Review comment:
   I don't think it's necessary here. I'm iterating over 'i' rather than the patterns directly as I'm traversing two arrays, patterns and pattern_names.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376750739
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
 
 Review comment:
   Do you think we should also think about ops like concat, split etc? Basically, the ops that involve TupleNode, and TupleGetItemNode.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376750402
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'. This is primarily intended to be used alongside the
+ * external codegen infrastructure to support the case where multiple
+ * Relay operators map to a single external operator.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
+    : pattern_name_(pattern_name), pattern_(pattern) {}
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      // if we haven't encountered this var yet, make a new free var and associate
+      // it with the value at 'root'
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return std::move(free_var);
+    } else {
+      // if we have encountered this var already, return the free var that was created
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  /* How does this work?
+   *
+   * A pattern consists of Relay expression containing only operator call nodes, constants
+   * and free variables. The free variables indicate where the pattern can 'attach' in your
+   * graph. This function takes the final call node of the pattern and the call node currently
+   * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
+   * from the graph (referred to as the 'root' node here) to check they're identical. If at any point
+   * they differ, an empty expression is returned to signify the extract failed. If a free var is
+   * reached in the pattern, the corresponding value in the root is associated with the name of the
+   * free var (via the var_map) so that when we construct the composite function, the inputs match
+   * up correctly with the rest of the graph. The return value of this function when successful is
+   * a new Relay expression ready to be wrapped into a composite function.
+   */
+  Expr ExtractPattern(const Call& pattern, const Call& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    // check to make sure both calls are to operators (not functions)
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return Expr();
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return Expr();
+
+    unsigned int i = 0;
+    Array<Expr> new_args;
+    for (const auto& arg : pattern->args) {
+      Expr new_arg;
+      if (arg->IsInstance<CallNode>()) {
+        // fail if the root argument is not also a call node
+        if (!root->args[i]->IsInstance<CallNode>()) {
+          return Expr();
+        }
+        // if it's a call node, recursively call this function
+        new_arg = ExtractPattern(Downcast<Call>(arg),
+                                 Downcast<Call>(root->args[i]),
+                                 var_map);
+      } else if (arg->IsInstance<VarNode>()) {
+        // if there's a var in the pattern, it must be a free var
+        // so call the function to update the var_map
+        new_arg = ExtractPattern(Downcast<Var>(arg),
+                                 root->args[i],
+                                 var_map);
+      } else if (arg->IsInstance<ConstantNode>()) {
+        // if there's a constant, simply get the corresponding
+        // value of the constant from the root
+        new_arg = ExtractPattern(Downcast<Constant>(arg),
+                                 root->args[i],
+                                 var_map);
+      }
+      if (!new_arg.defined()) {
+        return Expr();
+      }
+      new_args.push_back(new_arg);
+      i++;
+    }
+    return CallNode::make(root->op, new_args, root->attrs);
+  }
+
+  Expr VisitExpr_(const CallNode* cn) {
+    Call call = GetRef<Call>(cn);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      CHECK(func.defined());
+      const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+      // don't step into existing composite functions
+      if (name_node->value != "") {
+        tvm::Array<tvm::relay::Expr> new_args;
+        for (const auto& arg : call->args) {
+          auto new_e = this->Mutate(arg);
+          new_args.push_back(new_e);
+        }
+        return CallNode::make(call->op, new_args, call->attrs);
+      }
+    }
+
+    Expr expr = ExprMutator::VisitExpr_(cn);
+    call = Downcast<Call>(expr);
+    if (!call->op->IsInstance<OpNode>())
+      return std::move(call);
+
+    // only call patterns are supported
+    Call pattern = Downcast<Call>(pattern_);
+    CHECK(pattern.defined());
+    Map<std::string, Array<Expr>> args_map;
+    auto extract = ExtractPattern(pattern, call, &args_map);
+    if (extract.defined()) {
+      auto free_vars = FreeVars(extract);
+      // make the composite function
+      auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs());
+      f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_));
+      f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1));
+      // find the expressions associated with the free vars using the args_map
+      // this tells us which expressions should be given as inputs to the composite function
+      Array<Expr> args;
+      for (const auto& free_var : free_vars) {
+        args.push_back(args_map[free_var->name_hint()][1]);
+      }
+      auto new_call = CallNode::make(f, args);
+      return std::move(new_call);
+    }
+    return std::move(call);
+  }
+
+ private:
+  /*! \brief The name of the pattern to match */
+  std::string pattern_name_;
+  /*! \brief The pattern to match */
+  Expr pattern_;
+};
+
+Expr MergeComposite(const Expr& expr,
+    const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) {
+  CHECK(pattern_names.size() == patterns.size());
+  Expr merged_expr = expr;
+  // merge the patterns one-by-one in order
+  for (size_t i = 0; i < patterns.size(); i++) {
+    std::string pattern_name = pattern_names[i]->value;
+    Expr pattern = patterns[i];
 
 Review comment:
   Do we need a `const Expr&` ?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r370510632
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'.
+ */
+
+#include <tvm/top/operation.h>
 
 Review comment:
   top -> te after rebase

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r375526504
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'. This is primarily intended to be used alongside the
+ * external codegen infrastructure to support the case where multiple
+ * Relay operators map to a single external operator.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
+    : pattern_name_(pattern_name), pattern_(pattern) {}
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      // if we haven't encountered this var yet, make a new free var and associate
+      // it with the value at 'root'
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return std::move(free_var);
+    } else {
+      // if we have encountered this var already, return the free var that was created
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  /* How does this work?
 
 Review comment:
   Let's document this in the following style
   
   \brief
   
   \param A
   ...
   \param N
   
   \return
   
   \Note How does it work? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579340640
 
 
   I think we probably can get rid of the annotation pass. If we expose the logic that determines whether an operator is supported by a given external compiler to the partitioning pass, then this will also mean we don't need to define a pattern for every single operator. Just define patterns for composite operators and have a common way of expressing whether or not a given operator (composite or otherwise) is supported by an external compiler.
   
   That would lose some of the control that could have been gained by writing a custom annotation pass, but I'm not sure what use cases would leverage that.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r375524422
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'. This is primarily intended to be used alongside the
+ * external codegen infrastructure to support the case where multiple
+ * Relay operators map to a single external operator.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
 
 Review comment:
   remove one blank line

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] comaniac commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579014540
 
 
   So you want the flow to be:
   ```
   CompositeMerge -> Annotation -> Partitioning. 
   ```
   I agree that this would make annotation generic and straightforward, although it seems like we don't need annotation anymore if we specify all patterns including single ops. While there are lots of approaches to do so, maybe we could accept this solution first and consider the further steps. What you do think? @zhiics 
   
   Also, please help clarify the question asked by @masahi and me about multiple matching. Thanks.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-584313916
 
 
   Thanks @mbarrett97 @comaniac @zhiics @anijain2305 this is merged.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r370447222
 
 

 ##########
 File path: python/tvm/relay/transform.py
 ##########
 @@ -508,6 +508,41 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
     return _transform.Legalize(legalize_map_attr_name)
 
 
+def AnnotateCompiler(compiler):
 
 Review comment:
   remove this

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579436246
 
 
   do we still need partitioning if we get rid of the annotation pass? At least for simple patterns like conv + bias + relu, I can get the same composite (or partitioning, whatever) without writing a clunky custom annotator. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-577985906
 
 
   ok reading your original proposal carefully I understand that I can look up 'Composite' attribute to know if a particular function has a pattern I'm looking for. Of course I still need to traverse the arguments, but I can remove the "detection" part from DetectFusedConv2DBiasReLU. 
   
   Sounds great!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-577985906
 
 
   ok reading your original proposal carefully I understand that I can look up 'Composite' attribute to know if a particular function has a pattern I'm looking for. Of course I still need to traverse the arguments, but I can remove the "detection" logic from "traversal" logic in DetectFusedConv2DBiasReLU. 
   
   It seems I have to make only minimal change to make use of this feature in my PR. Also it enables removing my custom annotator. Sounds great!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r375526504
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'. This is primarily intended to be used alongside the
+ * external codegen infrastructure to support the case where multiple
+ * Relay operators map to a single external operator.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
+    : pattern_name_(pattern_name), pattern_(pattern) {}
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      // if we haven't encountered this var yet, make a new free var and associate
+      // it with the value at 'root'
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return std::move(free_var);
+    } else {
+      // if we have encountered this var already, return the free var that was created
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  /* How does this work?
 
 Review comment:
   Let's document this in the following style
   
   \brief
   
   \param A
   ...
   \param N
   
   \return
   
   \note How does it work? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r375530819
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'. This is primarily intended to be used alongside the
+ * external codegen infrastructure to support the case where multiple
+ * Relay operators map to a single external operator.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
+    : pattern_name_(pattern_name), pattern_(pattern) {}
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      // if we haven't encountered this var yet, make a new free var and associate
+      // it with the value at 'root'
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return std::move(free_var);
+    } else {
+      // if we have encountered this var already, return the free var that was created
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  /* How does this work?
+   *
+   * A pattern consists of Relay expression containing only operator call nodes, constants
+   * and free variables. The free variables indicate where the pattern can 'attach' in your
+   * graph. This function takes the final call node of the pattern and the call node currently
+   * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
+   * from the graph (referred to as the 'root' node here) to check they're identical. If at any point
+   * they differ, an empty expression is returned to signify the extract failed. If a free var is
+   * reached in the pattern, the corresponding value in the root is associated with the name of the
+   * free var (via the var_map) so that when we construct the composite function, the inputs match
+   * up correctly with the rest of the graph. The return value of this function when successful is
+   * a new Relay expression ready to be wrapped into a composite function.
+   */
+  Expr ExtractPattern(const Call& pattern, const Call& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    // check to make sure both calls are to operators (not functions)
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return Expr();
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return Expr();
+
+    unsigned int i = 0;
+    Array<Expr> new_args;
+    for (const auto& arg : pattern->args) {
+      Expr new_arg;
+      if (arg->IsInstance<CallNode>()) {
+        // fail if the root argument is not also a call node
+        if (!root->args[i]->IsInstance<CallNode>()) {
+          return Expr();
+        }
+        // if it's a call node, recursively call this function
+        new_arg = ExtractPattern(Downcast<Call>(arg),
+                                 Downcast<Call>(root->args[i]),
+                                 var_map);
+      } else if (arg->IsInstance<VarNode>()) {
+        // if there's a var in the pattern, it must be a free var
+        // so call the function to update the var_map
+        new_arg = ExtractPattern(Downcast<Var>(arg),
+                                 root->args[i],
+                                 var_map);
+      } else if (arg->IsInstance<ConstantNode>()) {
+        // if there's a constant, simply get the corresponding
+        // value of the constant from the root
+        new_arg = ExtractPattern(Downcast<Constant>(arg),
+                                 root->args[i],
+                                 var_map);
+      }
+      if (!new_arg.defined()) {
+        return Expr();
+      }
+      new_args.push_back(new_arg);
+      i++;
+    }
+    return CallNode::make(root->op, new_args, root->attrs);
+  }
+
+  Expr VisitExpr_(const CallNode* cn) {
+    Call call = GetRef<Call>(cn);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      CHECK(func.defined());
+      const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+      // don't step into existing composite functions
+      if (name_node->value != "") {
+        tvm::Array<tvm::relay::Expr> new_args;
+        for (const auto& arg : call->args) {
+          auto new_e = this->Mutate(arg);
+          new_args.push_back(new_e);
+        }
+        return CallNode::make(call->op, new_args, call->attrs);
+      }
+    }
+
+    Expr expr = ExprMutator::VisitExpr_(cn);
+    call = Downcast<Call>(expr);
+    if (!call->op->IsInstance<OpNode>())
+      return std::move(call);
+
+    // only call patterns are supported
+    Call pattern = Downcast<Call>(pattern_);
+    CHECK(pattern.defined());
+    Map<std::string, Array<Expr>> args_map;
+    auto extract = ExtractPattern(pattern, call, &args_map);
+    if (extract.defined()) {
+      auto free_vars = FreeVars(extract);
+      // make the composite function
+      auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs());
+      f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_));
+      f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1));
+      // find the expressions associated with the free vars using the args_map
+      // this tells us which expressions should be given as inputs to the composite function
+      Array<Expr> args;
+      for (const auto& free_var : free_vars) {
+        args.push_back(args_map[free_var->name_hint()][1]);
+      }
+      auto new_call = CallNode::make(f, args);
+      return std::move(new_call);
+    }
+    return std::move(call);
+  }
+
+ private:
+  /*! \brief The name of the pattern to match */
+  std::string pattern_name_;
+  /*! \brief The pattern to match */
+  Expr pattern_;
+};
+
+Expr MergeComposite(const Expr& expr,
+    const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) {
+  CHECK(pattern_names.size() == patterns.size());
 
 Review comment:
   CHECK_EQ

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r371508067
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const tvm::Map<std::string, Expr>& pattern_map)
+    : pattern_map_(pattern_map) {}
+
+  bool MatchPattern(const Call& pattern, const Call& root) {
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return false;
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return false;
+    if (pattern->args.size() != root->args.size())
+      return false;
+
+    unsigned int i = 0;
+    for (const auto& arg : pattern->args) {
+      if (arg->IsInstance<CallNode>()) {
+        if (!root->args[i]->IsInstance<CallNode>())
+          return false;
+        if (!MatchPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i])))
+          return false;
+      }
+      i++;
+    }
+    return true;
+  }
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return free_var;
+    } else {
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  Expr ExtractPattern(const Call& pattern, const Call& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    Expr expr;
+    Expr empty_expr;
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return empty_expr;
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return empty_expr;
+    if (pattern->args.size() != root->args.size())
+      return empty_expr;
+
+    unsigned int i = 0;
+    Array<Expr> new_args;
+    for (const auto& arg : pattern->args) {
+      if (arg->IsInstance<CallNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Call>(arg),
+                                          Downcast<Call>(root->args[i]),
+                                          var_map));
+      }
+      if (arg->IsInstance<VarNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Var>(arg),
+                                          root->args[i],
+                                          var_map));
+      }
+      if (arg->IsInstance<ConstantNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Constant>(arg),
+                                          root->args[i],
+                                          var_map));
+      }
+      i++;
+    }
+
+    auto new_call = CallNode::make(root->op, new_args, root->attrs);
+    return new_call;
+  }
+
+  Expr VisitExpr_(const CallNode* cn) {
+    Call call = GetRef<Call>(cn);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      CHECK(func.defined());
+      const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+      if (name_node->value != "") {
+        tvm::Array<tvm::relay::Expr> new_args;
+        for (const auto& arg : call->args) {
+          auto new_e = this->Mutate(arg);
+          new_args.push_back(new_e);
+        }
+        return CallNode::make(call->op, new_args, call->attrs);
+      }
+    }
+
+    Expr expr = ExprMutator::VisitExpr_(cn);
+    call = Downcast<Call>(expr);
+    if (!call->op->IsInstance<OpNode>())
+      return call;
+
+    Op op = Downcast<Op>(call->op);
+    CHECK(op.defined());
+    for (const auto& x : pattern_map_) {
+      Call pattern = Downcast<Call>(x.second);
+      if (Downcast<Op>(pattern->op)->name != op->name)
+        continue;
+
+      if (MatchPattern(pattern, call)) {
+        Map<std::string, Array<Expr>> args_map;
+        auto extract = ExtractPattern(pattern, call, &args_map);
+        auto free_vars = FreeVars(extract);
+        Function new_func = FunctionNode::make(free_vars, extract,
+                call->checked_type_, {}, Attrs());
+        new_func = FunctionSetAttr(new_func, attr::kComposite,
+                                   tir::StringImmNode::make(x.first));
+        new_func = FunctionSetAttr(new_func, attr::kPrimitive,
+            tvm::Integer(1));
+        Array<Expr> args;
+        for (const auto& free_var : free_vars) {
+          args.push_back(args_map[free_var->name_hint()][1]);
+        }
+        auto new_call = CallNode::make(new_func, args);
+        return new_call;
 
 Review comment:
   Yep, this is an oversight. I can fix it pretty easily to work in priority order though and shall update the PR accordingly. I'll add tests for this case.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376750916
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
 
 Review comment:
   Can we add 1 more testcase where Var('a') = Var('c') in the primary graph? Maybe, my previous comment already covers that.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-584334732
 
 
   @masahi @comaniac @zhiics Thanks for the reviews. An RFC on alternative annotation mechanisms would be great.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] comaniac commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579461903
 
 
   While @mbarrett97 has pointed out the difference between composite/annotation and partition, I just provide one more clarification: The main task for the partition pass is transforming compiler annotations in a Relay graph to external functions and it is compiler agnostic. It means we may even consider the possibility of partitioning a Relay graph to more than one external compilers.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377411817
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       ReLu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       ReLu
+
+    """
+    pattern_table = {
+        "add_sub_mul": make_add_relu_pattern()
+    }
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
 
 Review comment:
   Whats stopping Composite from being exposed? It seems pretty important to be able to see the composite name in python for a lot of optimizations.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377133432
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       ReLu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       ReLu
+
+    """
+    pattern_table = {
+        "add_sub_mul": make_add_relu_pattern()
+    }
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
 
 Review comment:
   Unfortunately this doesn't seem to be exposed by the Python API (I can set attributes but not retrieve them).

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
zhiics commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-584274614
 
 
   @mbarrett97 Considering that this is a relatively standalone pass and it helps fusion with external functions, I think it is okay to take it in.
   
   For the pattern matching cases, one possible way I am thinking of is automatically generating some possible pattern given some metadata.
   
   @comaniac and I will start an RFC to discuss how we can add the whitelist based annotation back. It probably will leverage the pattern matching here.
   
   @masahi @anijain2305 PTAL and share your comments if there is any. Thanks.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-584286159
 
 
   Oh just saw. The tests are already added. Thanks, its good from my side :)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376751173
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
 
 Review comment:
   Sorry, 1 more testcase request - Have multiple patterns in the graph, such that, the last subgraph input variables come from multiple prior subgraphs (instead of just 1).

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r375542736
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
+          \ /                      add_sub_mul
+          mul                     c     |
+          /  \                     \    |
+       c /  c |       ====>        add_sub_mul
+       \/   \/                          |
+       add  sub                         |
+        \   /                         relu
+         \ /
+         mul
+          |
+          |
+        relu
+    """
+
+    pattern_table = [
+        ("add_sub_mul", make_add_sub_mul_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+        add_node = relay.add(a, b)
+        sub_node = relay.subtract(a, b)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_node_2 = relay.add(c, mul_node)
+        sub_node_2 = relay.subtract(c, mul_node)
+        mul_node_2 = relay.multiply(add_node_2, sub_node_2)
+        r = relay.nn.relu(mul_node_2)
+        return relay.Function([a, b, c], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+
+        # add_sub_mul function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        sub_node = relay.subtract(in_1, in_2)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_sub_mul = relay.Function([in_1, in_2], mul_node)
+
+        # merged function
+        add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+        add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
+        r = relay.nn.relu(add_sub_mul_2)
+        return relay.Function([a, b, c], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_multiple_patterns():
+    """Test different patterns are merged correctly in the graph.
+
+    We would expect the pattern `make_conv_bias_relu_pattern` to be merged
+    into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
+    to be merged into a single op `add_relu`.
+
+        data   kernel
+          \      /
+           \    /
+           conv2d                   data   kernel   bias
+             |                         \      |      /
+             |   bias                 conv2d_bias_relu
+             |   /                            |
+          bias_add        ====>               |    a
+             |                                |   /
+           relu  a                        add_relu
+             \  /                             |
+             add                              |  b
+              |                               | /
+            relu  b                          mul
+              |  /
+             mul
+    """
+    pattern_table = [
+        ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        conv_node = relay.nn.conv2d(data,
+                                    kernel,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, bias)
+        relu_node = relay.nn.relu(bias_node)
+        add_node = relay.add(relu_node, a)
+        relu_node_2 = relay.nn.relu(add_node)
+        r = relay.multiply(relu_node_2, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    def expected():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        # conv_bias_relu function
+        in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
+        in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
+        in_3 = relay.var('in_3', shape=(256,))
+
+        conv_node = relay.nn.conv2d(in_1,
+                                    in_2,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, in_3)
+        r = relay.nn.relu(bias_node)
+        conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
+
+        # add_relu function
+        in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
+        in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
+        add_node = relay.add(in_4, in_5)
+        r = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_4, in_5], r)
+
+        # merged function
+        conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
+        add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
+        r = relay.multiply(add_relu_1, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_merge_order():
+    """Test that patterns are merged in the order they exist in the pattern table.
+
+    There can be cases where one pattern is a subgraph of another, in which case
+    it is not clear which match should take priority. The priority should come
+    from the order in which the patterns are declared in the pattern table. The
+    first patterns will be merged with highest priority and the last with lowest.
+
+    A:       B:       C:
+    add      add      abs
+     |        |        |
+    abs      abs      relu
+     |
+    relu
+
+    """
+
+    def pattern_A():
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        return out
+
+    def pattern_B():
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        return out
+
+    def pattern_C():
+        x = relay.var('x')
+        out = relay.abs(x)
+        out = relay.nn.relu(x)
+        return out
+
+    def before():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        out = relay.add(input_1, input_2)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        return relay.Function([input_1, input_2], out)
+
+    def after_A_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        merged_func = relay.Function([x, y], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('A'))
+        ret = relay.Call(merged_func, [input_1, input_2])
+        return relay.Function([input_1, input_2], ret)
+
+    def after_B_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        merged_func = relay.Function([x, y], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('B'))
+        merged_call = relay.Call(merged_func, [input_1, input_2])
+        ret = relay.nn.relu(merged_call)
+        return relay.Function([input_1, input_2], ret)
+
+    def after_C_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        add = relay.add(input_1, input_2)
+        x = relay.var('x')
+        out = relay.abs(x)
+        out = relay.nn.relu(out)
+        merged_func = relay.Function([x], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('C'))
+        ret = relay.Call(merged_func, [add])
+        return relay.Function([input_1, input_2], ret)
+
+    # check A highest priority
+    pattern_table = [
+        ("A", pattern_A()),
+        ("B", pattern_B()),
+        ("C", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+    # check B highest priority
+    pattern_table = [
+        ("B", pattern_A()),
+        ("C", pattern_B()),
+        ("A", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+    # check C highest priority
+    pattern_table = [
+        ("C", pattern_A()),
+        ("A", pattern_B()),
+        ("B", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
 
 Review comment:
   Besides checking the expected graph, can we also compile and run at least one of the tests?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377029670
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
+          \ /                      add_sub_mul
+          mul                     c     |
+          /  \                     \    |
+       c /  c |       ====>        add_sub_mul
+       \/   \/                          |
+       add  sub                         |
+        \   /                         relu
+         \ /
+         mul
+          |
+          |
+        relu
+    """
+
+    pattern_table = [
+        ("add_sub_mul", make_add_sub_mul_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+        add_node = relay.add(a, b)
+        sub_node = relay.subtract(a, b)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_node_2 = relay.add(c, mul_node)
+        sub_node_2 = relay.subtract(c, mul_node)
+        mul_node_2 = relay.multiply(add_node_2, sub_node_2)
+        r = relay.nn.relu(mul_node_2)
+        return relay.Function([a, b, c], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+
+        # add_sub_mul function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        sub_node = relay.subtract(in_1, in_2)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_sub_mul = relay.Function([in_1, in_2], mul_node)
+
+        # merged function
+        add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+        add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
+        r = relay.nn.relu(add_sub_mul_2)
+        return relay.Function([a, b, c], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_multiple_patterns():
+    """Test different patterns are merged correctly in the graph.
+
+    We would expect the pattern `make_conv_bias_relu_pattern` to be merged
+    into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
+    to be merged into a single op `add_relu`.
+
+        data   kernel
+          \      /
+           \    /
+           conv2d                   data   kernel   bias
+             |                         \      |      /
+             |   bias                 conv2d_bias_relu
+             |   /                            |
+          bias_add        ====>               |    a
+             |                                |   /
+           relu  a                        add_relu
+             \  /                             |
+             add                              |  b
+              |                               | /
+            relu  b                          mul
+              |  /
+             mul
+    """
+    pattern_table = [
+        ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        conv_node = relay.nn.conv2d(data,
+                                    kernel,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, bias)
+        relu_node = relay.nn.relu(bias_node)
+        add_node = relay.add(relu_node, a)
+        relu_node_2 = relay.nn.relu(add_node)
+        r = relay.multiply(relu_node_2, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    def expected():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        # conv_bias_relu function
+        in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
+        in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
+        in_3 = relay.var('in_3', shape=(256,))
+
+        conv_node = relay.nn.conv2d(in_1,
+                                    in_2,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, in_3)
+        r = relay.nn.relu(bias_node)
+        conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
+
+        # add_relu function
+        in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
+        in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
+        add_node = relay.add(in_4, in_5)
+        r = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_4, in_5], r)
+
+        # merged function
+        conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
+        add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
+        r = relay.multiply(add_relu_1, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_merge_order():
+    """Test that patterns are merged in the order they exist in the pattern table.
+
+    There can be cases where one pattern is a subgraph of another, in which case
+    it is not clear which match should take priority. The priority should come
+    from the order in which the patterns are declared in the pattern table. The
+    first patterns will be merged with highest priority and the last with lowest.
+
+    A:       B:       C:
+    add      add      abs
+     |        |        |
+    abs      abs      relu
+     |
+    relu
+
+    """
+
+    def pattern_A():
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        return out
+
+    def pattern_B():
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        return out
+
+    def pattern_C():
+        x = relay.var('x')
+        out = relay.abs(x)
+        out = relay.nn.relu(x)
+        return out
+
+    def before():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        out = relay.add(input_1, input_2)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        return relay.Function([input_1, input_2], out)
+
+    def after_A_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        merged_func = relay.Function([x, y], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('A'))
+        ret = relay.Call(merged_func, [input_1, input_2])
+        return relay.Function([input_1, input_2], ret)
+
+    def after_B_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        merged_func = relay.Function([x, y], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('B'))
+        merged_call = relay.Call(merged_func, [input_1, input_2])
+        ret = relay.nn.relu(merged_call)
+        return relay.Function([input_1, input_2], ret)
+
+    def after_C_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        add = relay.add(input_1, input_2)
+        x = relay.var('x')
+        out = relay.abs(x)
+        out = relay.nn.relu(out)
+        merged_func = relay.Function([x], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('C'))
+        ret = relay.Call(merged_func, [add])
+        return relay.Function([input_1, input_2], ret)
+
+    # check A highest priority
+    pattern_table = [
+        ("A", pattern_A()),
+        ("B", pattern_B()),
+        ("C", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+    # check B highest priority
+    pattern_table = [
+        ("B", pattern_A()),
+        ("C", pattern_B()),
+        ("A", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+    # check C highest priority
+    pattern_table = [
+        ("C", pattern_A()),
+        ("A", pattern_B()),
+        ("B", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
 
 Review comment:
   It's not particularly meaningful to compile beyond this point as it requires a specific codegen pass to be able to interpret the composite function. I think that coverage will come from the tests for external codegens that make use of this infrastructure.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377003354
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
 
 Review comment:
   To be more general, yes we will need to think about this case. For this 1st iteration I've just considered patterns which are composed of Calls. I'd prefer to start with this and add the additional functionality as and when it's required.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-577983004
 
 
   From your test cases, I don't see why this pass would help remove pattern matching on codegen side. Since Relay doesn't have `add_sub_mul` op, I still need to first find multiply, then traverse further to find sub etc, no? From the codegen side I don't see the difference compared to manual partitioning. Or is that part of another future PR?
   
   It definitely removes the need for writing custom annotator for each pattern, though.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579436246
 
 
   do we still need partitioning if we get rid of the annotation pass? At least for simple patterns like conv + bias + relu, I can get the same composite (or partitioning, whatever) without writing a clunky custom annotator and doing the partitioning pass. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-584284517
 
 
   My only request is to add more test cases. In my experience, things get very ugly as the networks get bigger and the things that seems corner cases become very common. But, I am ok with the scope of this PR and delaying aggressive testing to later PR. I think it helps in quick flush of e2e pileline.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377835990
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       ReLu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       ReLu
+
+    """
+    pattern_table = {
+        "add_sub_mul": make_add_relu_pattern()
+    }
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
 
 Review comment:
   Ok thanks, maybe I'll take a look at a separate PR that adds that functionality.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-578329595
 
 
   If there are multiple patterns to detect, can a composite function detected by one pattern be used as a part of a match for another patterns?
   
   Not sure if this is useful though. Maybe it enables breaking up a big pattern into chunks or pattern reuse. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-577983004
 
 
   From your test cases, I don't see why this pass would help remove pattern matching on codegen side. Since Relay doesn't have `add_sub_mul` op, I still need to first find multiply, then traverse further to find sub etc, no? From the codegen side I don't see the difference compared to manual partitioning. 
   
   It definitely removes the need for writing custom annotator for each pattern, though.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376750989
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
+          \ /                      add_sub_mul
+          mul                     c     |
+          /  \                     \    |
+       c /  c |       ====>        add_sub_mul
+       \/   \/                          |
+       add  sub                         |
+        \   /                         relu
+         \ /
+         mul
+          |
+          |
+        relu
+    """
+
+    pattern_table = [
+        ("add_sub_mul", make_add_sub_mul_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+        add_node = relay.add(a, b)
+        sub_node = relay.subtract(a, b)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_node_2 = relay.add(c, mul_node)
+        sub_node_2 = relay.subtract(c, mul_node)
+        mul_node_2 = relay.multiply(add_node_2, sub_node_2)
+        r = relay.nn.relu(mul_node_2)
+        return relay.Function([a, b, c], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+
+        # add_sub_mul function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        sub_node = relay.subtract(in_1, in_2)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_sub_mul = relay.Function([in_1, in_2], mul_node)
+
+        # merged function
+        add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+        add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
+        r = relay.nn.relu(add_sub_mul_2)
+        return relay.Function([a, b, c], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_multiple_patterns():
+    """Test different patterns are merged correctly in the graph.
+
+    We would expect the pattern `make_conv_bias_relu_pattern` to be merged
+    into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
+    to be merged into a single op `add_relu`.
+
+        data   kernel
+          \      /
+           \    /
+           conv2d                   data   kernel   bias
+             |                         \      |      /
+             |   bias                 conv2d_bias_relu
+             |   /                            |
+          bias_add        ====>               |    a
+             |                                |   /
+           relu  a                        add_relu
+             \  /                             |
+             add                              |  b
+              |                               | /
+            relu  b                          mul
+              |  /
+             mul
+    """
+    pattern_table = [
+        ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        conv_node = relay.nn.conv2d(data,
+                                    kernel,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, bias)
+        relu_node = relay.nn.relu(bias_node)
+        add_node = relay.add(relu_node, a)
+        relu_node_2 = relay.nn.relu(add_node)
+        r = relay.multiply(relu_node_2, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    def expected():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        # conv_bias_relu function
+        in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
+        in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
+        in_3 = relay.var('in_3', shape=(256,))
+
+        conv_node = relay.nn.conv2d(in_1,
+                                    in_2,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, in_3)
+        r = relay.nn.relu(bias_node)
+        conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
+
+        # add_relu function
+        in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
+        in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
+        add_node = relay.add(in_4, in_5)
+        r = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_4, in_5], r)
+
+        # merged function
+        conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
+        add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
+        r = relay.multiply(add_relu_1, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_merge_order():
 
 Review comment:
   What happens when pattern1 is a subset of pattern2? Basically you want to fuse as much as possible. So, in ideal cases, pattern 2. But, if pattern 2 isn't applicable, we want to still fuse pattern 1.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-578329595
 
 
   If there are multiple patterns to detetct, can a composite function detected by one pattern be used as a part of a match for another patterns?
   
   Not sure if this is useful though. Maybe it enables breaking up a big pattern into chunks or pattern reuse. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376750840
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
 
 Review comment:
   Can we also try two parallel add_sub_mul but with same inputs? This will ensure that Relay var variables are generated correctly when used across two subgraphs.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376800880
 
 

 ##########
 File path: python/tvm/relay/transform.py
 ##########
 @@ -513,6 +513,31 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
     return _transform.Legalize(legalize_map_attr_name)
 
 
+def MergeComposite(pattern_table):
+    """Merge multiple operators into a single composite relay function.
+
+    Parameters
+    ----------
+    pattern_table : list(tuple)
+        A list of (pattern_name, pattern) tuples.
 
 Review comment:
   Because the patterns have priorities and the index of this list implies the order. There was a discussion about how to deal with multiple patterns match to the same sub graph, and this is the workaround for this PR.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r370507137
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       ReLu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       ReLu
+
+    """
+    pattern_table = {
+        "add_sub_mul": make_add_relu_pattern()
+    }
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
+          \ /                      add_sub_mul
+          mul                     c     |
+          /  \                     \    |
+       c /  c |       ====>        add_sub_mul
+       \/   \/                          |
+       add  sub                         |
+        \   /                         ReLu
+         \ /
+         mul
+          |
+          |
+        ReLu
+    """
+
+    pattern_table = {
+        "add_sub_mul": make_add_sub_mul_pattern()
+    }
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+        add_node = relay.add(a, b)
+        sub_node = relay.subtract(a, b)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_node_2 = relay.add(c, mul_node)
+        sub_node_2 = relay.subtract(c, mul_node)
+        mul_node_2 = relay.multiply(add_node_2, sub_node_2)
+        r = relay.nn.relu(mul_node_2)
+        return relay.Function([a, b, c], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+
+        # add_sub_mul function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        sub_node = relay.subtract(in_1, in_2)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_sub_mul = relay.Function([in_1, in_2], mul_node)
+
+        # merged function
+        add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+        add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
+        r = relay.nn.relu(add_sub_mul_2)
+        return relay.Function([a, b, c], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
 
 Review comment:
   you are missing ```if __name__ == "__main__":``` here

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r374225876
 
 

 ##########
 File path: python/tvm/relay/transform.py
 ##########
 @@ -513,6 +513,49 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
     return _transform.Legalize(legalize_map_attr_name)
 
 
+def AnnotateCompiler(compiler):
 
 Review comment:
   Is this necessary for this PR?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi merged pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771
 
 
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] comaniac commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579344847
 
 
   That's exactly what I was thinking -- we can debate if we should merge annotation to this pass, or just get rid of the annotation and make this pass general enough. Anyway can be discussed in the follow-up steps.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r370871786
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const tvm::Map<std::string, Expr>& pattern_map)
+    : pattern_map_(pattern_map) {}
+
+  bool MatchPattern(const Call& pattern, const Call& root) {
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return false;
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return false;
+    if (pattern->args.size() != root->args.size())
+      return false;
+
+    unsigned int i = 0;
+    for (const auto& arg : pattern->args) {
+      if (arg->IsInstance<CallNode>()) {
+        if (!root->args[i]->IsInstance<CallNode>())
+          return false;
+        if (!MatchPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i])))
+          return false;
+      }
+      i++;
+    }
+    return true;
+  }
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return free_var;
+    } else {
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  Expr ExtractPattern(const Call& pattern, const Call& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    Expr expr;
+    Expr empty_expr;
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return empty_expr;
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return empty_expr;
+    if (pattern->args.size() != root->args.size())
+      return empty_expr;
+
+    unsigned int i = 0;
+    Array<Expr> new_args;
+    for (const auto& arg : pattern->args) {
+      if (arg->IsInstance<CallNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Call>(arg),
+                                          Downcast<Call>(root->args[i]),
+                                          var_map));
+      }
+      if (arg->IsInstance<VarNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Var>(arg),
+                                          root->args[i],
+                                          var_map));
+      }
+      if (arg->IsInstance<ConstantNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Constant>(arg),
+                                          root->args[i],
+                                          var_map));
+      }
+      i++;
+    }
+
+    auto new_call = CallNode::make(root->op, new_args, root->attrs);
+    return new_call;
+  }
+
+  Expr VisitExpr_(const CallNode* cn) {
+    Call call = GetRef<Call>(cn);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      CHECK(func.defined());
+      const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+      if (name_node->value != "") {
+        tvm::Array<tvm::relay::Expr> new_args;
+        for (const auto& arg : call->args) {
+          auto new_e = this->Mutate(arg);
+          new_args.push_back(new_e);
+        }
+        return CallNode::make(call->op, new_args, call->attrs);
+      }
+    }
+
+    Expr expr = ExprMutator::VisitExpr_(cn);
+    call = Downcast<Call>(expr);
+    if (!call->op->IsInstance<OpNode>())
+      return call;
+
+    Op op = Downcast<Op>(call->op);
+    CHECK(op.defined());
+    for (const auto& x : pattern_map_) {
+      Call pattern = Downcast<Call>(x.second);
 
 Review comment:
   better to add
   CHECK(pattern) 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r371504533
 
 

 ##########
 File path: src/relay/pass/merge_composite.cc
 ##########
 @@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/pass/merge_composite.cc
+ * \brief Merges expressions matching patterns into functions marked
+ * as 'composite'.
+ */
+
+#include <tvm/te/operation.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace merge_composite {
+
+
+class MergeCompositeWrapper : public ExprMutator {
+ public:
+  explicit MergeCompositeWrapper(const tvm::Map<std::string, Expr>& pattern_map)
+    : pattern_map_(pattern_map) {}
+
+  bool MatchPattern(const Call& pattern, const Call& root) {
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return false;
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return false;
+    if (pattern->args.size() != root->args.size())
+      return false;
+
+    unsigned int i = 0;
+    for (const auto& arg : pattern->args) {
+      if (arg->IsInstance<CallNode>()) {
+        if (!root->args[i]->IsInstance<CallNode>())
+          return false;
+        if (!MatchPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i])))
+          return false;
+      }
+      i++;
+    }
+    return true;
+  }
+
+  Expr ExtractPattern(const Var& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    if (var_map->find(pattern->name_hint()) == var_map->end()) {
+      auto free_var = VarNode::make(pattern->name_hint(), Type());
+      var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
+      return free_var;
+    } else {
+      return (*var_map)[pattern->name_hint()][0];
+    }
+  }
+
+  Expr ExtractPattern(const Constant& pattern, const Expr& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    return root;
+  }
+
+  Expr ExtractPattern(const Call& pattern, const Call& root,
+          Map<std::string, Array<Expr>>* var_map) {
+    Expr expr;
+    Expr empty_expr;
+    if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
+      return empty_expr;
+    if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
+      return empty_expr;
+    if (pattern->args.size() != root->args.size())
+      return empty_expr;
+
+    unsigned int i = 0;
+    Array<Expr> new_args;
+    for (const auto& arg : pattern->args) {
+      if (arg->IsInstance<CallNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Call>(arg),
+                                          Downcast<Call>(root->args[i]),
+                                          var_map));
+      }
+      if (arg->IsInstance<VarNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Var>(arg),
+                                          root->args[i],
+                                          var_map));
+      }
+      if (arg->IsInstance<ConstantNode>()) {
+        new_args.push_back(ExtractPattern(Downcast<Constant>(arg),
+                                          root->args[i],
+                                          var_map));
+      }
+      i++;
+    }
+
+    auto new_call = CallNode::make(root->op, new_args, root->attrs);
+    return new_call;
+  }
+
+  Expr VisitExpr_(const CallNode* cn) {
+    Call call = GetRef<Call>(cn);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      CHECK(func.defined());
+      const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
+      if (name_node->value != "") {
+        tvm::Array<tvm::relay::Expr> new_args;
+        for (const auto& arg : call->args) {
+          auto new_e = this->Mutate(arg);
+          new_args.push_back(new_e);
+        }
+        return CallNode::make(call->op, new_args, call->attrs);
+      }
+    }
+
+    Expr expr = ExprMutator::VisitExpr_(cn);
+    call = Downcast<Call>(expr);
+    if (!call->op->IsInstance<OpNode>())
+      return call;
+
+    Op op = Downcast<Op>(call->op);
+    CHECK(op.defined());
+    for (const auto& x : pattern_map_) {
+      Call pattern = Downcast<Call>(x.second);
+      if (Downcast<Op>(pattern->op)->name != op->name)
+        continue;
+
+      if (MatchPattern(pattern, call)) {
+        Map<std::string, Array<Expr>> args_map;
+        auto extract = ExtractPattern(pattern, call, &args_map);
+        auto free_vars = FreeVars(extract);
+        Function new_func = FunctionNode::make(free_vars, extract,
+                call->checked_type_, {}, Attrs());
+        new_func = FunctionSetAttr(new_func, attr::kComposite,
+                                   tir::StringImmNode::make(x.first));
+        new_func = FunctionSetAttr(new_func, attr::kPrimitive,
+            tvm::Integer(1));
+        Array<Expr> args;
+        for (const auto& free_var : free_vars) {
+          args.push_back(args_map[free_var->name_hint()][1]);
+        }
+        auto new_call = CallNode::make(new_func, args);
+        return new_call;
 
 Review comment:
   Same question here. If this PR does not attempt to deal with multiple pattern matching to one call, it would be better to provide a priority number for users to specify which pattern should be matched prior to others.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r370470303
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       ReLu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       ReLu
+
+    """
+    pattern_table = {
+        "add_sub_mul": make_add_relu_pattern()
+    }
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
 
 Review comment:
   can you add an example of retrieving the Composite attribute and verify it?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-581489473
 
 
   I've updated the PR to account for the priority ordering of patterns. You no longer specify a pattern map, but a list of tuples (pattern_name, pattern) where the order or the tuples corresponds to the order of the matching. To get this to work, I had to pass strings between Python/C++ requiring this change: https://github.com/apache/incubator-tvm/pull/4806.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377332546
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
+          \ /                      add_sub_mul
+          mul                     c     |
+          /  \                     \    |
+       c /  c |       ====>        add_sub_mul
+       \/   \/                          |
+       add  sub                         |
+        \   /                         relu
+         \ /
+         mul
+          |
+          |
+        relu
+    """
+
+    pattern_table = [
+        ("add_sub_mul", make_add_sub_mul_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+        add_node = relay.add(a, b)
+        sub_node = relay.subtract(a, b)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_node_2 = relay.add(c, mul_node)
+        sub_node_2 = relay.subtract(c, mul_node)
+        mul_node_2 = relay.multiply(add_node_2, sub_node_2)
+        r = relay.nn.relu(mul_node_2)
+        return relay.Function([a, b, c], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+
+        # add_sub_mul function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        sub_node = relay.subtract(in_1, in_2)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_sub_mul = relay.Function([in_1, in_2], mul_node)
+
+        # merged function
+        add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+        add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
+        r = relay.nn.relu(add_sub_mul_2)
+        return relay.Function([a, b, c], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_multiple_patterns():
+    """Test different patterns are merged correctly in the graph.
+
+    We would expect the pattern `make_conv_bias_relu_pattern` to be merged
+    into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
+    to be merged into a single op `add_relu`.
+
+        data   kernel
+          \      /
+           \    /
+           conv2d                   data   kernel   bias
+             |                         \      |      /
+             |   bias                 conv2d_bias_relu
+             |   /                            |
+          bias_add        ====>               |    a
+             |                                |   /
+           relu  a                        add_relu
+             \  /                             |
+             add                              |  b
+              |                               | /
+            relu  b                          mul
+              |  /
+             mul
+    """
+    pattern_table = [
+        ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        conv_node = relay.nn.conv2d(data,
+                                    kernel,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, bias)
+        relu_node = relay.nn.relu(bias_node)
+        add_node = relay.add(relu_node, a)
+        relu_node_2 = relay.nn.relu(add_node)
+        r = relay.multiply(relu_node_2, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    def expected():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        # conv_bias_relu function
+        in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
+        in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
+        in_3 = relay.var('in_3', shape=(256,))
+
+        conv_node = relay.nn.conv2d(in_1,
+                                    in_2,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, in_3)
+        r = relay.nn.relu(bias_node)
+        conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
+
+        # add_relu function
+        in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
+        in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
+        add_node = relay.add(in_4, in_5)
+        r = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_4, in_5], r)
+
+        # merged function
+        conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
+        add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
+        r = relay.multiply(add_relu_1, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_merge_order():
+    """Test that patterns are merged in the order they exist in the pattern table.
+
+    There can be cases where one pattern is a subgraph of another, in which case
+    it is not clear which match should take priority. The priority should come
+    from the order in which the patterns are declared in the pattern table. The
+    first patterns will be merged with highest priority and the last with lowest.
+
+    A:       B:       C:
+    add      add      abs
+     |        |        |
+    abs      abs      relu
+     |
+    relu
+
+    """
+
+    def pattern_A():
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        return out
+
+    def pattern_B():
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        return out
+
+    def pattern_C():
+        x = relay.var('x')
+        out = relay.abs(x)
+        out = relay.nn.relu(x)
+        return out
+
+    def before():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        out = relay.add(input_1, input_2)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        return relay.Function([input_1, input_2], out)
+
+    def after_A_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        out = relay.nn.relu(out)
+        merged_func = relay.Function([x, y], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('A'))
+        ret = relay.Call(merged_func, [input_1, input_2])
+        return relay.Function([input_1, input_2], ret)
+
+    def after_B_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        x = relay.var('x')
+        y = relay.var('y')
+        out = relay.add(x, y)
+        out = relay.abs(out)
+        merged_func = relay.Function([x, y], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('B'))
+        merged_call = relay.Call(merged_func, [input_1, input_2])
+        ret = relay.nn.relu(merged_call)
+        return relay.Function([input_1, input_2], ret)
+
+    def after_C_priority():
+        input_1 = relay.var('input_1', shape=(10, 10))
+        input_2 = relay.var('input_2', shape=(10, 10))
+        add = relay.add(input_1, input_2)
+        x = relay.var('x')
+        out = relay.abs(x)
+        out = relay.nn.relu(out)
+        merged_func = relay.Function([x], out)
+        merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite', expr.StringImm('C'))
+        ret = relay.Call(merged_func, [add])
+        return relay.Function([input_1, input_2], ret)
+
+    # check A highest priority
+    pattern_table = [
+        ("A", pattern_A()),
+        ("B", pattern_B()),
+        ("C", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+    # check B highest priority
+    pattern_table = [
+        ("B", pattern_A()),
+        ("C", pattern_B()),
+        ("A", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+    # check C highest priority
+    pattern_table = [
+        ("C", pattern_A()),
+        ("A", pattern_B()),
+        ("B", pattern_C()),
+    ]
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
 
 Review comment:
   my DNNL PR should do this job.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r377003977
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
+          \ /                      add_sub_mul
+          mul                     c     |
+          /  \                     \    |
+       c /  c |       ====>        add_sub_mul
+       \/   \/                          |
+       add  sub                         |
+        \   /                         relu
+         \ /
+         mul
+          |
+          |
+        relu
+    """
+
+    pattern_table = [
+        ("add_sub_mul", make_add_sub_mul_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+        add_node = relay.add(a, b)
+        sub_node = relay.subtract(a, b)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_node_2 = relay.add(c, mul_node)
+        sub_node_2 = relay.subtract(c, mul_node)
+        mul_node_2 = relay.multiply(add_node_2, sub_node_2)
+        r = relay.nn.relu(mul_node_2)
+        return relay.Function([a, b, c], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+
+        # add_sub_mul function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        sub_node = relay.subtract(in_1, in_2)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_sub_mul = relay.Function([in_1, in_2], mul_node)
+
+        # merged function
+        add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+        add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
+        r = relay.nn.relu(add_sub_mul_2)
+        return relay.Function([a, b, c], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_multiple_patterns():
+    """Test different patterns are merged correctly in the graph.
+
+    We would expect the pattern `make_conv_bias_relu_pattern` to be merged
+    into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
+    to be merged into a single op `add_relu`.
+
+        data   kernel
+          \      /
+           \    /
+           conv2d                   data   kernel   bias
+             |                         \      |      /
+             |   bias                 conv2d_bias_relu
+             |   /                            |
+          bias_add        ====>               |    a
+             |                                |   /
+           relu  a                        add_relu
+             \  /                             |
+             add                              |  b
+              |                               | /
+            relu  b                          mul
+              |  /
+             mul
+    """
+    pattern_table = [
+        ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        conv_node = relay.nn.conv2d(data,
+                                    kernel,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, bias)
+        relu_node = relay.nn.relu(bias_node)
+        add_node = relay.add(relu_node, a)
+        relu_node_2 = relay.nn.relu(add_node)
+        r = relay.multiply(relu_node_2, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    def expected():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        # conv_bias_relu function
+        in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
+        in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
+        in_3 = relay.var('in_3', shape=(256,))
+
+        conv_node = relay.nn.conv2d(in_1,
+                                    in_2,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, in_3)
+        r = relay.nn.relu(bias_node)
+        conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
+
+        # add_relu function
+        in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
+        in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
+        add_node = relay.add(in_4, in_5)
+        r = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_4, in_5], r)
+
+        # merged function
+        conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
+        add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
+        r = relay.multiply(add_relu_1, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_merge_order():
 
 Review comment:
   Have pattern 2 be above pattern 1 in the priority order. That way it will first try to fuse pattern 2 in the graph and fail, then try and fuse pattern 1 and succeed.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579456842
 
 
   thanks, such "complex compiler" use cases sounds quite interesting and the difference in Composite/External there makes a lot of sense. I also understand your earlier comments better now.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376750989
 
 

 ##########
 File path: tests/python/relay/test_pass_merge_composite.py
 ##########
 @@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for merge composite."""
+from tvm import expr
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+
+"""
+The merge composite pass is designed to merge multiple relay operators, that
+match a given pattern, and combine them into a single relay function.
+
+For example suppose we have the graph:
+
+    conv2d
+      |       (merge composite pass)
+   bias_add            ====>           conv2d_bias_relu
+      |            (our target)
+     relu
+
+Our Relay IR before the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
+            /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+Our Relay IR after the pass:
+    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
+            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
+      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
+            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
+            Tensor[(1, 256, 28, 28), float32] {
+        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
+        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
+      };
+      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
+    }
+
+As you can see in the second relay example, the pattern we specified has been wrapped
+in a function. The function is then called, producing the same result as the first relay
+example.
+
+One convenient use for this pass is to offload multiple operators to a single external
+codegen function.
+"""
+
+
+def make_add_sub_mul_pattern():
+    """Create a pattern to match the following graph.
+
+        add  sub
+         \   /
+          \ /
+          mul
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    sub_node = relay.subtract(x, y)
+    mul_node = relay.multiply(add_node, sub_node)
+    return mul_node
+
+
+def make_add_relu_pattern():
+    """Create a pattern to match the following graph.
+
+        add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    r = relay.nn.relu(add_node)
+    return r
+
+
+def make_conv_bias_relu_pattern():
+    """Create a pattern to match the following graph.
+
+       conv2d
+         |
+      bias_add
+         |
+       relu
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.var('z')
+    conv_node = relay.nn.conv2d(x, y)
+    bias_node = relay.nn.bias_add(conv_node, z)
+    r = relay.nn.relu(bias_node)
+    return r
+
+
+def test_simple_merge():
+    """Test composite function is correctly produced from simple graph.
+
+    We could expect the pattern `make_add_relu_pattern` to be merged
+    into a single op `add_relu`.
+
+        a  b
+        \ /               a  b
+        add    ====>      \ /
+         |             add_relu
+       relu
+
+    """
+    pattern_table = [
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        add_node = relay.add(a, b)
+        r = relay.nn.relu(add_node)
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        relu_node = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_1, in_2], relu_node)
+
+        # merged function
+        r = relay.Call(add_relu, [a, b])
+        return relay.Function([a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_branch_merge():
+    """Test composite function is correctly produced from branching graph.
+
+    We would expect the pattern `make_add_sub_mul_pattern` to be merged
+    into a single op `add_sub_mul`.
+
+       a  b  a  b
+        \/    \/
+        add  sub                       a  b
+         \   /                          \/
+          \ /                      add_sub_mul
+          mul                     c     |
+          /  \                     \    |
+       c /  c |       ====>        add_sub_mul
+       \/   \/                          |
+       add  sub                         |
+        \   /                         relu
+         \ /
+         mul
+          |
+          |
+        relu
+    """
+
+    pattern_table = [
+        ("add_sub_mul", make_add_sub_mul_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+        add_node = relay.add(a, b)
+        sub_node = relay.subtract(a, b)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_node_2 = relay.add(c, mul_node)
+        sub_node_2 = relay.subtract(c, mul_node)
+        mul_node_2 = relay.multiply(add_node_2, sub_node_2)
+        r = relay.nn.relu(mul_node_2)
+        return relay.Function([a, b, c], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        c = relay.var('c', shape=(10, 10))
+
+        # add_sub_mul function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        sub_node = relay.subtract(in_1, in_2)
+        mul_node = relay.multiply(add_node, sub_node)
+        add_sub_mul = relay.Function([in_1, in_2], mul_node)
+
+        # merged function
+        add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+        add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
+        r = relay.nn.relu(add_sub_mul_2)
+        return relay.Function([a, b, c], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_multiple_patterns():
+    """Test different patterns are merged correctly in the graph.
+
+    We would expect the pattern `make_conv_bias_relu_pattern` to be merged
+    into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
+    to be merged into a single op `add_relu`.
+
+        data   kernel
+          \      /
+           \    /
+           conv2d                   data   kernel   bias
+             |                         \      |      /
+             |   bias                 conv2d_bias_relu
+             |   /                            |
+          bias_add        ====>               |    a
+             |                                |   /
+           relu  a                        add_relu
+             \  /                             |
+             add                              |  b
+              |                               | /
+            relu  b                          mul
+              |  /
+             mul
+    """
+    pattern_table = [
+        ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
+        ("add_relu", make_add_relu_pattern())
+    ]
+
+    def before():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        conv_node = relay.nn.conv2d(data,
+                                    kernel,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, bias)
+        relu_node = relay.nn.relu(bias_node)
+        add_node = relay.add(relu_node, a)
+        relu_node_2 = relay.nn.relu(add_node)
+        r = relay.multiply(relu_node_2, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    def expected():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        bias = relay.var('bias', shape=(256,))
+        a = relay.var('a', shape=(1, 256, 28, 28))
+        b = relay.var('b', shape=(1, 256, 28, 28))
+
+        # conv_bias_relu function
+        in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
+        in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
+        in_3 = relay.var('in_3', shape=(256,))
+
+        conv_node = relay.nn.conv2d(in_1,
+                                    in_2,
+                                    kernel_size=(1, 1),
+                                    padding=(0, 0),
+                                    strides=(1, 1))
+
+        bias_node = relay.nn.bias_add(conv_node, in_3)
+        r = relay.nn.relu(bias_node)
+        conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
+
+        # add_relu function
+        in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
+        in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
+        add_node = relay.add(in_4, in_5)
+        r = relay.nn.relu(add_node)
+        add_relu = relay.Function([in_4, in_5], r)
+
+        # merged function
+        conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
+        add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
+        r = relay.multiply(add_relu_1, b)
+        return relay.Function([data, kernel, bias, a, b], r)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
+def test_merge_order():
 
 Review comment:
   What happens when pattern1 is a subset of pattern2? Basically you want to fuse as much as possible. So, in ideal cases, pattern 2. But, if not pattern 2 isn't applicable, we want to still fuse pattern 1.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-578979321
 
 
   It's the intention that this can be called on the entire Relay graph so that it can be used to help implement a generic annotation pass (one that is aware of composite functions). That way we can define functions similar to the 'Is Supported?' mechanism in the original annotation PR (since taken down) where you could declare in Python whether an operator was supported. That could be extended to say whether a composite function is supported without having to add pattern matching code to the annotator.
   
   The problem there is the case where composite functions do not end up in an external function after partitioning. My thinking is to have some legalize pass after the partitioning that removes the composite functions from sections of the graph not marked as external.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#discussion_r376750328
 
 

 ##########
 File path: python/tvm/relay/transform.py
 ##########
 @@ -513,6 +513,31 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
     return _transform.Legalize(legalize_map_attr_name)
 
 
+def MergeComposite(pattern_table):
+    """Merge multiple operators into a single composite relay function.
+
+    Parameters
+    ----------
+    pattern_table : list(tuple)
+        A list of (pattern_name, pattern) tuples.
 
 Review comment:
   Might be a dumb question - But why cant this be a map?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-579436246
 
 
   do we still need partitioning if we get rid of the annotation pass? At least for simple patterns like conv + bias + relu, I can get the same composite (or partitioning, whatever) without writing a clunky custom annotator and the partitioning pass. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #4771: [Relay] Added Merge Composite pass
URL: https://github.com/apache/incubator-tvm/pull/4771#issuecomment-577985906
 
 
   ok reading your original proposal carefully I understand that I can look up 'Composite' attribute to know if a particular function has a pattern I'm looking for. Of course I still need to traverse the arguments, but I can remove the "detection" part from DetectFusedConv2DBiasReLU. 
   
   It seems I have to make only minimal change to make use of this feature in my PR. Also it enables removing my custom annotator. Sounds great!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services