You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/04/03 15:14:13 UTC

[GitHub] [incubator-tvm] mbrookhart opened a new pull request #5231: [WIP][POC] Pattern Language and Matcher

mbrookhart opened a new pull request #5231: [WIP][POC] Pattern Language and Matcher
URL: https://github.com/apache/incubator-tvm/pull/5231
 
 
   C++ Implementation for 
   
   https://discuss.tvm.ai/t/rfc-relay-program-matching-for-relay-pt-1-a-pattern-language/5833
   
   Needs a little more work on Documentation, working on a second stage for pattern-based graph rewriting.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408461040
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern> args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches = true;
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+    auto child_graph = CreateIndexedGraph(op->child);
+    for (auto node : child_graph.topological_order_) {
+      if (node->ref_.as<WildcardPatternNode>()) {
+        continue;
+      }
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == op->child) {
+        dominated_exprs.insert(memo_[node->ref_]);
+      }
+    }
+    ClearMap(watermark);
+    auto expr_graph = CreateIndexedGraph(expr);
+    for (auto node : expr_graph.topological_order_) {
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == expr) {
+        if (dominated_exprs.count(node->ref_) == 0) {
+          bool node_matches = VisitDFPattern(op->parent, node->ref_);
+          ClearMap(watermark);
+          matches = node_matches || VisitDFPattern(op->path, node->ref_);
+          ClearMap(watermark);
+          if (!matches) {
+            return false;
+          }
+        }
+      }
+    }
+    return matches;
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+Expr InferType(const Expr& expr) {
 
 Review comment:
   This seems useful outside of pattern matching. In python frontend we also use this API a lot. Maybe move it to common utilities?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409155545
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  auto backup_memo = memo_;
+  auto backup_matched_nodes = matched_nodes_;
+
+  if (VisitDFPattern(op->child, expr)) {
+    auto child_graph = CreateIndexedGraph(GetRef<DFPattern>(op));
+    auto expr_graph = CreateIndexedGraph(expr);
+    auto find_dominated = [&child_graph, this](const DFPattern& node) {
+      std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+      auto indexed_node = child_graph.node_map_[node];
+      for (auto dominated : indexed_node->dominator_children_) {
+        if (dominated->ref_.as<WildcardPatternNode>() || dominated->ref_.as<OpNode>()) {
+          continue;
+        }
+        dominated_exprs.insert(memo_[dominated->ref_]);
+      }
+      return dominated_exprs;
+    };
+    std::function<bool(const Expr&, const std::unordered_set<Expr, ObjectHash, ObjectEqual>&)>
+        find_parent;
+    find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated,
+                   &expr_graph, &find_parent](
+                      const Expr& expr,
+                      const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs) {
+      bool out = true;
+      for (auto node : expr_graph.node_map_[expr]->dominator_children_) {
+        if (out && dominated_exprs.count(node->ref_) == 0) {
+          if (VisitDFPattern(op->parent, node->ref_)) {
+            backup_memo[op->parent] = memo_.at(op->parent);
+            backup_matched_nodes.push_back(op->parent);
+            memo_ = backup_memo;
+            matched_nodes_ = backup_matched_nodes;
+            watermark += 1;
+            return true;
+          } else {
+            if (VisitDFPattern(op->path, node->ref_)) {
+              auto new_dominated_exprs = find_dominated(op->path);
+              ClearMap(watermark);
+              out &= find_parent(node->ref_, new_dominated_exprs);
+            } else {
+              out = false;
+            }
+          }
+        }
+      }
+      return out;
+    };
+
+    auto dominated_exprs = find_dominated(op->child);
+    ClearMap(watermark);
+    bool matches = find_parent(expr, dominated_exprs);
+    if (matches) {
+      backup_memo[op->parent] = memo_.at(op->parent);
+      backup_memo[op->child] = expr;
+      memo_ = backup_memo;
 
 Review comment:
   :/ this is kind of leftover from an earlier design, I guess there's no real reason anymore. Will simplify, thanks.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-614331022
 
 
   Nice!
   
   > Any patterns in particular you want to see tested?
   
   No,  I don't know if these complicated patterns can come up in practice, but it is great to be future-proof :) Also it is a prereq if we want to replace the current fusion impl with a pattern matching based one.
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409138177
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  auto backup_memo = memo_;
+  auto backup_matched_nodes = matched_nodes_;
+
+  if (VisitDFPattern(op->child, expr)) {
+    auto child_graph = CreateIndexedGraph(GetRef<DFPattern>(op));
+    auto expr_graph = CreateIndexedGraph(expr);
+    auto find_dominated = [&child_graph, this](const DFPattern& node) {
+      std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+      auto indexed_node = child_graph.node_map_[node];
+      for (auto dominated : indexed_node->dominator_children_) {
+        if (dominated->ref_.as<WildcardPatternNode>() || dominated->ref_.as<OpNode>()) {
+          continue;
+        }
+        dominated_exprs.insert(memo_[dominated->ref_]);
+      }
+      return dominated_exprs;
+    };
+    std::function<bool(const Expr&, const std::unordered_set<Expr, ObjectHash, ObjectEqual>&)>
+        find_parent;
+    find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated,
 
 Review comment:
   I didn't know that a recursive lambda is possible, but I don't find it pretty :) Need to declare with full type, need to capture etc.
   
   Can we move this to a private member function? Variables captured can be passed as arguments. `find_dominated` can also be lifted to a free function.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409117750
 
 

 ##########
 File path: tests/python/relay/test_df_pattern.py
 ##########
 @@ -0,0 +1,574 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay.df_pattern import *
+import numpy as np
+
+# NB: 1 corresponds to the C++ enum that specicfies this
+# we loose the type safety due to the Python/C++ calling
+# convention.
+K_ELEMWISE = 0
+K_BROADCAST = 1
+
+## NODE TESTS
+def test_expr_pattern():
+    ep = ExprPattern(relay.var('x', shape=(4, 1)))
+    print(ep)
+
+def test_var_pattern():
+    v = is_input("x")
+    print(v)
+
+def test_wildcard_pattern():
+    wc = wildcard()
+    print(wc)
+
+def test_CallPattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    c = is_op("add")(wc1, wc2)
+    print(c)
+
+def test_TuplePattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    t = TuplePattern([wc1, wc2])
+    print(t)
+
+def test_TupleGetItemPattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    t = TuplePattern([wc1, wc2])
+    tgi = TupleGetItemPattern(t, 1)
+    print(tgi)
+
+def test_AltPattern():
+    is_add_or_sub = is_op('add') | is_op('subtract')
+    print(is_add_or_sub)
+
+def test_TypePattern():
+    ty_pat = has_type(relay.TensorType((10, 10), "float32"))
+    print(ty_pat)
+
+def test_AttrPattern():
+    op = is_op('add').has_attr("TOpPattern", K_ELEMWISE)
+    op_pat = op(wildcard(), wildcard())
+    print(op_pat)
+
+## MATCHER TESTS
+
+def test_match_op():
+    assert is_op('add').match(relay.op.op.get("add"))
+
+def test_no_match_op():
+    assert not is_op('add').match(relay.op.op.get("subtract"))
+
+def test_match_op_or():
+    is_add_or_sub = is_op('add') | is_op('subtract')
+    assert is_add_or_sub.match(relay.op.op.get("add"))
+    assert is_add_or_sub.match(relay.op.op.get("subtract"))
+
+def test_match_call_commutive():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('add')(is_input("x"), is_input("y"))
+    assert add_pattern.match(x + y)
+    assert add_pattern.match(y + x)
+    mul_pattern = is_op('multiply')(is_input("x"), is_input("y"))
+    assert mul_pattern.match(x * y)
+    assert mul_pattern.match(y * x)
+
+def test_no_match_call_commutive():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('subtract')(is_input("x"), is_input("y"))
+    assert add_pattern.match(x - y)
+    assert not add_pattern.match(y - x)
+    add_pattern = is_op('divide')(is_input("x"), is_input("y"))
+    assert add_pattern.match(x / y)
+    assert not add_pattern.match(y / x)
+
+def test_match_call():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('add')(wildcard(), wildcard())
+    assert add_pattern.match(x + y)
+
+def test_no_match_call():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('add')(wildcard(), wildcard())
+    assert not add_pattern.match(x - y)
+
+def test_match_tuple():
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
+    assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
+
+def test_no_match_tuple():
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard()))
+    assert not tuple_pattern.match(relay.expr.Tuple((x,y,z)))
+
+def test_match_tuple():
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
+    tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
+    assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 1))
+
+def test_no_match_tuple():
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add")))
+    tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
+    assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 2))
+
+def test_match_type():
+    x = relay.var('x', shape=(10, 10), dtype="float32")
+    ty_pat = has_type(relay.TensorType((10, 10), "float32"))
+    assert ty_pat.match(x)
+
+def test_no_match_type():
+    x = relay.var('x', shape=(10, 10), dtype="int32")
+    ty_pat = has_type(relay.TensorType((10, 10), "float32"))
+    assert not ty_pat.match(x)
+
+def test_match_attr():
+    op = is_op('add').has_attr("TOpPattern", K_BROADCAST)
+    op_pat = op(wildcard(), wildcard())
+    x = relay.var('x')
+    y = relay.var('y')
+    assert op_pat.match(x + y)
+
+def test_no_match_attr():
+    op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
+    op_pat = op(wildcard(), wildcard())
+    x = relay.var('x')
+    y = relay.var('y')
+    assert not op_pat.match(relay.op.nn.dense(x, y))
+
+def test_match_diamond():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    path1 = is_op('nn.relu')(is_conv2d)
+    path2 = is_op('nn.leaky_relu')(is_conv2d)
+    diamond = is_op('add')(path1, path2)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert diamond.match(out)
+
+def test_no_match_diamond():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    path1 = is_op('nn.relu')(is_conv2d)
+    path2 = is_op('nn.leaky_relu')(is_conv2d)
+    diamond = is_op('add')(path1, path2)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(leaky_relu)
+    assert not diamond.match(relu)
+
+def test_match_fake_diamond():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    path1 = is_op('nn.relu')(is_conv2d)
+    path2 = is_op('nn.leaky_relu')(is_conv2d)
+    diamond = is_op('add')(path1, path2)
+
+    # Expr
+    input1 = relay.var('input1')
+    weight1 = relay.var('weight1')
+    conv2d1 = relay.op.nn.conv2d(input1, weight1)
+    inp2 = relay.var('input2')
+    weight2 = relay.var('weight2')
+    conv2d2 = relay.op.nn.conv2d(inp2, weight2)
+    relu = relay.op.nn.relu(conv2d1)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(out)
+
+
+def test_match_dominator():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+    reduction = is_op('add')(wildcard(), wildcard())
+    diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relay.op.nn.relu(relu)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relay.op.nn.relu(relu)
+    relu = relay.op.tanh(relu)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relay.op.nn.relu(relu)
+    tanh = relay.op.tanh(relu)
+    out = relu + tanh
+
+    # Check
+    assert diamond.match(out)
+    
+
+def test_not_match_dominator():
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+    reduction = is_op('add')(wildcard(), wildcard())
+    diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+    # Expr
+    input1 = relay.var('input1')
+    weight1 = relay.var('weight1')
+    conv2d1 = relay.op.nn.conv2d(input1, weight1)
+    inp2 = relay.var('input2')
+    weight2 = relay.var('weight2')
+    conv2d2 = relay.op.nn.conv2d(inp2, weight2)
+    relu = relay.op.nn.relu(conv2d1)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relu + relu
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(inp)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    relu = relay.op.nn.relu(inp)
+    relu = relay.op.nn.relu(relu)
+    tanh = relay.op.tanh(relu)
+    out = relu + tanh
+
+    # Check
+    assert not diamond.match(out)
+
+def test_rewrite():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('add')(wildcard(), wildcard())
+    sub_pattern = is_op('subtract')(wildcard(), wildcard())
+    def add_to_sub(pre, post):
+        return post.args[0] - post.args[1]
+    out = rewrite([DFPatternCallback(add_pattern, add_to_sub)], x + y)
+    assert sub_pattern.match(out)
+
+def test_not_fuse_multi_diamond():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    path1 = is_op('nn.relu')(is_conv2d)
+    path2 = is_op('nn.leaky_relu')(is_conv2d)
+    diamond = is_op('add')(path1, path2)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+    out = out + conv2d
+    # Check
+    assert not diamond.match(out)
+
+def fuse_batchnorm(pre, post):
+    def left_right_call(post):
+        if isinstance(post.args[0], relay.Call):
+            return (post.args[1], post.args[0])
+        else:
+            return (post.args[0], post.args[1])
+    
+    beta, post = left_right_call(post)
+    assert isinstance(post, relay.Call)
+    
+    if post.op == relay.op.get("divide"):
+        numerator = post.args[0]
+        denominator = post.args[1]
+        gamma, numerator = left_right_call(numerator)
+    elif post.op == relay.op.get("multiply"):
+        gamma, quotient = left_right_call(post)
+        numerator = quotient.args[0]
+        denominator = quotient.args[1]
+    else:
+        raise "Found unexcepted op"
+
+    x = numerator.args[0]
+    mean = numerator.args[1]
+
+    var = denominator.args[0].args[0]
+    eps = denominator.args[0].args[1]
+    
+    out = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())
+    return out[0]
+
+def test_fuse_batchnorm():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN)
+    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+
+def test_no_fuse_batchnorm():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    fake_BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta
+
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), fake_BN)
+    assert tvm.ir.structural_equal(out, fake_BN)
+
+def test_fuse_double_batchnorm():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN2)
+
+    bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]
+    bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon = 1e-5)[0]
+
+    assert tvm.ir.structural_equal(out, bn2)
+
+def test_partial_fuse_double_batchnorm():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta
+    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN2)
+
+    bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon = 1e-5)[0]
+
+    assert tvm.ir.structural_equal(out, bn2)
+
+def test_fuse_batchnorm_commutation():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    #commute add
+    BN = beta + gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5))
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN)
+    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+
+    # associate divide/multiply
+    BN = (gamma * (x - mean)) /relay.op.sqrt(var + relay.const(1e-5))  + beta
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN)
+    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+
+    # associate multiply/divide
+    BN = gamma * ((x - mean)/relay.op.sqrt(var + relay.const(1e-5))) + beta
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN)
+    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+
+def algebraic_simplify(expr):
+    pattern_callbacks = []
+
+    def elwise_zero_callback(pre, post):
+        if (tvm.ir.structural_equal(post.args[0], relay.const(0)) | 
+            tvm.ir.structural_equal(post.args[0], relay.const(0.0))):
+            return post.args[1]
+        else:
+            return post.args[0]
+
+    def elwise_one_callback(pre, post):
+        if (tvm.ir.structural_equal(post.args[0], relay.const(1)) | 
+            tvm.ir.structural_equal(post.args[0], relay.const(1.0))):
+            return post.args[1]
+        else:
+            return post.args[0]
+
+    def return_zero_callback(pre, post):
+        if (tvm.ir.structural_equal(post.args[0], relay.const(0)) | 
+            tvm.ir.structural_equal(post.args[0], relay.const(0.0))):
+            return post.args[0]
+        else:
+            return post.args[1]
+
+    zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
+    one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0)))
+    add_pattern = wildcard() + zero
+    pattern_callbacks.append(DFPatternCallback(add_pattern, elwise_zero_callback))
+
+    sub_pattern = wildcard() - zero
+    pattern_callbacks.append(DFPatternCallback(sub_pattern, elwise_zero_callback))
+
+    mul_pattern = wildcard() * one
+    pattern_callbacks.append(DFPatternCallback(mul_pattern, elwise_one_callback))
+    
+    mul_zero_pattern = wildcard() * zero
+    pattern_callbacks.append(DFPatternCallback(mul_zero_pattern, return_zero_callback))
+
+    div_pattern = wildcard() / one
+    pattern_callbacks.append(DFPatternCallback(div_pattern, elwise_one_callback))
+
+    zero_div_pattern = zero / wildcard()
+    pattern_callbacks.append(DFPatternCallback(zero_div_pattern, return_zero_callback))
+
+    return rewrite(pattern_callbacks, expr);
+
+def test_algebraic_simplify():
+    x = relay.Var('x')
+    y = relay.Var('y')  
+
+    print(x + relay.const(0))
+    
+    one = relay.const(1)
+    zero = relay.const(0)
+    onef = relay.const(1.0)
+    zerof = relay.const(0.0)
+
+    assert algebraic_simplify(x + zero) == x
+    assert algebraic_simplify(x + zerof) == x
+    assert algebraic_simplify(zero + x) == x
+    assert algebraic_simplify(zerof + x) == x
+    
+    assert algebraic_simplify(x - zero) == x
+    assert algebraic_simplify(x - zerof) == x
+    
+    assert algebraic_simplify(x * one) == x
+    assert algebraic_simplify(x * onef) == x
+    assert algebraic_simplify(one * x) == x
+    assert algebraic_simplify(onef * x) == x
+    assert algebraic_simplify(x * zero) == zero
+    assert algebraic_simplify(x * zerof) == zerof
+    
+    assert algebraic_simplify(x / one) == x
+    assert algebraic_simplify(x / onef) == x
+    assert algebraic_simplify(zero / x) == zero
+    assert algebraic_simplify(zerof / x) == zerof
+
+    assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y)
+
+if __name__ == "__main__":
+    #test_match_op()
+    #test_no_match_op()
+    #test_match_op_or()
+    #test_match_call()
+    #test_no_match_call()
+    #test_match_call_commutive()
+    #test_no_match_call_commutive()
+    #test_match_tuple()
+    #test_no_match_tuple()
+    #test_match_type()
+    #test_no_match_type()
+    #test_match_attr()
+    #test_no_match_attr()
+    #test_match_diamond()
+    #test_no_match_diamond()
+    #test_match_fake_diamond()
+    #test_rewrite()
+    #test_fuse_batchnorm()
+    #test_no_fuse_batchnorm()
+    #test_fuse_double_batchnorm()
+    #test_partial_fuse_double_batchnorm()
+    #test_fuse_batchnorm_commutation()
+    #test_match_dominator()
 
 Review comment:
   They already run in CI via pytest, but I'll remove the comments so anyone running the tests manually doesn't hit issues. Thanks for the catch!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408647307
 
 

 ##########
 File path: include/tvm/relay/dataflow_functor.h
 ##########
 @@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_
+#define TVM_RELAY_DATAFLOW_FUNCTOR_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief A dynamical functor that dispatches on in the first DFPattern argument.
+ *
+ * \tparam FType function signiture
+ *  This type is only defined for FType with function signature R(const DFPattern&,
+ * Args...)
+ */
+template <typename FType>
+class DFPatternFunctor;
+
+// functions to be overriden.
+#define DFPATTERN_FUNCTOR_DEFAULT \
+  { return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }
+
+#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP)                                                    \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {          \
+    return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+  });
+
+template <typename R, typename... Args>
+class DFPatternFunctor<R(const DFPattern& n, Args...)> {
+ private:
+  using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*! \brief virtual destructor */
+  virtual ~DFPatternFunctor() {}
+  /*!
+   * \brief Same as call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  R operator()(const DFPattern& n, Args... args) {
+    return VisitDFPattern(n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief The functor call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  virtual R VisitDFPattern(const DFPattern& n, Args... args) {
+    CHECK(n.defined());
+    static FType vtable = InitVTable();
+    return vtable(n, this, std::forward<Args>(args)...);
+  }
+  // Functions that can be overriden by subclass
+  virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
+                            Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPatternDefault_(const Object* op, Args...) {
+    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
+    throw;
+  }
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+    FType vtable;
+    // Set dispatch
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
+    return vtable;
+  }
+};
+
+/*!
+ * \brief A simple visitor wrapper around DFPatternFunctor.
+ *  Recursively visit the content.
+ *
+ *  DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once.
+ */
+class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
+ public:
+  void VisitDFPattern(const DFPattern& pattern) override;
+  void VisitDFPattern_(const AltPatternNode* op) override;
+  void VisitDFPattern_(const AttrPatternNode* op) override;
+  void VisitDFPattern_(const CallPatternNode* op) override;
+  void VisitDFPattern_(const DominatorPatternNode* op) override;
+  void VisitDFPattern_(const ExprPatternNode* op) override;
+  void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
+  void VisitDFPattern_(const TuplePatternNode* op) override;
+  void VisitDFPattern_(const TypePatternNode* op) override;
+  void VisitDFPattern_(const VarPatternNode* op) override;
+  void VisitDFPattern_(const WildcardPatternNode* op) override;
+
+ protected:
+  // set of already-visited nodes
+  std::unordered_set<const Object*> visited_;
+};
+
+/*!
+ * \brief A Wrapper around a templated graph type
+ *  Holds a forward-backward indexed representation of the graph and a dominator tree representation
+ * of the graph
+ *
+ *  Class is Templated and the implementaiton is in the header file so we can analyis both DFPattern
+ * and Expr with the same infrastructure.
+ *
+ *  IndexedGraph should be instantiated thorught the CreateIndexedGraph utilities.
+ */
+template <typename T>
+class IndexedGraph {
+ public:
+  /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */
+  struct Node {
+    /*! \brief Node Constructor
+     *  \param ref The input graph node
+     *  \param index The index of the node in toplogoical order
 
 Review comment:
   topological

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408647818
 
 

 ##########
 File path: include/tvm/relay/dataflow_functor.h
 ##########
 @@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_
+#define TVM_RELAY_DATAFLOW_FUNCTOR_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief A dynamical functor that dispatches on in the first DFPattern argument.
+ *
+ * \tparam FType function signiture
+ *  This type is only defined for FType with function signature R(const DFPattern&,
+ * Args...)
+ */
+template <typename FType>
+class DFPatternFunctor;
+
+// functions to be overriden.
+#define DFPATTERN_FUNCTOR_DEFAULT \
+  { return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }
+
+#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP)                                                    \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {          \
+    return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+  });
+
+template <typename R, typename... Args>
+class DFPatternFunctor<R(const DFPattern& n, Args...)> {
+ private:
+  using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*! \brief virtual destructor */
+  virtual ~DFPatternFunctor() {}
+  /*!
+   * \brief Same as call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  R operator()(const DFPattern& n, Args... args) {
+    return VisitDFPattern(n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief The functor call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  virtual R VisitDFPattern(const DFPattern& n, Args... args) {
+    CHECK(n.defined());
+    static FType vtable = InitVTable();
+    return vtable(n, this, std::forward<Args>(args)...);
+  }
+  // Functions that can be overriden by subclass
+  virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
+                            Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPatternDefault_(const Object* op, Args...) {
+    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
+    throw;
+  }
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+    FType vtable;
+    // Set dispatch
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
+    return vtable;
+  }
+};
+
+/*!
+ * \brief A simple visitor wrapper around DFPatternFunctor.
+ *  Recursively visit the content.
+ *
+ *  DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once.
+ */
+class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
+ public:
+  void VisitDFPattern(const DFPattern& pattern) override;
+  void VisitDFPattern_(const AltPatternNode* op) override;
+  void VisitDFPattern_(const AttrPatternNode* op) override;
+  void VisitDFPattern_(const CallPatternNode* op) override;
+  void VisitDFPattern_(const DominatorPatternNode* op) override;
+  void VisitDFPattern_(const ExprPatternNode* op) override;
+  void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
+  void VisitDFPattern_(const TuplePatternNode* op) override;
+  void VisitDFPattern_(const TypePatternNode* op) override;
+  void VisitDFPattern_(const VarPatternNode* op) override;
+  void VisitDFPattern_(const WildcardPatternNode* op) override;
+
+ protected:
+  // set of already-visited nodes
+  std::unordered_set<const Object*> visited_;
+};
+
+/*!
+ * \brief A Wrapper around a templated graph type
+ *  Holds a forward-backward indexed representation of the graph and a dominator tree representation
+ * of the graph
+ *
+ *  Class is Templated and the implementaiton is in the header file so we can analyis both DFPattern
+ * and Expr with the same infrastructure.
+ *
+ *  IndexedGraph should be instantiated thorught the CreateIndexedGraph utilities.
+ */
+template <typename T>
+class IndexedGraph {
+ public:
+  /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */
+  struct Node {
+    /*! \brief Node Constructor
+     *  \param ref The input graph node
+     *  \param index The index of the node in toplogoical order
+     */
+    Node(const T& ref, const size_t index) : ref_(ref), index_(index) {}
+
+    /*! \brief The input node */
+    const T ref_;
+    /*! \brief The topological order index */
+    const size_t index_;
+
+    /*! \brief A boolean to determine if this node is external to the graph */
+    bool is_external_ = false;
+    /*! \brief The forward outputs/users of the node */
+    std::vector<Node*> outputs_;
+
+    /*! \brief The depth of the node in the dominator tree */
+    size_t depth_;
+    /*! \brief The dominator parent/final user of the outputs of this node */
+    Node* dominator_parent_;
+    /*! \brief The nodes this node dominates */
+    std::vector<Node*> dominator_children_;
+  };
+  /*! \brief Construct the domination create of the index graph */
 
 Review comment:
   domination create?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408458220
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern> args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
 
 Review comment:
   new line

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408455606
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
 
 Review comment:
   new line

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408496984
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern> args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches = true;
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+    auto child_graph = CreateIndexedGraph(op->child);
+    for (auto node : child_graph.topological_order_) {
+      if (node->ref_.as<WildcardPatternNode>()) {
+        continue;
+      }
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == op->child) {
+        dominated_exprs.insert(memo_[node->ref_]);
+      }
+    }
+    ClearMap(watermark);
+    auto expr_graph = CreateIndexedGraph(expr);
+    for (auto node : expr_graph.topological_order_) {
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == expr) {
+        if (dominated_exprs.count(node->ref_) == 0) {
+          bool node_matches = VisitDFPattern(op->parent, node->ref_);
+          ClearMap(watermark);
+          matches = node_matches || VisitDFPattern(op->path, node->ref_);
+          ClearMap(watermark);
+          if (!matches) {
+            return false;
+          }
+        }
+      }
+    }
+    return matches;
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+Expr InferType(const Expr& expr) {
 
 Review comment:
   Good idea. Do you have a suggestion for a place? Maybe analysis.h and analysis/util.cc?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409155274
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  auto backup_memo = memo_;
+  auto backup_matched_nodes = matched_nodes_;
+
+  if (VisitDFPattern(op->child, expr)) {
+    auto child_graph = CreateIndexedGraph(GetRef<DFPattern>(op));
+    auto expr_graph = CreateIndexedGraph(expr);
+    auto find_dominated = [&child_graph, this](const DFPattern& node) {
+      std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+      auto indexed_node = child_graph.node_map_[node];
+      for (auto dominated : indexed_node->dominator_children_) {
+        if (dominated->ref_.as<WildcardPatternNode>() || dominated->ref_.as<OpNode>()) {
+          continue;
+        }
+        dominated_exprs.insert(memo_[dominated->ref_]);
+      }
+      return dominated_exprs;
+    };
+    std::function<bool(const Expr&, const std::unordered_set<Expr, ObjectHash, ObjectEqual>&)>
+        find_parent;
+    find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated,
 
 Review comment:
   Yeah, it's not beautiful. 
   
   The other option is to make this a friend class to the matcher that just handles domination, there are two many required datamembers that are private to the dominator pattern to make this a more generally accessible part of the Pattern matcher. 
   
   I find friend classes equally horrendous, but I'm happy to move it, if you prefer.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbrookhart commented on issue #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on issue #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-614270281
 
 
   Thank you for the detailed updates, @masahi ! I'm super grateful someone takes time to  catch the little typos in the comments I miss.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409077909
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,423 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              std::cout << op << " " << op_map[op].operator int64_t() <<  std::endl;
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  auto backup_memo = memo_;
+  auto backup_matched_nodes = matched_nodes_;
+
+  if (VisitDFPattern(op->child, expr)) {
+    auto child_graph = CreateIndexedGraph(GetRef<DFPattern>(op));
+    auto expr_graph = CreateIndexedGraph(expr);
+    auto find_dominated = [&child_graph, this](const DFPattern& node) {
+      std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+      auto indexed_node = child_graph.node_map_[node];
+      for (auto dominated : indexed_node->dominator_children_) {
+        if (dominated->ref_.as<WildcardPatternNode>() || dominated->ref_.as<OpNode>()) {
+          continue;
+        }
+        dominated_exprs.insert(memo_[dominated->ref_]);
+      }
+      return dominated_exprs;
+    };
+    std::function<bool(const Expr&, const std::unordered_set<Expr, ObjectHash, ObjectEqual>&)>
+        find_parent;
+    find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated,
+                   &expr_graph, &find_parent](
+                      const Expr& expr,
+                      const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs) {
+      bool out = true;
+      for (auto node : expr_graph.node_map_[expr]->dominator_children_) {
+        if (out && dominated_exprs.count(node->ref_) == 0) {
+          if (VisitDFPattern(op->parent, node->ref_)) {
+            backup_memo[op->parent] = memo_.at(op->parent);
+            backup_matched_nodes.push_back(op->parent);
+            memo_ = backup_memo;
+            matched_nodes_ = backup_matched_nodes;
+            watermark += 1;
+            return true;
+          } else {
+            if (VisitDFPattern(op->path, node->ref_)) {
+              auto new_dominated_exprs = find_dominated(op->path);
+              std::cout << watermark << std::endl;
 
 Review comment:
   Remove it

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408536937
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern> args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches = true;
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+    auto child_graph = CreateIndexedGraph(op->child);
+    for (auto node : child_graph.topological_order_) {
+      if (node->ref_.as<WildcardPatternNode>()) {
+        continue;
+      }
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == op->child) {
+        dominated_exprs.insert(memo_[node->ref_]);
+      }
+    }
+    ClearMap(watermark);
+    auto expr_graph = CreateIndexedGraph(expr);
+    for (auto node : expr_graph.topological_order_) {
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == expr) {
+        if (dominated_exprs.count(node->ref_) == 0) {
+          bool node_matches = VisitDFPattern(op->parent, node->ref_);
+          ClearMap(watermark);
+          matches = node_matches || VisitDFPattern(op->path, node->ref_);
+          ClearMap(watermark);
+          if (!matches) {
+            return false;
+          }
+        }
+      }
+    }
+    return matches;
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+Expr InferType(const Expr& expr) {
 
 Review comment:
   I used to have a Python similar helper in C++ to run passes on an expression and return the updated expression. I believe ppl didn't like it because it was against the module-in and module-out style transformations. I think the one in Python side is mainly used by tests.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408668676
 
 

 ##########
 File path: include/tvm/relay/dataflow_pattern.h
 ##########
 @@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_pattern.h
+ * \brief A pattern language for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_
+#define TVM_RELAY_DATAFLOW_PATTERN_H_
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Base type of all dataflow patterns.
+ * \sa DFPattern
+ */
+class DFPatternNode : public Object {
+ public:
+  static constexpr const char* _type_key = "DFPatternNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow patterns.
+ * \sa DFPatternNode
+ */
+class DFPattern : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode);
+};
+
+/*!
+ * \brief Pattern for Relay Expression.
+ */
+class ExprPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The expression to match. */
+  Expr expr;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("expr", &expr);
+  }
+
+  static constexpr const char* _type_key = "relay.df_pattern.ExprPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a literal expression.
+ *
+ * \note Uses structural equality on expressions to check equality.
+ *
+ */
+class ExprPattern : public DFPattern {
+ public:
+  TVM_DLL ExprPattern(Expr expr);
+  TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode);
+};
+
+
+/*!
+ * \brief A Pattern to Match a Relay Variable
+ */
+class VarPattern;
+/*! \brief Container for Var */
+class VarPatternNode : public DFPatternNode {
+ public:
+  /*!
+   * \brief The name of the Var (optional).
+   */
+  std::string name;
+  /*!
+   * \brief type annotaion of the variable.
+   * This field records user provided type annotation of the Var.
+   * This field is optional and can be None.
+   */
+  Type type_annotation;
+
+  /*! \return The name hint of the variable */
+  const std::string& name_hint() const {
+    return name;
+  }
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("type_annotation", &type_annotation);
+  }
+
+  TVM_DLL static VarPattern make(std::string name_hint, Type type_annotation);
+
+  static constexpr const char* _type_key = "relay.df_pattern.VarPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode);
+};
+
+class VarPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode);
+};
+
+/*!
+ * \brief Call corresponds to operator invocation.
+ *  Corresponds to the operator in computational graph terminology.
+ */
+class CallPattern;
+/*! \brief CallPattern container. */
+class CallPatternNode : public DFPatternNode {
+ public:
+  /*!
+   * \brief The operator(function) being invoked
+   *
+   *  - It can be relay::Op which corresponds to the primitive operators.
+   *  - It can also be user defined functions (Function, GlobalVar, Var).
+   */
+  DFPattern op;
+
+  /*! \brief The arguments(inputs) of the call */
+  tvm::Array<relay::DFPattern> args;
+
+  /*! \brief The additional attributes */
+  Attrs attrs;
+
+  /*!
+   * \brief The type arguments passed to polymorphic(template) function.
+   *
+   * This is the advance feature that is only used when the function is
+   * polymorphic. It is safe to be ignored in most cases. For example, in the
+   * following code, the type_args of addone call is [int].
+   *
+   * \code
+   *
+   * template<typename T>
+   * T addone(T a) { return a + 1; }
+   *
+   * void main() {
+   *   int x = addone<int>(10);
+   * }
+   *
+   * \endcode
+   */
+  tvm::Array<Type> type_args;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("op", &op);
+    v->Visit("args", &args);
+    v->Visit("attrs", &attrs);
+    v->Visit("type_args", &type_args);
+  }
+
+  TVM_DLL static CallPattern make(DFPattern op, Array<DFPattern> args, Attrs attrs,
+                                  Array<Type> type_args);
+
+  static constexpr const char* _type_key = "relay.df_pattern.CallPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode);
+};
+
+class CallPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode);
+};
+
+/*! \brief Tuple of multiple Exprs */
+class TuplePattern;
+/*! \brief Tuple container */
+class TuplePatternNode : public DFPatternNode {
+ public:
+  /*! \brief the fields of the tuple */
+  tvm::Array<DFPattern> fields;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("fields", &fields);
+  }
+
+  TVM_DLL static TuplePattern make(tvm::Array<DFPattern> fields);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TuplePattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode);
+};
+
+class TuplePattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode);
+};
+
+/*! \brief Get index-th field out of a tuple. */
+class TupleGetItemPattern;
+class TupleGetItemPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The tuple Expression */
+  DFPattern tuple;
+  /*! \brief which value to get */
+  int index;
+
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("tuple_value", &tuple);
+  }
+
+  TVM_DLL static TupleGetItemPattern make(DFPattern tuple, int index);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TupleGetItemPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode);
+};
+
+class TupleGetItemPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode);
+};
+
+class AltPattern;
+/*!
+ * \brief Pattern for Alternate Expressions.
+ */
+class AltPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The left optional pattern. */
+  DFPattern left;
+  /*! \brief The right optional pattern. */
+  DFPattern right;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("left", &left);
+    v->Visit("right", &right);
+  }
+
+  TVM_DLL static AltPattern make(DFPattern left, DFPattern right);
+
+  static constexpr const char* _type_key = "relay.df_pattern.AltPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches either of two patterns
+ */
+class AltPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode);
+};
+
+
+/*!
+ * \brief Wildcard Pattern.
+ */
+class WildcardPatternNode : public DFPatternNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "relay.df_pattern.WildcardPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches anything.
+ */
+class WildcardPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
+};
+
+class TypePattern;
+/*!
+ * \brief Pattern for Types.
+ */
+class TypePatternNode : public DFPatternNode {
+ public:
+  /*! \brief The pattern. */
+  DFPattern pattern;
+  /*! \brief The type to match */
+  Type type;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("type", &type);
+  }
+
+  TVM_DLL static TypePattern make(DFPattern pattern, Type type);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TypePattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
+ */
+class TypePattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
+};
+
+class AttrPattern;
+/*!
+ * \brief Pattern for Types.
+ */
+class AttrPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The pattern. */
+  DFPattern pattern;
+  /*! \brief The attribute to match */
+  Attrs attrs;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("attrs", &attrs);
+  }
+
+  TVM_DLL static AttrPattern make(DFPattern pattern, Attrs attrs);
+
+  static constexpr const char* _type_key = "relay.df_pattern.AttrPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
 
 Review comment:
   attributes

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408529175
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern> args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches = true;
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+    auto child_graph = CreateIndexedGraph(op->child);
+    for (auto node : child_graph.topological_order_) {
+      if (node->ref_.as<WildcardPatternNode>()) {
+        continue;
+      }
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == op->child) {
+        dominated_exprs.insert(memo_[node->ref_]);
+      }
+    }
+    ClearMap(watermark);
+    auto expr_graph = CreateIndexedGraph(expr);
+    for (auto node : expr_graph.topological_order_) {
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == expr) {
+        if (dominated_exprs.count(node->ref_) == 0) {
+          bool node_matches = VisitDFPattern(op->parent, node->ref_);
+          ClearMap(watermark);
+          matches = node_matches || VisitDFPattern(op->path, node->ref_);
+          ClearMap(watermark);
+          if (!matches) {
+            return false;
+          }
+        }
+      }
+    }
+    return matches;
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+Expr InferType(const Expr& expr) {
 
 Review comment:
   Sounds good to me. But there is already InferType function with different signature declared in transform.h 
   https://github.com/apache/incubator-tvm/blob/master/include/tvm/relay/transform.h#L385
   
   We can also put the new decl there and define it in type_infer.cc. We can reuse `InferType(const Expr& expr, const IRModule& mod)` there.
   
   (I find type inference framed as transform a bit strange, though)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409900186
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> FindDominated(const DFPattern& node);
+  bool FindParent(const Expr& expr,
+                  const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs,
+                  const DominatorPatternNode* op);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
 
 Review comment:
   remove this blank to make it consistent with other autos above/below

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jroesch commented on issue #5231: [WIP][POC] Pattern Language and Matcher

Posted by GitBox <gi...@apache.org>.
jroesch commented on issue #5231: [WIP][POC] Pattern Language and Matcher
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-608621711
 
 
   cc @zhiics and @icemelon9 
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408529175
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern> args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches = true;
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+    auto child_graph = CreateIndexedGraph(op->child);
+    for (auto node : child_graph.topological_order_) {
+      if (node->ref_.as<WildcardPatternNode>()) {
+        continue;
+      }
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == op->child) {
+        dominated_exprs.insert(memo_[node->ref_]);
+      }
+    }
+    ClearMap(watermark);
+    auto expr_graph = CreateIndexedGraph(expr);
+    for (auto node : expr_graph.topological_order_) {
+      if (node->dominator_parent_ && node->dominator_parent_->ref_ == expr) {
+        if (dominated_exprs.count(node->ref_) == 0) {
+          bool node_matches = VisitDFPattern(op->parent, node->ref_);
+          ClearMap(watermark);
+          matches = node_matches || VisitDFPattern(op->path, node->ref_);
+          ClearMap(watermark);
+          if (!matches) {
+            return false;
+          }
+        }
+      }
+    }
+    return matches;
+  }
+  return false;
+}
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+Expr InferType(const Expr& expr) {
 
 Review comment:
   Sounds good to me. But there is already InferType function with different signature declared in transform.h 
   https://github.com/apache/incubator-tvm/blob/master/include/tvm/relay/transform.h#L385
   
   We can also put the new decl there and define it in type_infer.cc. We can reuse `InferType(const Expr& expr, const IRModule& mod)` there.
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408669180
 
 

 ##########
 File path: include/tvm/relay/dataflow_pattern.h
 ##########
 @@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_pattern.h
+ * \brief A pattern language for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_
+#define TVM_RELAY_DATAFLOW_PATTERN_H_
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Base type of all dataflow patterns.
+ * \sa DFPattern
+ */
+class DFPatternNode : public Object {
+ public:
+  static constexpr const char* _type_key = "DFPatternNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow patterns.
+ * \sa DFPatternNode
+ */
+class DFPattern : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode);
+};
+
+/*!
+ * \brief Pattern for Relay Expression.
+ */
+class ExprPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The expression to match. */
+  Expr expr;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("expr", &expr);
+  }
+
+  static constexpr const char* _type_key = "relay.df_pattern.ExprPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a literal expression.
+ *
+ * \note Uses structural equality on expressions to check equality.
+ *
+ */
+class ExprPattern : public DFPattern {
+ public:
+  TVM_DLL ExprPattern(Expr expr);
+  TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode);
+};
+
+
+/*!
+ * \brief A Pattern to Match a Relay Variable
+ */
+class VarPattern;
+/*! \brief Container for Var */
+class VarPatternNode : public DFPatternNode {
+ public:
+  /*!
+   * \brief The name of the Var (optional).
+   */
+  std::string name;
+  /*!
+   * \brief type annotaion of the variable.
+   * This field records user provided type annotation of the Var.
+   * This field is optional and can be None.
+   */
+  Type type_annotation;
+
+  /*! \return The name hint of the variable */
+  const std::string& name_hint() const {
+    return name;
+  }
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("type_annotation", &type_annotation);
+  }
+
+  TVM_DLL static VarPattern make(std::string name_hint, Type type_annotation);
+
+  static constexpr const char* _type_key = "relay.df_pattern.VarPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode);
+};
+
+class VarPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode);
+};
+
+/*!
+ * \brief Call corresponds to operator invocation.
+ *  Corresponds to the operator in computational graph terminology.
+ */
+class CallPattern;
+/*! \brief CallPattern container. */
+class CallPatternNode : public DFPatternNode {
+ public:
+  /*!
+   * \brief The operator(function) being invoked
+   *
+   *  - It can be relay::Op which corresponds to the primitive operators.
+   *  - It can also be user defined functions (Function, GlobalVar, Var).
+   */
+  DFPattern op;
+
+  /*! \brief The arguments(inputs) of the call */
+  tvm::Array<relay::DFPattern> args;
+
+  /*! \brief The additional attributes */
+  Attrs attrs;
+
+  /*!
+   * \brief The type arguments passed to polymorphic(template) function.
+   *
+   * This is the advance feature that is only used when the function is
+   * polymorphic. It is safe to be ignored in most cases. For example, in the
+   * following code, the type_args of addone call is [int].
+   *
+   * \code
+   *
+   * template<typename T>
+   * T addone(T a) { return a + 1; }
+   *
+   * void main() {
+   *   int x = addone<int>(10);
+   * }
+   *
+   * \endcode
+   */
+  tvm::Array<Type> type_args;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("op", &op);
+    v->Visit("args", &args);
+    v->Visit("attrs", &attrs);
+    v->Visit("type_args", &type_args);
+  }
+
+  TVM_DLL static CallPattern make(DFPattern op, Array<DFPattern> args, Attrs attrs,
+                                  Array<Type> type_args);
+
+  static constexpr const char* _type_key = "relay.df_pattern.CallPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode);
+};
+
+class CallPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode);
+};
+
+/*! \brief Tuple of multiple Exprs */
+class TuplePattern;
+/*! \brief Tuple container */
+class TuplePatternNode : public DFPatternNode {
+ public:
+  /*! \brief the fields of the tuple */
+  tvm::Array<DFPattern> fields;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("fields", &fields);
+  }
+
+  TVM_DLL static TuplePattern make(tvm::Array<DFPattern> fields);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TuplePattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode);
+};
+
+class TuplePattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode);
+};
+
+/*! \brief Get index-th field out of a tuple. */
+class TupleGetItemPattern;
+class TupleGetItemPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The tuple Expression */
+  DFPattern tuple;
+  /*! \brief which value to get */
+  int index;
+
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("tuple_value", &tuple);
+  }
+
+  TVM_DLL static TupleGetItemPattern make(DFPattern tuple, int index);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TupleGetItemPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode);
+};
+
+class TupleGetItemPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode);
+};
+
+class AltPattern;
+/*!
+ * \brief Pattern for Alternate Expressions.
+ */
+class AltPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The left optional pattern. */
+  DFPattern left;
+  /*! \brief The right optional pattern. */
+  DFPattern right;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("left", &left);
+    v->Visit("right", &right);
+  }
+
+  TVM_DLL static AltPattern make(DFPattern left, DFPattern right);
+
+  static constexpr const char* _type_key = "relay.df_pattern.AltPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches either of two patterns
+ */
+class AltPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode);
+};
+
+
+/*!
+ * \brief Wildcard Pattern.
+ */
+class WildcardPatternNode : public DFPatternNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "relay.df_pattern.WildcardPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches anything.
+ */
+class WildcardPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
+};
+
+class TypePattern;
+/*!
+ * \brief Pattern for Types.
+ */
+class TypePatternNode : public DFPatternNode {
+ public:
+  /*! \brief The pattern. */
+  DFPattern pattern;
+  /*! \brief The type to match */
+  Type type;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("type", &type);
+  }
+
+  TVM_DLL static TypePattern make(DFPattern pattern, Type type);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TypePattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
+ */
+class TypePattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
+};
+
+class AttrPattern;
+/*!
+ * \brief Pattern for Types.
+ */
+class AttrPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The pattern. */
+  DFPattern pattern;
+  /*! \brief The attribute to match */
+  Attrs attrs;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("attrs", &attrs);
+  }
+
+  TVM_DLL static AttrPattern make(DFPattern pattern, Attrs attrs);
+
+  static constexpr const char* _type_key = "relay.df_pattern.AttrPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
+ */
+class AttrPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode);
+};
+
+class DominatorPattern;
+/*!
+ * \brief Pattern for Types.
 
 Review comment:
   dominator?
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409902168
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> FindDominated(const DFPattern& node);
+  bool FindParent(const Expr& expr,
+                  const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs,
+                  const DominatorPatternNode* op);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
 
 Review comment:
   can you collapse nested ifs like this?
   
   ```
             if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
                 (is_expr_op(call_node->args[0], "divide") ||
                  is_expr_op(call_node->args[1], "divide"))) {
   ```
   Same for `multiply` below.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409058690
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,423 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              std::cout << op << " " << op_map[op].operator int64_t() <<  std::endl;
 
 Review comment:
   Remove this

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409879725
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  friend DominatorMatcher;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Friend class to do recursive dominator matching
+class DominatorMatcher {
+ public:
+  DominatorMatcher(DFPatternMatcher* matcher, const DominatorPatternNode* op, const Expr& expr)
+      : matcher_(matcher), op_(op), expr_(expr) {
+    watermark_ = matcher_->matched_nodes_.size();
+    pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  }
+  bool Match() {
+    if (matcher_->VisitDFPattern(op_->child, expr_)) {
+      auto dominated_exprs = FindDominated(op_->child);
+      matcher_->ClearMap(watermark_);
+
+      bool matches = FindParent(expr_, dominated_exprs);
+      if (matches) {
+        matcher_->ClearMap(watermark_);
+        matcher_->memo_[op_->child] = expr_;
+        matcher_->matched_nodes_.push_back(op_->child);
+      }
+      return matches;
+    }
+    return false;
+  }
+
+ protected:
+  DFPatternMatcher* matcher_;
+  const DominatorPatternNode* op_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  Expr expr_;
+  size_t watermark_;
 
 Review comment:
   I spent some more time thinking about it, I was able to reduce the state by ~50% and decided it was sufficient to fold that back into the main matcher. Thanks for the thoughts!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409322475
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  auto backup_memo = memo_;
+  auto backup_matched_nodes = matched_nodes_;
+
+  if (VisitDFPattern(op->child, expr)) {
+    auto child_graph = CreateIndexedGraph(GetRef<DFPattern>(op));
+    auto expr_graph = CreateIndexedGraph(expr);
+    auto find_dominated = [&child_graph, this](const DFPattern& node) {
+      std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+      auto indexed_node = child_graph.node_map_[node];
+      for (auto dominated : indexed_node->dominator_children_) {
+        if (dominated->ref_.as<WildcardPatternNode>() || dominated->ref_.as<OpNode>()) {
+          continue;
+        }
+        dominated_exprs.insert(memo_[dominated->ref_]);
+      }
+      return dominated_exprs;
+    };
+    std::function<bool(const Expr&, const std::unordered_set<Expr, ObjectHash, ObjectEqual>&)>
+        find_parent;
+    find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated,
 
 Review comment:
   How about making `find_parent` a member function of `DFPatternMatcher`? I attempted this in https://github.com/masahi/tvm/commit/c41a0e5577ae4fc549b2c7dc8c0daeaf0d011623. It compiles and passes all your tests.
   
   At least you can remove `backup_memo` and `backup_matched_nodes` entirely. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409138818
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  auto backup_memo = memo_;
+  auto backup_matched_nodes = matched_nodes_;
+
+  if (VisitDFPattern(op->child, expr)) {
+    auto child_graph = CreateIndexedGraph(GetRef<DFPattern>(op));
+    auto expr_graph = CreateIndexedGraph(expr);
+    auto find_dominated = [&child_graph, this](const DFPattern& node) {
+      std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+      auto indexed_node = child_graph.node_map_[node];
+      for (auto dominated : indexed_node->dominator_children_) {
+        if (dominated->ref_.as<WildcardPatternNode>() || dominated->ref_.as<OpNode>()) {
+          continue;
+        }
+        dominated_exprs.insert(memo_[dominated->ref_]);
+      }
+      return dominated_exprs;
+    };
+    std::function<bool(const Expr&, const std::unordered_set<Expr, ObjectHash, ObjectEqual>&)>
+        find_parent;
+    find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated,
+                   &expr_graph, &find_parent](
+                      const Expr& expr,
+                      const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs) {
+      bool out = true;
+      for (auto node : expr_graph.node_map_[expr]->dominator_children_) {
+        if (out && dominated_exprs.count(node->ref_) == 0) {
+          if (VisitDFPattern(op->parent, node->ref_)) {
+            backup_memo[op->parent] = memo_.at(op->parent);
+            backup_matched_nodes.push_back(op->parent);
+            memo_ = backup_memo;
+            matched_nodes_ = backup_matched_nodes;
+            watermark += 1;
+            return true;
+          } else {
+            if (VisitDFPattern(op->path, node->ref_)) {
+              auto new_dominated_exprs = find_dominated(op->path);
+              ClearMap(watermark);
+              out &= find_parent(node->ref_, new_dominated_exprs);
+            } else {
+              out = false;
+            }
+          }
+        }
+      }
+      return out;
+    };
+
+    auto dominated_exprs = find_dominated(op->child);
+    ClearMap(watermark);
+    bool matches = find_parent(expr, dominated_exprs);
+    if (matches) {
+      backup_memo[op->parent] = memo_.at(op->parent);
+      backup_memo[op->child] = expr;
+      memo_ = backup_memo;
 
 Review comment:
   why not directly update `memo_`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408455749
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern> args) {
 
 Review comment:
   & missing?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408646311
 
 

 ##########
 File path: include/tvm/relay/dataflow_functor.h
 ##########
 @@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_
+#define TVM_RELAY_DATAFLOW_FUNCTOR_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief A dynamical functor that dispatches on in the first DFPattern argument.
+ *
+ * \tparam FType function signiture
+ *  This type is only defined for FType with function signature R(const DFPattern&,
+ * Args...)
+ */
+template <typename FType>
+class DFPatternFunctor;
+
+// functions to be overriden.
+#define DFPATTERN_FUNCTOR_DEFAULT \
+  { return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }
+
+#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP)                                                    \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {          \
+    return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+  });
+
+template <typename R, typename... Args>
+class DFPatternFunctor<R(const DFPattern& n, Args...)> {
+ private:
+  using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*! \brief virtual destructor */
+  virtual ~DFPatternFunctor() {}
+  /*!
+   * \brief Same as call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  R operator()(const DFPattern& n, Args... args) {
+    return VisitDFPattern(n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief The functor call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  virtual R VisitDFPattern(const DFPattern& n, Args... args) {
+    CHECK(n.defined());
+    static FType vtable = InitVTable();
+    return vtable(n, this, std::forward<Args>(args)...);
+  }
+  // Functions that can be overriden by subclass
+  virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
+                            Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPatternDefault_(const Object* op, Args...) {
+    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
+    throw;
+  }
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+    FType vtable;
+    // Set dispatch
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
+    return vtable;
+  }
+};
+
+/*!
+ * \brief A simple visitor wrapper around DFPatternFunctor.
+ *  Recursively visit the content.
+ *
+ *  DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once.
+ */
+class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
+ public:
+  void VisitDFPattern(const DFPattern& pattern) override;
+  void VisitDFPattern_(const AltPatternNode* op) override;
+  void VisitDFPattern_(const AttrPatternNode* op) override;
+  void VisitDFPattern_(const CallPatternNode* op) override;
+  void VisitDFPattern_(const DominatorPatternNode* op) override;
+  void VisitDFPattern_(const ExprPatternNode* op) override;
+  void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
+  void VisitDFPattern_(const TuplePatternNode* op) override;
+  void VisitDFPattern_(const TypePatternNode* op) override;
+  void VisitDFPattern_(const VarPatternNode* op) override;
+  void VisitDFPattern_(const WildcardPatternNode* op) override;
+
+ protected:
+  // set of already-visited nodes
+  std::unordered_set<const Object*> visited_;
+};
+
+/*!
+ * \brief A Wrapper around a templated graph type
+ *  Holds a forward-backward indexed representation of the graph and a dominator tree representation
+ * of the graph
+ *
+ *  Class is Templated and the implementaiton is in the header file so we can analyis both DFPattern
 
 Review comment:
   templated
   we can analyze

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408645794
 
 

 ##########
 File path: tests/python/relay/test_df_pattern.py
 ##########
 @@ -0,0 +1,574 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay.df_pattern import *
+import numpy as np
+
+# NB: 1 corresponds to the C++ enum that specicfies this
+# we loose the type safety due to the Python/C++ calling
+# convention.
+K_ELEMWISE = 0
+K_BROADCAST = 1
+
+## NODE TESTS
+def test_expr_pattern():
+    ep = ExprPattern(relay.var('x', shape=(4, 1)))
+    print(ep)
+
+def test_var_pattern():
+    v = is_input("x")
+    print(v)
+
+def test_wildcard_pattern():
+    wc = wildcard()
+    print(wc)
+
+def test_CallPattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    c = is_op("add")(wc1, wc2)
+    print(c)
+
+def test_TuplePattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    t = TuplePattern([wc1, wc2])
+    print(t)
+
+def test_TupleGetItemPattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    t = TuplePattern([wc1, wc2])
+    tgi = TupleGetItemPattern(t, 1)
+    print(tgi)
+
+def test_AltPattern():
+    is_add_or_sub = is_op('add') | is_op('subtract')
+    print(is_add_or_sub)
+
+def test_TypePattern():
+    ty_pat = has_type(relay.TensorType((10, 10), "float32"))
+    print(ty_pat)
+
+def test_AttrPattern():
+    op = is_op('add').has_attr("TOpPattern", K_ELEMWISE)
+    op_pat = op(wildcard(), wildcard())
+    print(op_pat)
+
+## MATCHER TESTS
+
+def test_match_op():
+    assert is_op('add').match(relay.op.op.get("add"))
+
+def test_no_match_op():
+    assert not is_op('add').match(relay.op.op.get("subtract"))
+
+def test_match_op_or():
+    is_add_or_sub = is_op('add') | is_op('subtract')
+    assert is_add_or_sub.match(relay.op.op.get("add"))
+    assert is_add_or_sub.match(relay.op.op.get("subtract"))
+
+def test_match_call_commutive():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('add')(is_input("x"), is_input("y"))
+    assert add_pattern.match(x + y)
+    assert add_pattern.match(y + x)
+    mul_pattern = is_op('multiply')(is_input("x"), is_input("y"))
+    assert mul_pattern.match(x * y)
+    assert mul_pattern.match(y * x)
+
+def test_no_match_call_commutive():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('subtract')(is_input("x"), is_input("y"))
+    assert add_pattern.match(x - y)
+    assert not add_pattern.match(y - x)
+    add_pattern = is_op('divide')(is_input("x"), is_input("y"))
+    assert add_pattern.match(x / y)
+    assert not add_pattern.match(y / x)
+
+def test_match_call():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('add')(wildcard(), wildcard())
+    assert add_pattern.match(x + y)
+
+def test_no_match_call():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('add')(wildcard(), wildcard())
+    assert not add_pattern.match(x - y)
+
+def test_match_tuple():
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
+    assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
+
+def test_no_match_tuple():
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard()))
+    assert not tuple_pattern.match(relay.expr.Tuple((x,y,z)))
+
+def test_match_tuple():
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
+    tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
+    assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 1))
+
+def test_no_match_tuple():
+    x = relay.var('x')
+    y = relay.var('y')
+    z = relay.op.op.get("add")
+    tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add")))
+    tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
+    assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 2))
+
+def test_match_type():
+    x = relay.var('x', shape=(10, 10), dtype="float32")
+    ty_pat = has_type(relay.TensorType((10, 10), "float32"))
+    assert ty_pat.match(x)
+
+def test_no_match_type():
+    x = relay.var('x', shape=(10, 10), dtype="int32")
+    ty_pat = has_type(relay.TensorType((10, 10), "float32"))
+    assert not ty_pat.match(x)
+
+def test_match_attr():
+    op = is_op('add').has_attr("TOpPattern", K_BROADCAST)
+    op_pat = op(wildcard(), wildcard())
+    x = relay.var('x')
+    y = relay.var('y')
+    assert op_pat.match(x + y)
+
+def test_no_match_attr():
+    op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
+    op_pat = op(wildcard(), wildcard())
+    x = relay.var('x')
+    y = relay.var('y')
+    assert not op_pat.match(relay.op.nn.dense(x, y))
+
+def test_match_diamond():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    path1 = is_op('nn.relu')(is_conv2d)
+    path2 = is_op('nn.leaky_relu')(is_conv2d)
+    diamond = is_op('add')(path1, path2)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert diamond.match(out)
+
+def test_no_match_diamond():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    path1 = is_op('nn.relu')(is_conv2d)
+    path2 = is_op('nn.leaky_relu')(is_conv2d)
+    diamond = is_op('add')(path1, path2)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(leaky_relu)
+    assert not diamond.match(relu)
+
+def test_match_fake_diamond():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    path1 = is_op('nn.relu')(is_conv2d)
+    path2 = is_op('nn.leaky_relu')(is_conv2d)
+    diamond = is_op('add')(path1, path2)
+
+    # Expr
+    input1 = relay.var('input1')
+    weight1 = relay.var('weight1')
+    conv2d1 = relay.op.nn.conv2d(input1, weight1)
+    inp2 = relay.var('input2')
+    weight2 = relay.var('weight2')
+    conv2d2 = relay.op.nn.conv2d(inp2, weight2)
+    relu = relay.op.nn.relu(conv2d1)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(out)
+
+
+def test_match_dominator():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+    reduction = is_op('add')(wildcard(), wildcard())
+    diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relay.op.nn.relu(relu)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relay.op.nn.relu(relu)
+    relu = relay.op.tanh(relu)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relay.op.nn.relu(relu)
+    tanh = relay.op.tanh(relu)
+    out = relu + tanh
+
+    # Check
+    assert diamond.match(out)
+    
+
+def test_not_match_dominator():
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+    reduction = is_op('add')(wildcard(), wildcard())
+    diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+    # Expr
+    input1 = relay.var('input1')
+    weight1 = relay.var('weight1')
+    conv2d1 = relay.op.nn.conv2d(input1, weight1)
+    inp2 = relay.var('input2')
+    weight2 = relay.var('weight2')
+    conv2d2 = relay.op.nn.conv2d(inp2, weight2)
+    relu = relay.op.nn.relu(conv2d1)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relu + relu
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(inp)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert not diamond.match(out)
+
+    # Expr
+    inp = relay.var('input')
+    relu = relay.op.nn.relu(inp)
+    relu = relay.op.nn.relu(relu)
+    tanh = relay.op.tanh(relu)
+    out = relu + tanh
+
+    # Check
+    assert not diamond.match(out)
+
+def test_rewrite():
+    x = relay.var('x')
+    y = relay.var('y')
+    add_pattern = is_op('add')(wildcard(), wildcard())
+    sub_pattern = is_op('subtract')(wildcard(), wildcard())
+    def add_to_sub(pre, post):
+        return post.args[0] - post.args[1]
+    out = rewrite([DFPatternCallback(add_pattern, add_to_sub)], x + y)
+    assert sub_pattern.match(out)
+
+def test_not_fuse_multi_diamond():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    path1 = is_op('nn.relu')(is_conv2d)
+    path2 = is_op('nn.leaky_relu')(is_conv2d)
+    diamond = is_op('add')(path1, path2)
+
+    # Expr
+    inp = relay.var('input')
+    weight = relay.var('weight')
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+    out = out + conv2d
+    # Check
+    assert not diamond.match(out)
+
+def fuse_batchnorm(pre, post):
+    def left_right_call(post):
+        if isinstance(post.args[0], relay.Call):
+            return (post.args[1], post.args[0])
+        else:
+            return (post.args[0], post.args[1])
+    
+    beta, post = left_right_call(post)
+    assert isinstance(post, relay.Call)
+    
+    if post.op == relay.op.get("divide"):
+        numerator = post.args[0]
+        denominator = post.args[1]
+        gamma, numerator = left_right_call(numerator)
+    elif post.op == relay.op.get("multiply"):
+        gamma, quotient = left_right_call(post)
+        numerator = quotient.args[0]
+        denominator = quotient.args[1]
+    else:
+        raise "Found unexcepted op"
+
+    x = numerator.args[0]
+    mean = numerator.args[1]
+
+    var = denominator.args[0].args[0]
+    eps = denominator.args[0].args[1]
+    
+    out = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())
+    return out[0]
+
+def test_fuse_batchnorm():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN)
+    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+
+def test_no_fuse_batchnorm():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    fake_BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta
+
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), fake_BN)
+    assert tvm.ir.structural_equal(out, fake_BN)
+
+def test_fuse_double_batchnorm():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN2)
+
+    bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]
+    bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon = 1e-5)[0]
+
+    assert tvm.ir.structural_equal(out, bn2)
+
+def test_partial_fuse_double_batchnorm():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta
+    BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN2)
+
+    bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon = 1e-5)[0]
+
+    assert tvm.ir.structural_equal(out, bn2)
+
+def test_fuse_batchnorm_commutation():
+    x = relay.var('x')
+    var = relay.var('var')
+    mean = relay.var('mean')
+    beta = relay.var('beta')
+    gamma = relay.var('gamma')
+    
+    BN_pattern = wildcard() * (wildcard() - wildcard())/is_op("sqrt")(wildcard() + wildcard()) + wildcard()
+    #commute add
+    BN = beta + gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5))
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN)
+    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+
+    # associate divide/multiply
+    BN = (gamma * (x - mean)) /relay.op.sqrt(var + relay.const(1e-5))  + beta
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN)
+    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+
+    # associate multiply/divide
+    BN = gamma * ((x - mean)/relay.op.sqrt(var + relay.const(1e-5))) + beta
+    out = rewrite(DFPatternCallback(BN_pattern, fuse_batchnorm), BN)
+    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+
+def algebraic_simplify(expr):
+    pattern_callbacks = []
+
+    def elwise_zero_callback(pre, post):
+        if (tvm.ir.structural_equal(post.args[0], relay.const(0)) | 
+            tvm.ir.structural_equal(post.args[0], relay.const(0.0))):
+            return post.args[1]
+        else:
+            return post.args[0]
+
+    def elwise_one_callback(pre, post):
+        if (tvm.ir.structural_equal(post.args[0], relay.const(1)) | 
+            tvm.ir.structural_equal(post.args[0], relay.const(1.0))):
+            return post.args[1]
+        else:
+            return post.args[0]
+
+    def return_zero_callback(pre, post):
+        if (tvm.ir.structural_equal(post.args[0], relay.const(0)) | 
+            tvm.ir.structural_equal(post.args[0], relay.const(0.0))):
+            return post.args[0]
+        else:
+            return post.args[1]
+
+    zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
+    one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0)))
+    add_pattern = wildcard() + zero
+    pattern_callbacks.append(DFPatternCallback(add_pattern, elwise_zero_callback))
+
+    sub_pattern = wildcard() - zero
+    pattern_callbacks.append(DFPatternCallback(sub_pattern, elwise_zero_callback))
+
+    mul_pattern = wildcard() * one
+    pattern_callbacks.append(DFPatternCallback(mul_pattern, elwise_one_callback))
+    
+    mul_zero_pattern = wildcard() * zero
+    pattern_callbacks.append(DFPatternCallback(mul_zero_pattern, return_zero_callback))
+
+    div_pattern = wildcard() / one
+    pattern_callbacks.append(DFPatternCallback(div_pattern, elwise_one_callback))
+
+    zero_div_pattern = zero / wildcard()
+    pattern_callbacks.append(DFPatternCallback(zero_div_pattern, return_zero_callback))
+
+    return rewrite(pattern_callbacks, expr);
+
+def test_algebraic_simplify():
+    x = relay.Var('x')
+    y = relay.Var('y')  
+
+    print(x + relay.const(0))
+    
+    one = relay.const(1)
+    zero = relay.const(0)
+    onef = relay.const(1.0)
+    zerof = relay.const(0.0)
+
+    assert algebraic_simplify(x + zero) == x
+    assert algebraic_simplify(x + zerof) == x
+    assert algebraic_simplify(zero + x) == x
+    assert algebraic_simplify(zerof + x) == x
+    
+    assert algebraic_simplify(x - zero) == x
+    assert algebraic_simplify(x - zerof) == x
+    
+    assert algebraic_simplify(x * one) == x
+    assert algebraic_simplify(x * onef) == x
+    assert algebraic_simplify(one * x) == x
+    assert algebraic_simplify(onef * x) == x
+    assert algebraic_simplify(x * zero) == zero
+    assert algebraic_simplify(x * zerof) == zerof
+    
+    assert algebraic_simplify(x / one) == x
+    assert algebraic_simplify(x / onef) == x
+    assert algebraic_simplify(zero / x) == zero
+    assert algebraic_simplify(zerof / x) == zerof
+
+    assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y)
+
+if __name__ == "__main__":
+    #test_match_op()
+    #test_no_match_op()
+    #test_match_op_or()
+    #test_match_call()
+    #test_no_match_call()
+    #test_match_call_commutive()
+    #test_no_match_call_commutive()
+    #test_match_tuple()
+    #test_no_match_tuple()
+    #test_match_type()
+    #test_no_match_type()
+    #test_match_attr()
+    #test_no_match_attr()
+    #test_match_diamond()
+    #test_no_match_diamond()
+    #test_match_fake_diamond()
+    #test_rewrite()
+    #test_fuse_batchnorm()
+    #test_no_fuse_batchnorm()
+    #test_fuse_double_batchnorm()
+    #test_partial_fuse_double_batchnorm()
+    #test_fuse_batchnorm_commutation()
+    #test_match_dominator()
 
 Review comment:
   enable these tests?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbrookhart commented on issue #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on issue #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-614321103
 
 
   Yes! Fuzzy diamond matching as long as all of the nodes between the parent and the child all match the path pattern.
   
   Just added this unit test, will upstream it with the refactor:
   ```
       # Fuzzy path/nested Diamond
       is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
       is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
       reduction = is_op('add')(wildcard(), wildcard())
       diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
   
       inp = relay.var('input')
       weight = relay.var('weight')
       conv2d = relay.op.nn.conv2d(inp, weight)
       relu = relay.op.nn.relu(conv2d)
       relu = relu + relu
       tanh = relay.op.tanh(relu)
       leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
       out = tanh + leaky_relu
   
       assert diamond.match(out)
   ```
   Any patterns in particular you want to see tested?
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409138177
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  auto backup_memo = memo_;
+  auto backup_matched_nodes = matched_nodes_;
+
+  if (VisitDFPattern(op->child, expr)) {
+    auto child_graph = CreateIndexedGraph(GetRef<DFPattern>(op));
+    auto expr_graph = CreateIndexedGraph(expr);
+    auto find_dominated = [&child_graph, this](const DFPattern& node) {
+      std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+      auto indexed_node = child_graph.node_map_[node];
+      for (auto dominated : indexed_node->dominator_children_) {
+        if (dominated->ref_.as<WildcardPatternNode>() || dominated->ref_.as<OpNode>()) {
+          continue;
+        }
+        dominated_exprs.insert(memo_[dominated->ref_]);
+      }
+      return dominated_exprs;
+    };
+    std::function<bool(const Expr&, const std::unordered_set<Expr, ObjectHash, ObjectEqual>&)>
+        find_parent;
+    find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated,
 
 Review comment:
   I didn't know that a recursive lambda is possible, but I don't find it pretty :) Need to declare with full type, need to capture etc.
   
   Can we move this to a private member function? Variables captured can be passed as arguments.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409902806
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> FindDominated(const DFPattern& node);
+  bool FindParent(const Expr& expr,
+                  const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs,
+                  const DominatorPatternNode* op);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+std::unordered_set<Expr, ObjectHash, ObjectEqual> DFPatternMatcher::FindDominated(
+    const DFPattern& node) {
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+  auto indexed_node = pattern_graph_.node_map_[node];
+  for (auto dominated : indexed_node->dominator_children_) {
+    if (dominated->ref_.as<WildcardPatternNode>()) {
+      continue;
+    }
+    if (memo_.count(dominated->ref_)) {
+      Array<Expr> matched = memo_[dominated->ref_];
+      dominated_exprs.insert(matched[matched.size() - 1]);
+    }
+  }
+  return dominated_exprs;
+}
+
+bool DFPatternMatcher::FindParent(
+    const Expr& expr, const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs,
+    const DominatorPatternNode* op) {
+  bool out = true;
+  for (auto node : expr_graph_.node_map_[expr]->dominator_children_) {
+    if (out && dominated_exprs.count(node->ref_) == 0 && node->ref_.as<OpNode>() == nullptr) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          auto new_dominated_exprs = FindDominated(op->path);
+          out &= FindParent(node->ref_, new_dominated_exprs, op);
+        } else {
+          out = false;
 
 Review comment:
   return false;

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408455537
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
 
 Review comment:
   needs new line

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409681253
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  friend DominatorMatcher;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Friend class to do recursive dominator matching
+class DominatorMatcher {
+ public:
+  DominatorMatcher(DFPatternMatcher* matcher, const DominatorPatternNode* op, const Expr& expr)
+      : matcher_(matcher), op_(op), expr_(expr) {
+    watermark_ = matcher_->matched_nodes_.size();
+    pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  }
+  bool Match() {
+    if (matcher_->VisitDFPattern(op_->child, expr_)) {
+      auto dominated_exprs = FindDominated(op_->child);
+      matcher_->ClearMap(watermark_);
+
+      bool matches = FindParent(expr_, dominated_exprs);
+      if (matches) {
+        matcher_->ClearMap(watermark_);
+        matcher_->memo_[op_->child] = expr_;
+        matcher_->matched_nodes_.push_back(op_->child);
+      }
+      return matches;
+    }
+    return false;
+  }
+
+ protected:
+  DFPatternMatcher* matcher_;
+  const DominatorPatternNode* op_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  Expr expr_;
+  size_t watermark_;
 
 Review comment:
   @masahi I did the friend class refactor last night. It's not beautiful, but even removing 3 state variables, there's still quite a bit of state involved in this recursion, and I think I'd prefer to keep this stuff out of the main matcher, but if you have other thoughts, I'd love to hear them.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409900411
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> FindDominated(const DFPattern& node);
+  bool FindParent(const Expr& expr,
+                  const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs,
+                  const DominatorPatternNode* op);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+std::unordered_set<Expr, ObjectHash, ObjectEqual> DFPatternMatcher::FindDominated(
+    const DFPattern& node) {
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+  auto indexed_node = pattern_graph_.node_map_[node];
+  for (auto dominated : indexed_node->dominator_children_) {
+    if (dominated->ref_.as<WildcardPatternNode>()) {
+      continue;
+    }
+    if (memo_.count(dominated->ref_)) {
+      Array<Expr> matched = memo_[dominated->ref_];
+      dominated_exprs.insert(matched[matched.size() - 1]);
+    }
+  }
+  return dominated_exprs;
+}
+
+bool DFPatternMatcher::FindParent(
+    const Expr& expr, const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs,
+    const DominatorPatternNode* op) {
+  bool out = true;
+  for (auto node : expr_graph_.node_map_[expr]->dominator_children_) {
+    if (out && dominated_exprs.count(node->ref_) == 0 && node->ref_.as<OpNode>() == nullptr) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          auto new_dominated_exprs = FindDominated(op->path);
+          out &= FindParent(node->ref_, new_dominated_exprs, op);
+        } else {
+          out = false;
+        }
+      }
+    }
+  }
+  return out;
+}
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
 
 Review comment:
   new line

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on issue #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-614286984
 
 
   > I'm super grateful someone takes time to catch the little typos in the comments I miss.
   
   Catching typos takes zero effort, so no problem :)
   
   Currently I'm reading the visit on dominator pattern. Since I like the diamond problem, https://github.com/apache/incubator-tvm/pull/1548, I'd like to ask, is arbitrary diamond shape supported? What about a diamond nested within another diamond? Current fusion algo can deal with them. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409899250
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+ protected:
 
 Review comment:
   add new line before `protected`. `clang-format` can take of this including removing the black line at L35, but unfortunately it doesn't seem to add new line between functions automatically.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408646552
 
 

 ##########
 File path: include/tvm/relay/dataflow_functor.h
 ##########
 @@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_
+#define TVM_RELAY_DATAFLOW_FUNCTOR_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief A dynamical functor that dispatches on in the first DFPattern argument.
+ *
+ * \tparam FType function signiture
+ *  This type is only defined for FType with function signature R(const DFPattern&,
+ * Args...)
+ */
+template <typename FType>
+class DFPatternFunctor;
+
+// functions to be overriden.
+#define DFPATTERN_FUNCTOR_DEFAULT \
+  { return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }
+
+#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP)                                                    \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {          \
+    return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+  });
+
+template <typename R, typename... Args>
+class DFPatternFunctor<R(const DFPattern& n, Args...)> {
+ private:
+  using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*! \brief virtual destructor */
+  virtual ~DFPatternFunctor() {}
+  /*!
+   * \brief Same as call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  R operator()(const DFPattern& n, Args... args) {
+    return VisitDFPattern(n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief The functor call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  virtual R VisitDFPattern(const DFPattern& n, Args... args) {
+    CHECK(n.defined());
+    static FType vtable = InitVTable();
+    return vtable(n, this, std::forward<Args>(args)...);
+  }
+  // Functions that can be overriden by subclass
+  virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
+                            Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPatternDefault_(const Object* op, Args...) {
+    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
+    throw;
+  }
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+    FType vtable;
+    // Set dispatch
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
+    return vtable;
+  }
+};
+
+/*!
+ * \brief A simple visitor wrapper around DFPatternFunctor.
+ *  Recursively visit the content.
+ *
+ *  DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once.
+ */
+class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
+ public:
+  void VisitDFPattern(const DFPattern& pattern) override;
+  void VisitDFPattern_(const AltPatternNode* op) override;
+  void VisitDFPattern_(const AttrPatternNode* op) override;
+  void VisitDFPattern_(const CallPatternNode* op) override;
+  void VisitDFPattern_(const DominatorPatternNode* op) override;
+  void VisitDFPattern_(const ExprPatternNode* op) override;
+  void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
+  void VisitDFPattern_(const TuplePatternNode* op) override;
+  void VisitDFPattern_(const TypePatternNode* op) override;
+  void VisitDFPattern_(const VarPatternNode* op) override;
+  void VisitDFPattern_(const WildcardPatternNode* op) override;
+
+ protected:
+  // set of already-visited nodes
+  std::unordered_set<const Object*> visited_;
+};
+
+/*!
+ * \brief A Wrapper around a templated graph type
+ *  Holds a forward-backward indexed representation of the graph and a dominator tree representation
+ * of the graph
+ *
+ *  Class is Templated and the implementaiton is in the header file so we can analyis both DFPattern
+ * and Expr with the same infrastructure.
+ *
+ *  IndexedGraph should be instantiated thorught the CreateIndexedGraph utilities.
 
 Review comment:
   through

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408455749
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern> args) {
 
 Review comment:
   & missing? Not sure if it is intended

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r408668251
 
 

 ##########
 File path: include/tvm/relay/dataflow_pattern.h
 ##########
 @@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_pattern.h
+ * \brief A pattern language for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_
+#define TVM_RELAY_DATAFLOW_PATTERN_H_
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Base type of all dataflow patterns.
+ * \sa DFPattern
+ */
+class DFPatternNode : public Object {
+ public:
+  static constexpr const char* _type_key = "DFPatternNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow patterns.
+ * \sa DFPatternNode
+ */
+class DFPattern : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode);
+};
+
+/*!
+ * \brief Pattern for Relay Expression.
+ */
+class ExprPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The expression to match. */
+  Expr expr;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("expr", &expr);
+  }
+
+  static constexpr const char* _type_key = "relay.df_pattern.ExprPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a literal expression.
+ *
+ * \note Uses structural equality on expressions to check equality.
+ *
+ */
+class ExprPattern : public DFPattern {
+ public:
+  TVM_DLL ExprPattern(Expr expr);
+  TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode);
+};
+
+
+/*!
+ * \brief A Pattern to Match a Relay Variable
+ */
+class VarPattern;
+/*! \brief Container for Var */
+class VarPatternNode : public DFPatternNode {
+ public:
+  /*!
+   * \brief The name of the Var (optional).
+   */
+  std::string name;
+  /*!
+   * \brief type annotaion of the variable.
+   * This field records user provided type annotation of the Var.
+   * This field is optional and can be None.
+   */
+  Type type_annotation;
+
+  /*! \return The name hint of the variable */
+  const std::string& name_hint() const {
+    return name;
+  }
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("type_annotation", &type_annotation);
+  }
+
+  TVM_DLL static VarPattern make(std::string name_hint, Type type_annotation);
+
+  static constexpr const char* _type_key = "relay.df_pattern.VarPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode);
+};
+
+class VarPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode);
+};
+
+/*!
+ * \brief Call corresponds to operator invocation.
+ *  Corresponds to the operator in computational graph terminology.
+ */
+class CallPattern;
+/*! \brief CallPattern container. */
+class CallPatternNode : public DFPatternNode {
+ public:
+  /*!
+   * \brief The operator(function) being invoked
+   *
+   *  - It can be relay::Op which corresponds to the primitive operators.
+   *  - It can also be user defined functions (Function, GlobalVar, Var).
+   */
+  DFPattern op;
+
+  /*! \brief The arguments(inputs) of the call */
+  tvm::Array<relay::DFPattern> args;
+
+  /*! \brief The additional attributes */
+  Attrs attrs;
+
+  /*!
+   * \brief The type arguments passed to polymorphic(template) function.
+   *
+   * This is the advance feature that is only used when the function is
+   * polymorphic. It is safe to be ignored in most cases. For example, in the
+   * following code, the type_args of addone call is [int].
+   *
+   * \code
+   *
+   * template<typename T>
+   * T addone(T a) { return a + 1; }
+   *
+   * void main() {
+   *   int x = addone<int>(10);
+   * }
+   *
+   * \endcode
+   */
+  tvm::Array<Type> type_args;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("op", &op);
+    v->Visit("args", &args);
+    v->Visit("attrs", &attrs);
+    v->Visit("type_args", &type_args);
+  }
+
+  TVM_DLL static CallPattern make(DFPattern op, Array<DFPattern> args, Attrs attrs,
+                                  Array<Type> type_args);
+
+  static constexpr const char* _type_key = "relay.df_pattern.CallPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode);
+};
+
+class CallPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode);
+};
+
+/*! \brief Tuple of multiple Exprs */
+class TuplePattern;
+/*! \brief Tuple container */
+class TuplePatternNode : public DFPatternNode {
+ public:
+  /*! \brief the fields of the tuple */
+  tvm::Array<DFPattern> fields;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("fields", &fields);
+  }
+
+  TVM_DLL static TuplePattern make(tvm::Array<DFPattern> fields);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TuplePattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode);
+};
+
+class TuplePattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode);
+};
+
+/*! \brief Get index-th field out of a tuple. */
+class TupleGetItemPattern;
+class TupleGetItemPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The tuple Expression */
+  DFPattern tuple;
+  /*! \brief which value to get */
+  int index;
+
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("tuple_value", &tuple);
+  }
+
+  TVM_DLL static TupleGetItemPattern make(DFPattern tuple, int index);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TupleGetItemPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode);
+};
+
+class TupleGetItemPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode);
+};
+
+class AltPattern;
+/*!
+ * \brief Pattern for Alternate Expressions.
+ */
+class AltPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The left optional pattern. */
+  DFPattern left;
+  /*! \brief The right optional pattern. */
+  DFPattern right;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("left", &left);
+    v->Visit("right", &right);
+  }
+
+  TVM_DLL static AltPattern make(DFPattern left, DFPattern right);
+
+  static constexpr const char* _type_key = "relay.df_pattern.AltPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches either of two patterns
+ */
+class AltPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode);
+};
+
+
+/*!
+ * \brief Wildcard Pattern.
+ */
+class WildcardPatternNode : public DFPatternNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "relay.df_pattern.WildcardPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches anything.
+ */
+class WildcardPattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
+};
+
+class TypePattern;
+/*!
+ * \brief Pattern for Types.
+ */
+class TypePatternNode : public DFPatternNode {
+ public:
+  /*! \brief The pattern. */
+  DFPattern pattern;
+  /*! \brief The type to match */
+  Type type;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("type", &type);
+  }
+
+  TVM_DLL static TypePattern make(DFPattern pattern, Type type);
+
+  static constexpr const char* _type_key = "relay.df_pattern.TypePattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
+ */
+class TypePattern : public DFPattern {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
+};
+
+class AttrPattern;
+/*!
+ * \brief Pattern for Types.
 
 Review comment:
   Attributes

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409900186
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> FindDominated(const DFPattern& node);
+  bool FindParent(const Expr& expr,
+                  const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs,
+                  const DominatorPatternNode* op);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
 
 Review comment:
   remove this blank to make it consistent with to other autos above/below

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409842978
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  friend DominatorMatcher;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Friend class to do recursive dominator matching
+class DominatorMatcher {
+ public:
+  DominatorMatcher(DFPatternMatcher* matcher, const DominatorPatternNode* op, const Expr& expr)
+      : matcher_(matcher), op_(op), expr_(expr) {
+    watermark_ = matcher_->matched_nodes_.size();
+    pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  }
+  bool Match() {
+    if (matcher_->VisitDFPattern(op_->child, expr_)) {
+      auto dominated_exprs = FindDominated(op_->child);
+      matcher_->ClearMap(watermark_);
+
+      bool matches = FindParent(expr_, dominated_exprs);
+      if (matches) {
+        matcher_->ClearMap(watermark_);
+        matcher_->memo_[op_->child] = expr_;
+        matcher_->matched_nodes_.push_back(op_->child);
+      }
+      return matches;
+    }
+    return false;
+  }
+
+ protected:
+  DFPatternMatcher* matcher_;
+  const DominatorPatternNode* op_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  Expr expr_;
+  size_t watermark_;
 
 Review comment:
   ok if you don't want to add `FindParent` in the `DFPatternMatcher`, I don't have preference over the friend class or a recursive lambda solution.
   
   I can see the value of decoupling the dominator matching logic from the main matcher, but since `DFPatternMatcher` is a private class in this .cc, I think it is fine to put `FindParent` there as an additional impl detail to handle dominator patterns. `FindDominated` can be lifted to a free function if you pass `memo_`, so it doesn't need to be a part of `DFPatternMatcher`. 
   
   Adding `FindParent` (perhaps with a better name) as a private member function would be definitely lighter-weight and less controversial change than introducing friend class. But this is not a strong opinion, I think decoupling of `DominatorMatcher`, which requires more complicated logic, at the cost of a friend is reasonable.
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jroesch commented on issue #5231: [WIP][POC] Pattern Language and Matcher

Posted by GitBox <gi...@apache.org>.
jroesch commented on issue #5231: [WIP][POC] Pattern Language and Matcher
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-608621546
 
 
   cc @jonso4 you asked and it is delivered 😆 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409902404
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> FindDominated(const DFPattern& node);
+  bool FindParent(const Expr& expr,
+                  const std::unordered_set<Expr, ObjectHash, ObjectEqual>& dominated_exprs,
+                  const DominatorPatternNode* op);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
 
 Review comment:
   return true

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, and Rewriter V0
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r409322475
 
 

 ##########
 File path: src/relay/ir/dataflow_matcher.cc
 ##########
 @@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  bool Match(const DFPattern& pattern, const Expr& expr);
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  };
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memo_.count(pattern)) {
+    return expr.same_as(memo_[pattern]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern] = expr;
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply")) {
+            if (is_expr_op(expr, "multiply")) {
+              if (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide")) {
+                bool out = false;
+                for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+                  auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                                   op->attrs, op->type_args);
+                  auto mul =
+                      CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                            arg_node->attrs, arg_node->type_args);
+                  out = VisitDFPattern(mul, expr);
+                  if (out) {
+                    return out;
+                  } else {
+                    ClearMap(watermark);
+                  }
+                }
+                return out;
+              }
+            }
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide")) {
+              if (is_expr_op(expr, "divide")) {
+                if (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply")) {
+                  auto mul =
+                      CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                            op->attrs, op->type_args);
+                  auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                                   arg_node->attrs, arg_node->type_args);
+                  return VisitDFPattern(div, expr);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  auto watermark = matched_nodes_.size();
+  auto backup_memo = memo_;
+  auto backup_matched_nodes = matched_nodes_;
+
+  if (VisitDFPattern(op->child, expr)) {
+    auto child_graph = CreateIndexedGraph(GetRef<DFPattern>(op));
+    auto expr_graph = CreateIndexedGraph(expr);
+    auto find_dominated = [&child_graph, this](const DFPattern& node) {
+      std::unordered_set<Expr, ObjectHash, ObjectEqual> dominated_exprs;
+      auto indexed_node = child_graph.node_map_[node];
+      for (auto dominated : indexed_node->dominator_children_) {
+        if (dominated->ref_.as<WildcardPatternNode>() || dominated->ref_.as<OpNode>()) {
+          continue;
+        }
+        dominated_exprs.insert(memo_[dominated->ref_]);
+      }
+      return dominated_exprs;
+    };
+    std::function<bool(const Expr&, const std::unordered_set<Expr, ObjectHash, ObjectEqual>&)>
+        find_parent;
+    find_parent = [this, &op, &watermark, &backup_memo, &backup_matched_nodes, &find_dominated,
 
 Review comment:
   How about making `find_parent` a member function of `DFPatternMatcher`? I attempted this in https://github.com/masahi/tvm/commit/c41a0e5577ae4fc549b2c7dc8c0daeaf0d011623 
   
   At least you can remove `backup_memo` and `backup_matched_nodes` entirely. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r419596978



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;
+/*!
+ * \brief Base type of all dataflow pattern callbacks.
+ * \sa DFPatternCallback
+ */
+class DFPatternCallbackNode : public Object {
+ public:
+  /*! \brief Pattern this callback matches */
+  DFPattern pattern_;
+  /*! \brief Function to call when finding a matched expression */
+  PackedFunc function_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  TVM_DLL static DFPatternCallback make(DFPattern pattern, PackedFunc callback);

Review comment:
       we now actually prefer constructors in ObjectRef instead of using "make". Same to all other new nodes.

##########
File path: include/tvm/relay/dataflow_pattern.h
##########
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_pattern.h
+ * \brief A pattern language for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_
+#define TVM_RELAY_DATAFLOW_PATTERN_H_
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Base type of all dataflow patterns.
+ * \sa DFPattern
+ */
+class DFPatternNode : public Object {
+ public:
+  static constexpr const char* _type_key = "DFPatternNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow patterns.
+ * \sa DFPatternNode
+ */
+class DFPattern : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode);
+};
+
+/*!
+ * \brief Pattern for Relay Expression.
+ */
+class ExprPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The expression to match. */
+  Expr expr;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("expr", &expr);
+  }
+
+  static constexpr const char* _type_key = "relay.dataflow_pattern.ExprPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a literal expression.
+ *
+ * \note Uses structural equality on expressions to check equality.
+ *
+ */
+class ExprPattern : public DFPattern {
+ public:
+  TVM_DLL ExprPattern(Expr expr);
+  TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode);
+};
+
+
+/*!
+ * \brief A Pattern to Match a Relay Variable
+ */
+class VarPattern;
+/*! \brief Container for Var */
+class VarPatternNode : public DFPatternNode {
+ public:
+  /*!
+   * \brief The name of the Var (optional).
+   */
+  std::string name;

Review comment:
       Maybe we can use tvm::String instead of std::string as we are porting std::string in node system to the former.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen edited a comment on pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-618705083


   cc @junrushao1994 @ajtulloch @u99127 @yzhliu @abcdabcd987


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r419582381



##########
File path: include/tvm/relay/dataflow_functor.h
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_
+#define TVM_RELAY_DATAFLOW_FUNCTOR_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief A dynamical functor that dispatches on in the first DFPattern argument.
+ *
+ * \tparam FType function signiture
+ *  This type is only defined for FType with function signature R(const DFPattern&,
+ * Args...)
+ */
+template <typename FType>
+class DFPatternFunctor;
+
+// functions to be overriden.
+#define DFPATTERN_FUNCTOR_DEFAULT \
+  { return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }
+
+#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP)                                                    \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {          \
+    return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+  });
+
+template <typename R, typename... Args>
+class DFPatternFunctor<R(const DFPattern& n, Args...)> {
+ private:
+  using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*! \brief virtual destructor */
+  virtual ~DFPatternFunctor() {}
+  /*!
+   * \brief Same as call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  R operator()(const DFPattern& n, Args... args) {
+    return VisitDFPattern(n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief The functor call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  virtual R VisitDFPattern(const DFPattern& n, Args... args) {
+    CHECK(n.defined());
+    static FType vtable = InitVTable();
+    return vtable(n, this, std::forward<Args>(args)...);
+  }
+  // Functions that can be overriden by subclass
+  virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
+                            Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPatternDefault_(const Object* op, Args...) {
+    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
+    throw;
+  }
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+    FType vtable;
+    // Set dispatch
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
+    return vtable;
+  }
+};
+
+/*!
+ * \brief A simple visitor wrapper around DFPatternFunctor.
+ *  Recursively visit the content.
+ *
+ *  DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once.
+ */
+class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
+ public:
+  void VisitDFPattern(const DFPattern& pattern) override;
+  void VisitDFPattern_(const AltPatternNode* op) override;
+  void VisitDFPattern_(const AttrPatternNode* op) override;
+  void VisitDFPattern_(const CallPatternNode* op) override;
+  void VisitDFPattern_(const DominatorPatternNode* op) override;
+  void VisitDFPattern_(const ExprPatternNode* op) override;
+  void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
+  void VisitDFPattern_(const TuplePatternNode* op) override;
+  void VisitDFPattern_(const TypePatternNode* op) override;
+  void VisitDFPattern_(const VarPatternNode* op) override;
+  void VisitDFPattern_(const WildcardPatternNode* op) override;
+
+ protected:
+  // set of already-visited nodes
+  std::unordered_set<const Object*> visited_;
+};
+
+/*!
+ * \brief A Wrapper around a templated graph type
+ *  Holds a forward-backward indexed representation of the graph and a dominator tree representation
+ * of the graph
+ *
+ *  This class is templated and the implementaiton is in the header file so we can analyze both
+ * DFPattern and Expr with the same infrastructure.
+ *
+ *  IndexedGraph should be instantiated through the CreateIndexedGraph utilities.
+ */
+template <typename T>

Review comment:
       Moved the implementation to a private header file. I don't think I see the advantage to making Node an Object.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-625934854


   @zhiics Thanks for the comments! completed those refactors.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424840667



##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <stack>
+
+#include "indexed_graph.h"
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            CHECK(false) << "Unsupported type in Type Pattern Node";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+              (is_expr_op(call_node->args[0], "divide") ||
+               is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs,
+                                     op->type_args);
+              auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                     arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                (is_expr_op(call_node->args[0], "multiply") ||
+                 is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                     op->attrs, op->type_args);
+              auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs,
+                                     arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) {
+          return false;
+        }
+      }
+    }
+  }
+  return true;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+Expr InferType(const Expr& expr) {
+  auto mod = IRModule::FromExpr(expr);
+  mod = transform::InferType()(mod);
+  if (expr.as<FunctionNode>()) {
+    return mod->Lookup("main");
+  } else {
+    return mod->Lookup("main").as<FunctionNode>()->body;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match")
+    .set_body_typed([](DFPattern pattern, Expr expr) {
+      return DFPatternMatcher(expr).Match(pattern, expr);
+    });
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarily needed to support the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assignments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {

Review comment:
       How about returning `groups_` here and removing `GetGroups()`? Given that `GetGroups` always immediately follows `GroupMatches` 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen edited a comment on pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-618704483


   Thanks @mbrookhart now that we have a concrete POC, it would be nice to have another round of ABI review with the folks, possibly open another thread at the dicuss forum to provide examples about what the relay.dataflow_pattern can do so far and get feedbacks about API choices.
   
   I think the design choices that would be in particular interesting are:
   - The dominator pattern API examples(since that was not very well covered)
   - The API of match, rewrite and partition
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r414451866



##########
File path: include/tvm/relay/dataflow_functor.h
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_
+#define TVM_RELAY_DATAFLOW_FUNCTOR_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief A dynamical functor that dispatches on in the first DFPattern argument.
+ *
+ * \tparam FType function signiture

Review comment:
       signature

##########
File path: docs/langref/relay_pattern.rst
##########
@@ -0,0 +1,143 @@
+..  Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+..    http://www.apache.org/licenses/LICENSE-2.0
+
+..  Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+
+=========================
+Pattern Matching in Relay
+=========================
+
+There are many places in TVM where we identify pure data-flow sub-graphs of the Relay program and attempt to transform them in some way example passes include fusion, quantization, external code generation, and device specific optimizations such as bitpacking, and layer slicing used by VTA. 
+
+Many of these passes today require a lots of boring boilerplate code in order to implement as well as requiring users to think in terms of visitors and AST matching. Many of these transformations can easily be described in terms of graph rewrites. In order to build a rewriter or other advanced machinery we first need a language of patterns to describe what we can match. 
+
+Such a language is not just useful for building a rewriter but also providing extension points for existing passes. For example the fusion pass could be parametrized by a set of fusion patterns which describes the capability of your hardware, and the quantization pass could take a set of patterns which describe which operators can be quantized on a given platform.

Review comment:
       parameterized

##########
File path: include/tvm/relay/dataflow_pattern.h
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_pattern.h
+ * \brief A pattern language for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_
+#define TVM_RELAY_DATAFLOW_PATTERN_H_
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Base type of all dataflow patterns.
+ * \sa DFPattern
+ */
+class DFPatternNode : public Object {
+ public:
+  static constexpr const char* _type_key = "DFPatternNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow patterns.
+ * \sa DFPatternNode
+ */
+class DFPattern : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode);
+};
+
+/*!
+ * \brief Pattern for Relay Expression.
+ */
+class ExprPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The expression to match. */
+  Expr expr;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("expr", &expr);
+  }
+
+  static constexpr const char* _type_key = "relay.df_pattern.ExprPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a literal expression.
+ *
+ * \note Uses structural equality on expressions to check equality.
+ *
+ */
+class ExprPattern : public DFPattern {
+ public:
+  TVM_DLL ExprPattern(Expr expr);
+  TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode);
+};
+
+
+/*!
+ * \brief A Pattern to Match a Relay Variable
+ */
+class VarPattern;
+/*! \brief Container for Var */
+class VarPatternNode : public DFPatternNode {
+ public:
+  /*!
+   * \brief The name of the Var (optional).
+   */
+  std::string name;
+  /*!
+   * \brief type annotaion of the variable.

Review comment:
       annotation

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {

Review comment:
       Should there be a check against the Var type annotation if it's present in the pattern?

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
+  return DFPatternMatcher(expr).Match(pattern, expr);
+});
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarly needed to suppor the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assingnments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {
+    groups_ = {Group()};
+    gid_assignments_.clear();
+    visit_counter_.clear();
+
+    pattern_ = pattern;
+    pattern_graph_ = CreateIndexedGraph(pattern_);
+    auto matcher = DFPatternMatcher(pre);
+    matcher_ = &matcher;
+    this->VisitExpr(pre);
+  }
+
+ protected:
+  void VisitLeaf(const Expr& pre) override {
+    if (matcher_->Match(pattern_, pre)) {
+      CreateGroup(pre);
+    }
+  }
+
+  /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
+   * group overlap analysis */
+  class MatchExtractor : public ExprMutator {
+   public:
+    explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
+        : inputs_(inputs) {}
+    const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
+
+   protected:
+    Expr VisitExpr(const Expr& pre) override {
+      if (inputs_.count(pre)) {
+        return inputs_.at(pre);
+      }
+      return ExprMutator::VisitExpr(pre);
+    }
+    const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
+  };
+
+  /* \brief Create a group based on a matched expression */
+  void CreateGroup(const Expr& expr) {
+    var_number_ = 0;
+
+    auto node_map = matcher_->GetMemo();
+
+    // Get fuzzy patterns
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> fuzzy_matches;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (auto op = node->ref_.as<DominatorPatternNode>()) {
+        for (auto fuzzy_op : {op->parent, op->path}) {
+          for (auto match : node_map[fuzzy_op]) {
+            fuzzy_matches.insert(match);
+          }
+        }
+      }
+    }
+
+    // Create input variables
+    Group group;
+    group.root_node = expr;
+    group.matched_nodes = node_map;
+
+    std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs;
+    Array<Var> params;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (node->inputs_.size() == 0) {
+        if (node_map.count(node->ref_)) {
+          auto matches = node_map[node->ref_];
+          for (auto match : matches) {
+            if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
+                match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
+              inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
+                                      std::to_string(var_number_),
+                                  NullValue<Type>());
+              group.args.push_back(match);
+              params.push_back(inputs[match]);
+              var_number_++;
+            }
+          }
+        }
+      }
+    }
+
+    graph_number_++;
+
+    // Extract a Function. Used in Parition directly,

Review comment:
       Partition

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
+  return DFPatternMatcher(expr).Match(pattern, expr);
+});
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarly needed to suppor the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assingnments of expressions */

Review comment:
       assignments

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
+  return DFPatternMatcher(expr).Match(pattern, expr);
+});
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarly needed to suppor the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assingnments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {
+    groups_ = {Group()};
+    gid_assignments_.clear();
+    visit_counter_.clear();
+
+    pattern_ = pattern;
+    pattern_graph_ = CreateIndexedGraph(pattern_);
+    auto matcher = DFPatternMatcher(pre);
+    matcher_ = &matcher;
+    this->VisitExpr(pre);
+  }
+
+ protected:
+  void VisitLeaf(const Expr& pre) override {
+    if (matcher_->Match(pattern_, pre)) {
+      CreateGroup(pre);
+    }
+  }
+
+  /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
+   * group overlap analysis */
+  class MatchExtractor : public ExprMutator {
+   public:
+    explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
+        : inputs_(inputs) {}
+    const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
+
+   protected:
+    Expr VisitExpr(const Expr& pre) override {
+      if (inputs_.count(pre)) {
+        return inputs_.at(pre);
+      }
+      return ExprMutator::VisitExpr(pre);
+    }
+    const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
+  };
+
+  /* \brief Create a group based on a matched expression */
+  void CreateGroup(const Expr& expr) {
+    var_number_ = 0;
+
+    auto node_map = matcher_->GetMemo();
+
+    // Get fuzzy patterns
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> fuzzy_matches;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (auto op = node->ref_.as<DominatorPatternNode>()) {
+        for (auto fuzzy_op : {op->parent, op->path}) {
+          for (auto match : node_map[fuzzy_op]) {
+            fuzzy_matches.insert(match);
+          }
+        }
+      }
+    }
+
+    // Create input variables
+    Group group;
+    group.root_node = expr;
+    group.matched_nodes = node_map;
+
+    std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs;
+    Array<Var> params;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (node->inputs_.size() == 0) {
+        if (node_map.count(node->ref_)) {
+          auto matches = node_map[node->ref_];
+          for (auto match : matches) {
+            if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
+                match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
+              inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
+                                      std::to_string(var_number_),
+                                  NullValue<Type>());
+              group.args.push_back(match);
+              params.push_back(inputs[match]);
+              var_number_++;
+            }
+          }
+        }
+      }
+    }
+
+    graph_number_++;
+
+    // Extract a Function. Used in Parition directly,
+    // used to determine Group overlap in other passes
+    auto extractor = MatchExtractor(inputs);
+    auto body = extractor.Mutate(expr);
+
+    // Verify the pattern still holds
+    CHECK(DFPatternMatcher(body).Match(pattern_, body));
+    group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
+
+    // Check to make sure we aren't overlapping with another group
+    for (auto kv : extractor.GetMemo()) {
+      if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 &&
+          kv.first.as<OpNode>() == nullptr && kv.first.as<FunctionNode>() == nullptr &&
+          kv.first.as<ConstantNode>() == nullptr) {
+        // Exit due to overlapping partitions
+        return;
+      }
+    }
+    // Assign Group Ids
+    group.gid = ++gid_;
+    for (auto kv : extractor.GetMemo()) {
+      gid_assignments_[kv.first] = gid_;
+    }
+
+    // Save Group
+    groups_.emplace_back(std::move(group));
+    CHECK_EQ(groups_[gid_].gid, gid_);
+  }
+
+  // Internal State
+  DFPattern pattern_;
+  std::vector<Group> groups_;
+  std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+  DFPatternMatcher* matcher_ = nullptr;
+  IndexedGraph<DFPattern> pattern_graph_;
+  int gid_ = 0;
+  int var_number_ = 0;
+  int graph_number_ = 0;
+};
+
+// Rewrite
+
+DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) {
+  ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
+  n->pattern_ = std::move(pattern);
+  n->function_ = std::move(function);
+  return DFPatternCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback")
+    .set_body_typed(DFPatternCallbackNode::make);
+
+/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
+ * function to rewrtie those matches
+ *
+ * The class uses PatternGrouper to support the dominator pattern.
+ */
+class PatternRewriter : protected MixedModeMutator {
+ public:
+  PatternRewriter() {}
+  /*! \brief Rewrite can take a number of callbakcs and will repeatedly rewrite the graph with the
+   * callbacks until it stops changing */
+  Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre) {
+    auto post = pre;
+    auto last = post;
+    // rewrite the graph until it stops changing to make sure all rewrites are complete
+    do {

Review comment:
       Might it make sense to have a maximum number of rewrite passes here in case of infinite mutation?

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
+  return DFPatternMatcher(expr).Match(pattern, expr);
+});
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarly needed to suppor the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assingnments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {
+    groups_ = {Group()};
+    gid_assignments_.clear();
+    visit_counter_.clear();
+
+    pattern_ = pattern;
+    pattern_graph_ = CreateIndexedGraph(pattern_);
+    auto matcher = DFPatternMatcher(pre);
+    matcher_ = &matcher;
+    this->VisitExpr(pre);
+  }
+
+ protected:
+  void VisitLeaf(const Expr& pre) override {
+    if (matcher_->Match(pattern_, pre)) {
+      CreateGroup(pre);
+    }
+  }
+
+  /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
+   * group overlap analysis */
+  class MatchExtractor : public ExprMutator {
+   public:
+    explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
+        : inputs_(inputs) {}
+    const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
+
+   protected:
+    Expr VisitExpr(const Expr& pre) override {
+      if (inputs_.count(pre)) {
+        return inputs_.at(pre);
+      }
+      return ExprMutator::VisitExpr(pre);
+    }
+    const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
+  };
+
+  /* \brief Create a group based on a matched expression */
+  void CreateGroup(const Expr& expr) {
+    var_number_ = 0;
+
+    auto node_map = matcher_->GetMemo();
+
+    // Get fuzzy patterns
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> fuzzy_matches;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (auto op = node->ref_.as<DominatorPatternNode>()) {
+        for (auto fuzzy_op : {op->parent, op->path}) {
+          for (auto match : node_map[fuzzy_op]) {
+            fuzzy_matches.insert(match);
+          }
+        }
+      }
+    }
+
+    // Create input variables
+    Group group;
+    group.root_node = expr;
+    group.matched_nodes = node_map;
+
+    std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs;
+    Array<Var> params;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (node->inputs_.size() == 0) {
+        if (node_map.count(node->ref_)) {
+          auto matches = node_map[node->ref_];
+          for (auto match : matches) {
+            if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
+                match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
+              inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
+                                      std::to_string(var_number_),
+                                  NullValue<Type>());
+              group.args.push_back(match);
+              params.push_back(inputs[match]);
+              var_number_++;
+            }
+          }
+        }
+      }
+    }
+
+    graph_number_++;
+
+    // Extract a Function. Used in Parition directly,
+    // used to determine Group overlap in other passes
+    auto extractor = MatchExtractor(inputs);
+    auto body = extractor.Mutate(expr);
+
+    // Verify the pattern still holds
+    CHECK(DFPatternMatcher(body).Match(pattern_, body));
+    group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
+
+    // Check to make sure we aren't overlapping with another group
+    for (auto kv : extractor.GetMemo()) {
+      if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 &&
+          kv.first.as<OpNode>() == nullptr && kv.first.as<FunctionNode>() == nullptr &&
+          kv.first.as<ConstantNode>() == nullptr) {
+        // Exit due to overlapping partitions
+        return;
+      }
+    }
+    // Assign Group Ids
+    group.gid = ++gid_;
+    for (auto kv : extractor.GetMemo()) {
+      gid_assignments_[kv.first] = gid_;
+    }
+
+    // Save Group
+    groups_.emplace_back(std::move(group));
+    CHECK_EQ(groups_[gid_].gid, gid_);
+  }
+
+  // Internal State
+  DFPattern pattern_;
+  std::vector<Group> groups_;
+  std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+  DFPatternMatcher* matcher_ = nullptr;
+  IndexedGraph<DFPattern> pattern_graph_;
+  int gid_ = 0;
+  int var_number_ = 0;
+  int graph_number_ = 0;
+};
+
+// Rewrite
+
+DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) {
+  ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
+  n->pattern_ = std::move(pattern);
+  n->function_ = std::move(function);
+  return DFPatternCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback")
+    .set_body_typed(DFPatternCallbackNode::make);
+
+/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
+ * function to rewrtie those matches
+ *
+ * The class uses PatternGrouper to support the dominator pattern.
+ */
+class PatternRewriter : protected MixedModeMutator {
+ public:
+  PatternRewriter() {}
+  /*! \brief Rewrite can take a number of callbakcs and will repeatedly rewrite the graph with the

Review comment:
       callbacks

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
+  return DFPatternMatcher(expr).Match(pattern, expr);
+});
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarly needed to suppor the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assingnments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {
+    groups_ = {Group()};
+    gid_assignments_.clear();
+    visit_counter_.clear();
+
+    pattern_ = pattern;
+    pattern_graph_ = CreateIndexedGraph(pattern_);
+    auto matcher = DFPatternMatcher(pre);
+    matcher_ = &matcher;
+    this->VisitExpr(pre);
+  }
+
+ protected:
+  void VisitLeaf(const Expr& pre) override {
+    if (matcher_->Match(pattern_, pre)) {
+      CreateGroup(pre);
+    }
+  }
+
+  /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
+   * group overlap analysis */
+  class MatchExtractor : public ExprMutator {
+   public:
+    explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
+        : inputs_(inputs) {}
+    const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
+
+   protected:
+    Expr VisitExpr(const Expr& pre) override {
+      if (inputs_.count(pre)) {
+        return inputs_.at(pre);
+      }
+      return ExprMutator::VisitExpr(pre);
+    }
+    const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
+  };
+
+  /* \brief Create a group based on a matched expression */
+  void CreateGroup(const Expr& expr) {
+    var_number_ = 0;
+
+    auto node_map = matcher_->GetMemo();
+
+    // Get fuzzy patterns
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> fuzzy_matches;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (auto op = node->ref_.as<DominatorPatternNode>()) {
+        for (auto fuzzy_op : {op->parent, op->path}) {
+          for (auto match : node_map[fuzzy_op]) {
+            fuzzy_matches.insert(match);
+          }
+        }
+      }
+    }
+
+    // Create input variables
+    Group group;
+    group.root_node = expr;
+    group.matched_nodes = node_map;
+
+    std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs;
+    Array<Var> params;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (node->inputs_.size() == 0) {
+        if (node_map.count(node->ref_)) {
+          auto matches = node_map[node->ref_];
+          for (auto match : matches) {
+            if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
+                match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
+              inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
+                                      std::to_string(var_number_),
+                                  NullValue<Type>());
+              group.args.push_back(match);
+              params.push_back(inputs[match]);
+              var_number_++;
+            }
+          }
+        }
+      }
+    }
+
+    graph_number_++;
+
+    // Extract a Function. Used in Parition directly,
+    // used to determine Group overlap in other passes
+    auto extractor = MatchExtractor(inputs);
+    auto body = extractor.Mutate(expr);
+
+    // Verify the pattern still holds
+    CHECK(DFPatternMatcher(body).Match(pattern_, body));
+    group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
+
+    // Check to make sure we aren't overlapping with another group
+    for (auto kv : extractor.GetMemo()) {
+      if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 &&
+          kv.first.as<OpNode>() == nullptr && kv.first.as<FunctionNode>() == nullptr &&
+          kv.first.as<ConstantNode>() == nullptr) {
+        // Exit due to overlapping partitions
+        return;
+      }
+    }
+    // Assign Group Ids
+    group.gid = ++gid_;
+    for (auto kv : extractor.GetMemo()) {
+      gid_assignments_[kv.first] = gid_;
+    }
+
+    // Save Group
+    groups_.emplace_back(std::move(group));
+    CHECK_EQ(groups_[gid_].gid, gid_);
+  }
+
+  // Internal State
+  DFPattern pattern_;
+  std::vector<Group> groups_;
+  std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+  DFPatternMatcher* matcher_ = nullptr;
+  IndexedGraph<DFPattern> pattern_graph_;
+  int gid_ = 0;
+  int var_number_ = 0;
+  int graph_number_ = 0;
+};
+
+// Rewrite
+
+DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) {
+  ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
+  n->pattern_ = std::move(pattern);
+  n->function_ = std::move(function);
+  return DFPatternCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback")
+    .set_body_typed(DFPatternCallbackNode::make);
+
+/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
+ * function to rewrtie those matches
+ *
+ * The class uses PatternGrouper to support the dominator pattern.
+ */
+class PatternRewriter : protected MixedModeMutator {
+ public:
+  PatternRewriter() {}
+  /*! \brief Rewrite can take a number of callbakcs and will repeatedly rewrite the graph with the
+   * callbacks until it stops changing */
+  Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre) {
+    auto post = pre;
+    auto last = post;
+    // rewrite the graph until it stops changing to make sure all rewrites are complete
+    do {
+      last = post;
+      for (auto callback : callbacks) {
+        callback_ = callback;
+        auto grouper = PatternGrouper();
+        grouper.GroupMatches(callback_->pattern_, post);
+        groups_ = grouper.GetGroups();
+        gid_assignments_ = grouper.GetGIDAssignments();
+        memo_.clear();
+        post = this->VisitExpr(post);
+      }
+    } while (last != post);
+    return post;
+  }
+
+ protected:
+  Expr DispatchVisitExpr(const Expr& pre) override {
+    auto post = MixedModeMutator::DispatchVisitExpr(pre);
+    if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) {
+      // Convert the pre-rewrite node map to a post-rewrite node map
+      auto group = groups_[gid_assignments_[pre]];
+      std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> node_map;
+      for (auto kv : group.matched_nodes) {
+        Array<Expr> tmp;
+        for (size_t i = 0; i < kv.second.size(); ++i) {
+          tmp.push_back(this->memo_[kv.second[i]]);
+        }
+        node_map.insert({kv.first, tmp});
+      }
+      // run the user callback function
+      return callback_->function_(pre, post, Map<DFPattern, Array<Expr>>(node_map));
+    }
+    return post;
+  }
+
+  DFPatternCallback callback_;
+  std::vector<PatternGrouper::Group> groups_;
+  std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+};
+
+Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr) {
+  return PatternRewriter().Rewrite(callbacks, expr);
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns);
+
+/* \brief PatternParitioner replaces expressions that match a pattern with function call that

Review comment:
       PatternPartitioner

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
+  return DFPatternMatcher(expr).Match(pattern, expr);
+});
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarly needed to suppor the post-dominator analysis required for dominator pattern

Review comment:
       primarily/support

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
+  return DFPatternMatcher(expr).Match(pattern, expr);
+});
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarly needed to suppor the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assingnments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {
+    groups_ = {Group()};
+    gid_assignments_.clear();
+    visit_counter_.clear();
+
+    pattern_ = pattern;
+    pattern_graph_ = CreateIndexedGraph(pattern_);
+    auto matcher = DFPatternMatcher(pre);
+    matcher_ = &matcher;
+    this->VisitExpr(pre);
+  }
+
+ protected:
+  void VisitLeaf(const Expr& pre) override {
+    if (matcher_->Match(pattern_, pre)) {
+      CreateGroup(pre);
+    }
+  }
+
+  /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
+   * group overlap analysis */
+  class MatchExtractor : public ExprMutator {
+   public:
+    explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
+        : inputs_(inputs) {}
+    const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
+
+   protected:
+    Expr VisitExpr(const Expr& pre) override {
+      if (inputs_.count(pre)) {
+        return inputs_.at(pre);
+      }
+      return ExprMutator::VisitExpr(pre);
+    }
+    const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
+  };
+
+  /* \brief Create a group based on a matched expression */
+  void CreateGroup(const Expr& expr) {
+    var_number_ = 0;
+
+    auto node_map = matcher_->GetMemo();
+
+    // Get fuzzy patterns
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> fuzzy_matches;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (auto op = node->ref_.as<DominatorPatternNode>()) {
+        for (auto fuzzy_op : {op->parent, op->path}) {
+          for (auto match : node_map[fuzzy_op]) {
+            fuzzy_matches.insert(match);
+          }
+        }
+      }
+    }
+
+    // Create input variables
+    Group group;
+    group.root_node = expr;
+    group.matched_nodes = node_map;
+
+    std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs;
+    Array<Var> params;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (node->inputs_.size() == 0) {
+        if (node_map.count(node->ref_)) {
+          auto matches = node_map[node->ref_];
+          for (auto match : matches) {
+            if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
+                match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
+              inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
+                                      std::to_string(var_number_),
+                                  NullValue<Type>());
+              group.args.push_back(match);
+              params.push_back(inputs[match]);
+              var_number_++;
+            }
+          }
+        }
+      }
+    }
+
+    graph_number_++;
+
+    // Extract a Function. Used in Parition directly,
+    // used to determine Group overlap in other passes
+    auto extractor = MatchExtractor(inputs);
+    auto body = extractor.Mutate(expr);
+
+    // Verify the pattern still holds
+    CHECK(DFPatternMatcher(body).Match(pattern_, body));
+    group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
+
+    // Check to make sure we aren't overlapping with another group
+    for (auto kv : extractor.GetMemo()) {
+      if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 &&
+          kv.first.as<OpNode>() == nullptr && kv.first.as<FunctionNode>() == nullptr &&
+          kv.first.as<ConstantNode>() == nullptr) {
+        // Exit due to overlapping partitions
+        return;
+      }
+    }
+    // Assign Group Ids
+    group.gid = ++gid_;
+    for (auto kv : extractor.GetMemo()) {
+      gid_assignments_[kv.first] = gid_;
+    }
+
+    // Save Group
+    groups_.emplace_back(std::move(group));
+    CHECK_EQ(groups_[gid_].gid, gid_);
+  }
+
+  // Internal State
+  DFPattern pattern_;
+  std::vector<Group> groups_;
+  std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+  DFPatternMatcher* matcher_ = nullptr;
+  IndexedGraph<DFPattern> pattern_graph_;
+  int gid_ = 0;
+  int var_number_ = 0;
+  int graph_number_ = 0;
+};
+
+// Rewrite
+
+DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) {
+  ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
+  n->pattern_ = std::move(pattern);
+  n->function_ = std::move(function);
+  return DFPatternCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback")
+    .set_body_typed(DFPatternCallbackNode::make);
+
+/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
+ * function to rewrtie those matches
+ *
+ * The class uses PatternGrouper to support the dominator pattern.
+ */
+class PatternRewriter : protected MixedModeMutator {
+ public:
+  PatternRewriter() {}
+  /*! \brief Rewrite can take a number of callbakcs and will repeatedly rewrite the graph with the
+   * callbacks until it stops changing */
+  Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre) {
+    auto post = pre;
+    auto last = post;
+    // rewrite the graph until it stops changing to make sure all rewrites are complete
+    do {
+      last = post;
+      for (auto callback : callbacks) {
+        callback_ = callback;
+        auto grouper = PatternGrouper();
+        grouper.GroupMatches(callback_->pattern_, post);
+        groups_ = grouper.GetGroups();
+        gid_assignments_ = grouper.GetGIDAssignments();
+        memo_.clear();
+        post = this->VisitExpr(post);
+      }
+    } while (last != post);
+    return post;
+  }
+
+ protected:
+  Expr DispatchVisitExpr(const Expr& pre) override {
+    auto post = MixedModeMutator::DispatchVisitExpr(pre);
+    if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) {
+      // Convert the pre-rewrite node map to a post-rewrite node map
+      auto group = groups_[gid_assignments_[pre]];
+      std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> node_map;
+      for (auto kv : group.matched_nodes) {
+        Array<Expr> tmp;
+        for (size_t i = 0; i < kv.second.size(); ++i) {
+          tmp.push_back(this->memo_[kv.second[i]]);
+        }
+        node_map.insert({kv.first, tmp});
+      }
+      // run the user callback function
+      return callback_->function_(pre, post, Map<DFPattern, Array<Expr>>(node_map));
+    }
+    return post;
+  }
+
+  DFPatternCallback callback_;
+  std::vector<PatternGrouper::Group> groups_;
+  std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+};
+
+Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr) {
+  return PatternRewriter().Rewrite(callbacks, expr);
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite").set_body_typed(RewritePatterns);
+
+/* \brief PatternParitioner replaces expressions that match a pattern with function call that
+ * perform the same computation but allow for further analysis and lowering.
+ *
+ * The class uses PatternGrouper to support the dominator pattern.
+ */
+class PatternPartitioner : protected MixedModeMutator {
+ public:
+  Expr Partition(const DFPattern& pattern, const Expr& pre) {
+    auto grouper = PatternGrouper();
+    grouper.GroupMatches(pattern, pre);
+    groups_ = grouper.GetGroups();
+    gid_assignments_ = grouper.GetGIDAssignments();
+    return this->VisitExpr(pre);
+  }
+
+ protected:
+  Expr RewriteParition(const PatternGrouper::Group& group) {

Review comment:
       RewritePartition

##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,634 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <stack>
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            throw "Unsupported type";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+                 (is_expr_op(call_node->args[0], "divide") ||
+                  is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPatternNode::make(op->op, {arg_node->args[arg_id], op->args[1]},
+                                               op->attrs, op->type_args);
+              auto mul =
+                  CallPatternNode::make(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                        arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                   (is_expr_op(call_node->args[0], "multiply") ||
+                    is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul =
+                  CallPatternNode::make(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                        op->attrs, op->type_args);
+              auto div = CallPatternNode::make(arg_node->op, {mul, arg_node->args[1]},
+                                               arg_node->attrs, arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  bool out = true;
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (VisitDFPattern(op->path, node->ref_)) {
+          out &= MatchesPath(op, node->ref_);
+        } else {
+          return false;
+        }
+      }
+    }
+  }
+  return out;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
+  return DFPatternMatcher(expr).Match(pattern, expr);
+});
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarly needed to suppor the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assingnments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {
+    groups_ = {Group()};
+    gid_assignments_.clear();
+    visit_counter_.clear();
+
+    pattern_ = pattern;
+    pattern_graph_ = CreateIndexedGraph(pattern_);
+    auto matcher = DFPatternMatcher(pre);
+    matcher_ = &matcher;
+    this->VisitExpr(pre);
+  }
+
+ protected:
+  void VisitLeaf(const Expr& pre) override {
+    if (matcher_->Match(pattern_, pre)) {
+      CreateGroup(pre);
+    }
+  }
+
+  /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
+   * group overlap analysis */
+  class MatchExtractor : public ExprMutator {
+   public:
+    explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
+        : inputs_(inputs) {}
+    const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
+
+   protected:
+    Expr VisitExpr(const Expr& pre) override {
+      if (inputs_.count(pre)) {
+        return inputs_.at(pre);
+      }
+      return ExprMutator::VisitExpr(pre);
+    }
+    const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
+  };
+
+  /* \brief Create a group based on a matched expression */
+  void CreateGroup(const Expr& expr) {
+    var_number_ = 0;
+
+    auto node_map = matcher_->GetMemo();
+
+    // Get fuzzy patterns
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> fuzzy_matches;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (auto op = node->ref_.as<DominatorPatternNode>()) {
+        for (auto fuzzy_op : {op->parent, op->path}) {
+          for (auto match : node_map[fuzzy_op]) {
+            fuzzy_matches.insert(match);
+          }
+        }
+      }
+    }
+
+    // Create input variables
+    Group group;
+    group.root_node = expr;
+    group.matched_nodes = node_map;
+
+    std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs;
+    Array<Var> params;
+    for (auto node : pattern_graph_.topological_order_) {
+      if (node->inputs_.size() == 0) {
+        if (node_map.count(node->ref_)) {
+          auto matches = node_map[node->ref_];
+          for (auto match : matches) {
+            if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
+                match.as<FunctionNode>() == nullptr && match.as<ConstantNode>() == nullptr) {
+              inputs[match] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
+                                      std::to_string(var_number_),
+                                  NullValue<Type>());
+              group.args.push_back(match);
+              params.push_back(inputs[match]);
+              var_number_++;
+            }
+          }
+        }
+      }
+    }
+
+    graph_number_++;
+
+    // Extract a Function. Used in Parition directly,
+    // used to determine Group overlap in other passes
+    auto extractor = MatchExtractor(inputs);
+    auto body = extractor.Mutate(expr);
+
+    // Verify the pattern still holds
+    CHECK(DFPatternMatcher(body).Match(pattern_, body));
+    group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
+
+    // Check to make sure we aren't overlapping with another group
+    for (auto kv : extractor.GetMemo()) {
+      if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 &&
+          kv.first.as<OpNode>() == nullptr && kv.first.as<FunctionNode>() == nullptr &&
+          kv.first.as<ConstantNode>() == nullptr) {
+        // Exit due to overlapping partitions
+        return;
+      }
+    }
+    // Assign Group Ids
+    group.gid = ++gid_;
+    for (auto kv : extractor.GetMemo()) {
+      gid_assignments_[kv.first] = gid_;
+    }
+
+    // Save Group
+    groups_.emplace_back(std::move(group));
+    CHECK_EQ(groups_[gid_].gid, gid_);
+  }
+
+  // Internal State
+  DFPattern pattern_;
+  std::vector<Group> groups_;
+  std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+  DFPatternMatcher* matcher_ = nullptr;
+  IndexedGraph<DFPattern> pattern_graph_;
+  int gid_ = 0;
+  int var_number_ = 0;
+  int graph_number_ = 0;
+};
+
+// Rewrite
+
+DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) {
+  ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
+  n->pattern_ = std::move(pattern);
+  n->function_ = std::move(function);
+  return DFPatternCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);
+
+TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback")
+    .set_body_typed(DFPatternCallbackNode::make);
+
+/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
+ * function to rewrtie those matches

Review comment:
       rewrite




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r414721399



##########
File path: python/tvm/relay/df_pattern/__init__.py
##########
@@ -0,0 +1,488 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The Relay Pattern Language and tooling."""
+from tvm.relay import Expr
+from ...ir.base import Node
+from ...ir import make_node
+from ...runtime import Object
+from ... import _ffi as tvm_ffi
+from ..op import get
+from . import _ffi as ffi
+
+
+def register_df_node(type_key=None):
+    """Register a Relay node type.
+
+    Parameters
+    ----------
+    type_key : str or cls
+        The type key of the node.
+    """
+    if not isinstance(type_key, str):
+        return tvm_ffi.register_object(
+            "relay.df_pattern." + type_key.__name__)(type_key)
+    return tvm_ffi.register_object(type_key)

Review comment:
       @jroesch can you comment on this? This was one of your contributions to the python API.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-625942189


   @tqchen @mbaret care to take another look after I updated the code based on your comments?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424840667



##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <stack>
+
+#include "indexed_graph.h"
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            CHECK(false) << "Unsupported type in Type Pattern Node";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+              (is_expr_op(call_node->args[0], "divide") ||
+               is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs,
+                                     op->type_args);
+              auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                     arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                (is_expr_op(call_node->args[0], "multiply") ||
+                 is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                     op->attrs, op->type_args);
+              auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs,
+                                     arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) {
+          return false;
+        }
+      }
+    }
+  }
+  return true;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+Expr InferType(const Expr& expr) {
+  auto mod = IRModule::FromExpr(expr);
+  mod = transform::InferType()(mod);
+  if (expr.as<FunctionNode>()) {
+    return mod->Lookup("main");
+  } else {
+    return mod->Lookup("main").as<FunctionNode>()->body;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match")
+    .set_body_typed([](DFPattern pattern, Expr expr) {
+      return DFPatternMatcher(expr).Match(pattern, expr);
+    });
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarily needed to support the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assignments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {

Review comment:
       How about returning `groups_` here and removing `GetGroups()`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424846955



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;
+/*!
+ * \brief Base type of all dataflow pattern callbacks.
+ * \sa DFPatternCallback
+ */
+class DFPatternCallbackNode : public Object {
+ public:
+  /*! \brief Pattern this callback matches */
+  DFPattern pattern_;
+  /*! \brief Function to call when finding a matched expression */
+  PackedFunc function_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "DFPatternCallbackNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow pattern callbacks.
+ * \sa DFPatternCallbackNode
+ */
+class DFPatternCallback : public ObjectRef {

Review comment:
       Since this header is fairly small and used only by dataflow_matcher.cc, how about moving the content to dataflow_matcher.cc and remove this header??




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r414162532



##########
File path: include/tvm/relay/dataflow_functor.h
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information

Review comment:
       naming: perhaps it should be `dataflow_pattern_functor.h`? since `dataflow_functor` is a bit confusing

##########
File path: include/tvm/relay/dataflow_functor.h
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_
+#define TVM_RELAY_DATAFLOW_FUNCTOR_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief A dynamical functor that dispatches on in the first DFPattern argument.
+ *
+ * \tparam FType function signiture
+ *  This type is only defined for FType with function signature R(const DFPattern&,
+ * Args...)
+ */
+template <typename FType>
+class DFPatternFunctor;
+
+// functions to be overriden.
+#define DFPATTERN_FUNCTOR_DEFAULT \
+  { return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }
+
+#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP)                                                    \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {          \
+    return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
+  });
+
+template <typename R, typename... Args>
+class DFPatternFunctor<R(const DFPattern& n, Args...)> {
+ private:
+  using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*! \brief virtual destructor */
+  virtual ~DFPatternFunctor() {}
+  /*!
+   * \brief Same as call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  R operator()(const DFPattern& n, Args... args) {
+    return VisitDFPattern(n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief The functor call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  virtual R VisitDFPattern(const DFPattern& n, Args... args) {
+    CHECK(n.defined());
+    static FType vtable = InitVTable();
+    return vtable(n, this, std::forward<Args>(args)...);
+  }
+  // Functions that can be overriden by subclass
+  virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
+                            Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPatternDefault_(const Object* op, Args...) {
+    LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
+    throw;
+  }
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+    FType vtable;
+    // Set dispatch
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
+    return vtable;
+  }
+};
+
+/*!
+ * \brief A simple visitor wrapper around DFPatternFunctor.
+ *  Recursively visit the content.
+ *
+ *  DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once.
+ */
+class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
+ public:
+  void VisitDFPattern(const DFPattern& pattern) override;
+  void VisitDFPattern_(const AltPatternNode* op) override;
+  void VisitDFPattern_(const AttrPatternNode* op) override;
+  void VisitDFPattern_(const CallPatternNode* op) override;
+  void VisitDFPattern_(const DominatorPatternNode* op) override;
+  void VisitDFPattern_(const ExprPatternNode* op) override;
+  void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
+  void VisitDFPattern_(const TuplePatternNode* op) override;
+  void VisitDFPattern_(const TypePatternNode* op) override;
+  void VisitDFPattern_(const VarPatternNode* op) override;
+  void VisitDFPattern_(const WildcardPatternNode* op) override;
+
+ protected:
+  // set of already-visited nodes
+  std::unordered_set<const Object*> visited_;
+};
+
+/*!
+ * \brief A Wrapper around a templated graph type
+ *  Holds a forward-backward indexed representation of the graph and a dominator tree representation
+ * of the graph
+ *
+ *  This class is templated and the implementaiton is in the header file so we can analyze both
+ * DFPattern and Expr with the same infrastructure.
+ *
+ *  IndexedGraph should be instantiated through the CreateIndexedGraph utilities.
+ */
+template <typename T>

Review comment:
       Given T is always subclass of an Object, perhaps it is easier to get rid of the shared_ptr, and make the Node an Object, (so that we cna use the object system)

##########
File path: python/tvm/relay/df_pattern/__init__.py
##########
@@ -0,0 +1,488 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The Relay Pattern Language and tooling."""
+from tvm.relay import Expr

Review comment:
       shall we use the same name as c++ "dataflow_pattern"?

##########
File path: tests/python/relay/test_df_pattern.py
##########
@@ -0,0 +1,783 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay.df_pattern import *
+import numpy as np
+
+# NB: 1 corresponds to the C++ enum that specicfies this
+# we loose the type safety due to the Python/C++ calling
+# convention.
+K_ELEMWISE = 0
+K_BROADCAST = 1
+
+## NODE TESTS
+def test_expr_pattern():
+    ep = ExprPattern(relay.var('x', shape=(4, 1)))
+    print(ep)
+
+def test_var_pattern():
+    v = is_input("x")
+    print(v)
+
+def test_wildcard_pattern():
+    wc = wildcard()
+    print(wc)
+
+def test_CallPattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    c = is_op("add")(wc1, wc2)
+    print(c)
+
+def test_TuplePattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    t = TuplePattern([wc1, wc2])
+    print(t)
+
+def test_TupleGetItemPattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    t = TuplePattern([wc1, wc2])
+    tgi = TupleGetItemPattern(t, 1)
+    print(tgi)
+
+def test_AltPattern():
+    is_add_or_sub = is_op('add') | is_op('subtract')
+    print(is_add_or_sub)
+
+def test_TypePattern():
+    ty_pat = has_type(relay.TensorType((10, 10), "float32"))
+    print(ty_pat)

Review comment:
       remove print, and use asserts instead, e.g. assert the type of the pattern, and the  subfields.

##########
File path: python/tvm/relay/df_pattern/__init__.py
##########
@@ -0,0 +1,488 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The Relay Pattern Language and tooling."""
+from tvm.relay import Expr
+from ...ir.base import Node
+from ...ir import make_node
+from ...runtime import Object
+from ... import _ffi as tvm_ffi
+from ..op import get
+from . import _ffi as ffi
+
+
+def register_df_node(type_key=None):
+    """Register a Relay node type.
+
+    Parameters
+    ----------
+    type_key : str or cls
+        The type key of the node.
+    """
+    if not isinstance(type_key, str):
+        return tvm_ffi.register_object(
+            "relay.df_pattern." + type_key.__name__)(type_key)
+    return tvm_ffi.register_object(type_key)
+
+
+class DFPattern(Node):
+    """Base class of all Patterns.
+    """
+
+    def __call__(self, *args):
+        return CallPattern(self, list(args))
+
+    def __or__(self, other):
+        return AltPattern(self, other)
+
+    def __add__(self, other):
+        return is_op("add")(self, other)
+
+    def __sub__(self, other):
+        return is_op("subtract")(self, other)
+
+    def __mul__(self, other):
+        return is_op("multiply")(self, other)
+
+    def __truediv__(self, other):
+        return is_op("divide")(self, other)
+
+    def has_attr(self, attr_name: str, attr_value):
+        """
+        Add an attribute constraint to this pattern
+
+        Parameters
+        ----------
+        attr_name: str
+            The name of the attribute to match
+        attr_value: Any
+            The value of the attribute to match
+        """
+        attrs = make_node("DictAttrs", **{attr_name: attr_value})
+        return AttrPattern(self, attrs)
+
+    def has_type(self, ttype):
+        """
+        Add a type constraint to this pattern
+
+        Parameters
+        ----------
+        ttype: tvm.relay.Type
+            The type to match
+        """
+        return has_type(ttype, self)
+
+    def match(self, expr: Expr) -> bool:
+        """
+        Match this pattern to an expression
+
+        Parameters
+        ----------
+        expr : tvm.relay.Expr
+            The expression to match.
+        """
+        return match(self, expr)
+
+    def partition(self, expr: Expr) -> bool:
+        """
+        Parition the expression into functions defined by this pattern
+
+        Parameters
+        ----------
+        expr : tvm.relay.Expr
+            The expression to match.
+        """
+        return partition(self, expr)
+
+    def dominates(self, parent, path=None):
+        """
+        Create a dominator for this partern
+
+        Parameters
+        ----------
+        parent: tvm.relay.df_pattern.DFPattern
+            The parent pattern this pattern dominates.
+        path: tvm.relay.df_pattern.DFPattern
+            The fuzzy path pattern.
+        """
+        if path is None:
+            path = wildcard()
+        return DominatorPattern(parent, path, self)
+
+
+def is_input(name: str = "") -> DFPattern:
+    """
+    Syntatic sugar for creating an optionally named VarPattern
+
+    Parameters
+    ----------
+    name: str
+        The name of the input pattern to match
+    """
+    return VarPattern(name)
+
+
+def is_op(op_name: str) -> DFPattern:
+    """
+    Syntatic sugar for creating an operator ExprPattern
+
+    Parameters
+    ----------
+    op_name: String
+        The name of the relay op
+    """
+    op = get(op_name)
+    return ExprPattern(op)
+
+
+def wildcard() -> DFPattern:
+    """
+    Syntatic sugar for creating a WildcardPattern
+    """
+    return WildcardPattern()
+
+
+def has_type(ttype, pattern: DFPattern = None) -> DFPattern:
+    """
+    Syntatic sugar for creating a TypePattern
+
+    Parameters
+    ----------
+    pattern: tvm.relay.df_pattern.DFPattern
+        The pattern that needs type annotation
+
+    ttype: tvm.relay.Type
+        The type to match
+    """
+    if pattern is None:
+        pattern = wildcard()
+    return TypePattern(pattern, ttype)
+
+
+def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern:
+    """
+    Syntatic sugar for creating an AttrPattern
+
+    Parameters
+    ----------
+    pattern: tvm.relay.df_pattern.DFPattern
+        The input pattern.
+
+    attrs: tvm.Attrs
+        The attributes to match
+    """
+    if pattern is None:
+        pattern = wildcard()
+    return pattern.has_attr(attr_name, attr_value)
+
+
+def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern:
+    """
+    Syntatic sugar for creating an Dominator pattern
+
+    Parameters
+    ----------
+    parent: tvm.relay.df_pattern.DFPattern
+        The parent pattern.
+    path: tvm.relay.df_pattern.DFPattern
+        The fuzzy path pattern.
+    child: tvm.relay.df_pattern.DFPattern
+        The child pattern.
+    """
+    return DominatorPattern(parent, path, child)
+
+
+def match(pattern: DFPattern, expr: Expr) -> bool:
+    """
+    Match a pattern to an expression
+
+    Parameters
+    ----------
+    pattern: tvm.relay.df_pattern.DFPattern
+        The input pattern.
+    expr : tvm.relay.Expr
+        The expression to match.
+    """
+    return ffi.match(pattern, expr)
+
+
+@register_df_node
+class ExprPattern(DFPattern):
+    """A pattern which matches a constant expression.
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+        The expression to match.
+    """
+
+    def __init__(self, expr: Expr):
+        self.__init_handle_by_constructor__(ffi.ExprPattern, expr)
+
+
+@register_df_node
+class VarPattern(DFPattern):
+    """A local variable in Relay.
+
+    Local variable can be used to declare input
+    arguments to a function, or intermediate variables.
+
+    Parameters
+    ----------
+    name_hint: str
+        The name of the variable.
+        This name only acts as a hint, and is not used
+        for equality.
+
+    type_annotation: tvm.relay.Type, optional
+        The type annotation on the variable.
+    """
+
+    def __init__(self, name_hint: str, type_annotation=None):
+        self.__init_handle_by_constructor__(
+            ffi.VarPattern, name_hint, type_annotation)
+
+
+@register_df_node
+class CallPattern(DFPattern):
+    """A pattern matching a function call node in Relay.
+
+    Parameters
+    ----------
+    op: realy.df_pattern.DFPattern
+        The operation to be called.
+
+    args: List[realy.df_pattern.DFPattern]
+        The arguments to the call.
+
+    attrs: Optional[tvm.Attrs]
+        Attributes to the call, can be None
+
+    type_args: Optional[List[tvm.relay.Type]]
+        The additional type arguments, this is only
+        used in advanced usecase of template functions.
+    """
+
+    def __init__(self, op, args, attrs=None, type_args=None):
+        if not type_args:
+            type_args = []
+        self.__init_handle_by_constructor__(
+            ffi.CallPattern, op, args, attrs, type_args)
+
+
+@register_df_node
+class TuplePattern(DFPattern):
+    """A patern matching a Relay Tuple.
+
+    Parameters
+    ----------
+    fields : List[tvm.relay.df_pattern.DFPattern]
+        The fields in the tuple.
+    """
+
+    def __init__(self, fields):
+        self.__init_handle_by_constructor__(ffi.TuplePattern, fields)
+
+    def __getitem__(self, index):
+        if index >= len(self):
+            raise IndexError("TuplePattern index out of range")
+        return self.fields[index]
+
+    def __len__(self):
+        return len(self.fields)
+
+    def astype(self, _):
+        raise TypeError("astype cannot be used on TuplePattern")
+
+
+@register_df_node
+class TupleGetItemPattern(DFPattern):
+    """Get index-th item from a TuplePattern.
+
+    Parameters
+    ----------
+    tuple_value: tvm.relay.df_pattern.DFPattern
+        The input tuple expression.
+
+    index: int
+        The index.
+    """
+
+    def __init__(self, tuple_value: DFPattern, index):
+        self.__init_handle_by_constructor__(
+            ffi.TupleGetItemPattern, tuple_value, index)
+
+
+@register_df_node
+class AltPattern(DFPattern):
+    """Create a Pattern that can match one of two conditions
+
+    Parameters
+    ----------
+    left: tvm.relay.df_pattern.DFPattern
+        One possible matching Pattern
+    right: tvm.relay.df_pattern.DFPattern
+        One possible matching Pattern
+    """
+
+    def __init__(self, left: DFPattern, right: DFPattern):
+        self.__init_handle_by_constructor__(
+            ffi.AltPattern, left, right)
+
+
+@register_df_node
+class WildcardPattern(DFPattern):
+    """A pattern which matches anything.
+    """
+
+    def __init__(self):
+        self.__init_handle_by_constructor__(ffi.WildcardPattern)
+
+
+@register_df_node
+class TypePattern(DFPattern):
+    """Get index-th item from a TuplePattern.
+
+    Parameters
+    ----------
+    pattern: tvm.relay.df_pattern.DFPattern
+        The input pattern that needs type annotation
+
+    ttype: tvm.relay.Type
+        The type to match
+    """
+
+    def __init__(self, pattern: DFPattern, ttype):
+        self.__init_handle_by_constructor__(
+            ffi.TypePattern, pattern, ttype)
+
+
+@register_df_node
+class AttrPattern(DFPattern):
+    """Get match an expression with a certain attributes.
+    Currently only supports Op Attributes, not call Attributes
+
+    Parameters
+    ----------
+    pattern: tvm.relay.df_pattern.DFPattern
+        The input pattern.
+
+    attrs: tvm.Attrs
+        The attributes to match
+    """
+
+    def __init__(self, pattern: DFPattern, attrs):
+        self.__init_handle_by_constructor__(
+            ffi.AttrPattern, pattern, attrs)
+
+
+@register_df_node
+class DominatorPattern(DFPattern):
+    """Match a domination graph.
+
+    Parameters
+    ----------
+    parent: tvm.relay.df_pattern.DFPattern
+        The parent, i.e., the single node which produces something,
+        later aggregated by the child
+    path: tvm.relay.df_pattern.DFPattern
+        The fuzzy path pattern between parent and child,
+        typically matches elementwise ops
+    child: tvm.relay.df_pattern.DFPattern
+        The last node in the domination which is the end user
+        for all nodes in the path and the parent
+    """
+
+    def __init__(self, parent: DFPattern, path: DFPattern, child: DFPattern):
+        self.__init_handle_by_constructor__(
+            ffi.DominatorPattern, parent, path, child)
+
+
+class DFPatternCallback:
+    """A Callback for Pattern Rewriting
+
+    When rewrite is called on this DFPatternCallback, the backend will find matches for the
+    pattern, call the callback function, and replace the matched expression with whatever
+    the callback returns.
+
+    Users are expect to inherit from this class and provide a "self.pattern" to match
+    """
+
+    def rewrite(self, expr: Expr) -> Expr:
+        """
+        Rewrite expression with this callback
+
+        Parameters
+        ----------
+        expr : tvm.relay.Expr
+            The expression to rewrite.
+        """
+        return rewrite(self, expr)
+
+    def callback(self, pre, post, node_map):
+        """
+        Callback function to use when we found a match to the pattern
+
+        Parameters
+        ----------
+        pre : tvm.relay.Expr
+            The matching expression from the original graph.
+        post : tvm.relay.Expr
+            The matching expression with rewritten inputs
+        node_map : Map(DFPattern, List(Expr))
+            The map between patterns and matched expressions
+        """
+        raise "Unimplemented"
+
+class _DFPatternCallback(Object):
+    """C++ implemenation"""
+    def __init__(self, pattern, callback):
+        self.__init_handle_by_constructor__(
+            ffi.DFPatternCallback, pattern, callback)
+
+
+def rewrite(callbacks, expr: Expr) -> Expr:
+    """
+    Rewrite expression with the given callbacks
+
+    Parameters
+    ----------
+    callbacks: tvm.relay.df_pattern.DFPatternCallback
+        The input callback or list of callbacks.
+    expr : tvm.relay.Expr
+        The expression to rewrite.
+    """
+    if isinstance(callbacks, DFPatternCallback):
+        tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)]
+    else:
+        tmp = []
+        for callback in callbacks:
+            tmp.append(_DFPatternCallback(callback.pattern, callback.callback))
+
+    return ffi.rewrite(tmp, expr)
+
+def partition(pattern: DFPattern, expr: Expr) -> Expr:
+    """
+    Parition the expression into a series of functions that match the pattern
+
+    Parameters
+    ----------
+    partion: tvm.relay.df_pattern.DFPattern

Review comment:
       need to document return value

##########
File path: python/tvm/relay/df_pattern/__init__.py
##########
@@ -0,0 +1,488 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The Relay Pattern Language and tooling."""
+from tvm.relay import Expr
+from ...ir.base import Node
+from ...ir import make_node
+from ...runtime import Object
+from ... import _ffi as tvm_ffi
+from ..op import get
+from . import _ffi as ffi
+
+
+def register_df_node(type_key=None):
+    """Register a Relay node type.
+
+    Parameters
+    ----------
+    type_key : str or cls
+        The type key of the node.
+    """
+    if not isinstance(type_key, str):
+        return tvm_ffi.register_object(
+            "relay.df_pattern." + type_key.__name__)(type_key)
+    return tvm_ffi.register_object(type_key)

Review comment:
       Let us directly do tvm._ffi.register_object (inline the call), This will make life easier for  https://github.com/tqchen/ffi-navigator, and also makes the code more explicit

##########
File path: python/tvm/relay/df_pattern/__init__.py
##########
@@ -0,0 +1,488 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The Relay Pattern Language and tooling."""
+from tvm.relay import Expr
+from ...ir.base import Node
+from ...ir import make_node
+from ...runtime import Object
+from ... import _ffi as tvm_ffi

Review comment:
       we are moving towards absoluate import in python, 
   
   do import tvm._ffi here




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424838458



##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <stack>
+
+#include "indexed_graph.h"
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            CHECK(false) << "Unsupported type in Type Pattern Node";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+              (is_expr_op(call_node->args[0], "divide") ||
+               is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs,
+                                     op->type_args);
+              auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                     arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                (is_expr_op(call_node->args[0], "multiply") ||
+                 is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                     op->attrs, op->type_args);
+              auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs,
+                                     arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) {
+          return false;
+        }
+      }
+    }
+  }
+  return true;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));
+  if (VisitDFPattern(op->child, expr)) {
+    bool matches_path = MatchesPath(op, expr);
+    memoize_ = true;
+    if (matches_path) {
+      return DominatesParent(op, expr);
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
+  return StructuralEqual()(op->expr, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
+    matches = (op->index == tuple_get_item_node->index) &&
+              VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* tuple_node = expr.as<TupleNode>()) {
+    if (op->fields.size() == tuple_node->fields.size()) {
+      matches = true;
+      size_t i = 0;
+      while (matches && i < op->fields.size()) {
+        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+        ++i;
+      }
+    }
+  }
+  return matches;
+}
+
+Expr InferType(const Expr& expr) {
+  auto mod = IRModule::FromExpr(expr);
+  mod = transform::InferType()(mod);
+  if (expr.as<FunctionNode>()) {
+    return mod->Lookup("main");
+  } else {
+    return mod->Lookup("main").as<FunctionNode>()->body;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
+  bool matches = false;
+  if (const auto* var_node = expr.as<VarNode>()) {
+    matches = true;
+    if (op->name_hint() != "") {
+      matches &= op->name_hint() == var_node->name_hint();
+    }
+  }
+  return matches;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match")
+    .set_body_typed([](DFPattern pattern, Expr expr) {
+      return DFPatternMatcher(expr).Match(pattern, expr);
+    });
+
+/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarily needed to support the post-dominator analysis required for dominator pattern
+ * matching.
+ */
+class PatternGrouper : protected MixedModeVisitor {
+ public:
+  /* \brief Internal Group class for storing analysis */
+  struct Group {
+    Expr root_node;
+    int gid;
+    Map<DFPattern, Array<Expr>> matched_nodes;
+    Function function;
+    Array<Expr> args;
+  };
+
+  /* \brief Return the discovered groups */
+  const std::vector<Group>& GetGroups() { return this->groups_; }
+
+  /* \brief Return the group assignments of expressions */
+  const std::unordered_map<Expr, int, ObjectHash, ObjectEqual>& GetGIDAssignments() {
+    return gid_assignments_;
+  }
+  /* \brief Group expressions that match the pattern */
+  void GroupMatches(const DFPattern& pattern, const Expr& pre) {
+    groups_ = {Group()};
+    gid_assignments_.clear();
+    visit_counter_.clear();
+
+    pattern_ = pattern;
+    pattern_graph_ = CreateIndexedGraph(pattern_);
+    auto matcher = DFPatternMatcher(pre);
+    matcher_ = &matcher;
+    this->VisitExpr(pre);
+  }
+
+ protected:
+  void VisitLeaf(const Expr& pre) override {
+    if (matcher_->Match(pattern_, pre)) {
+      CreateGroup(pre);
+    }
+  }
+
+  /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
+   * group overlap analysis */
+  class MatchExtractor : public ExprMutator {
+   public:
+    explicit MatchExtractor(const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual>& inputs)
+        : inputs_(inputs) {}
+    const std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>& GetMemo() { return this->memo_; }
+
+   protected:
+    Expr VisitExpr(const Expr& pre) override {
+      if (inputs_.count(pre)) {
+        return inputs_.at(pre);
+      }
+      return ExprMutator::VisitExpr(pre);
+    }
+    const std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> inputs_;
+  };
+
+  /* \brief Create a group based on a matched expression */
+  void CreateGroup(const Expr& expr) {
+    var_number_ = 0;

Review comment:
       It seems this variable can be a local variable instead of member




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen edited a comment on pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-618704483






----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424846955



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;
+/*!
+ * \brief Base type of all dataflow pattern callbacks.
+ * \sa DFPatternCallback
+ */
+class DFPatternCallbackNode : public Object {
+ public:
+  /*! \brief Pattern this callback matches */
+  DFPattern pattern_;
+  /*! \brief Function to call when finding a matched expression */
+  PackedFunc function_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "DFPatternCallbackNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow pattern callbacks.
+ * \sa DFPatternCallbackNode
+ */
+class DFPatternCallback : public ObjectRef {

Review comment:
       Since this header is fairly small and used only by dataflow_matcher.cc, how about moving the content to dataflow_matcher.cc and remove this header?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart edited a comment on pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
mbrookhart edited a comment on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-625942189


   @tqchen @mbaret Thanks for the comments! Care to take another look after I updated the code based on your comments?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r425433993



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;
+/*!
+ * \brief Base type of all dataflow pattern callbacks.
+ * \sa DFPatternCallback
+ */
+class DFPatternCallbackNode : public Object {
+ public:
+  /*! \brief Pattern this callback matches */
+  DFPattern pattern_;
+  /*! \brief Function to call when finding a matched expression */
+  PackedFunc function_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "DFPatternCallbackNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow pattern callbacks.
+ * \sa DFPatternCallbackNode
+ */
+class DFPatternCallback : public ObjectRef {

Review comment:
       Something got lost in a refactor. I want to users to be able to write pattern-based passes in C++, which requires this in a header, but I don't seem to have the pass functions exposed. Will fix.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r425411117



##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <stack>
+
+#include "indexed_graph.h"
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            CHECK(false) << "Unsupported type in Type Pattern Node";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+              (is_expr_op(call_node->args[0], "divide") ||
+               is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs,
+                                     op->type_args);
+              auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                     arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                (is_expr_op(call_node->args[0], "multiply") ||
+                 is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                     op->attrs, op->type_args);
+              auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs,
+                                     arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) {
+          return false;
+        }
+      }
+    }
+  }
+  return true;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));

Review comment:
       :+1: Looks like it got left behind in a refactor




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5231: Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-629445467


   Thanks @mbrookhart @jroesch @mbaret @masahi @yzhliu !


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-618704483


   Thanks @mbrookhart now that we have a concrete POC, it would be nice to have another round of ABI review with the folks, possibly open another thread at the dicuss forum to provide examples about what the relay.dataflow_pattern can do so far and get feedbacks about API choices.
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r425440445



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;
+/*!
+ * \brief Base type of all dataflow pattern callbacks.
+ * \sa DFPatternCallback
+ */
+class DFPatternCallbackNode : public Object {
+ public:
+  /*! \brief Pattern this callback matches */
+  DFPattern pattern_;
+  /*! \brief Function to call when finding a matched expression */
+  PackedFunc function_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "DFPatternCallbackNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow pattern callbacks.
+ * \sa DFPatternCallbackNode
+ */
+class DFPatternCallback : public ObjectRef {

Review comment:
       I can see these pass functions can be useful for op fusion and BYOC related passes.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424847028



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;

Review comment:
       I think you can remove this forward decl.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen merged pull request #5231: Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424846955



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;
+/*!
+ * \brief Base type of all dataflow pattern callbacks.
+ * \sa DFPatternCallback
+ */
+class DFPatternCallbackNode : public Object {
+ public:
+  /*! \brief Pattern this callback matches */
+  DFPattern pattern_;
+  /*! \brief Function to call when finding a matched expression */
+  PackedFunc function_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "DFPatternCallbackNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow pattern callbacks.
+ * \sa DFPatternCallbackNode
+ */
+class DFPatternCallback : public ObjectRef {

Review comment:
       Since this header is fairly small and used only by dataflow_matcher.cc, how about moving the content to dataflow_matcher.cc?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424844266



##########
File path: include/tvm/relay/dataflow_pattern_functor.h
##########
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h

Review comment:
       dataflow_pattern_functor.h




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424847028



##########
File path: include/tvm/relay/dataflow_matcher.h
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;

Review comment:
       I think you can remove this forward decl.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#issuecomment-618705083


   cc @junrushao1994 @ajtulloch @u99127 @yzhliu 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r424827496



##########
File path: src/relay/ir/dataflow_matcher.cc
##########
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <stack>
+
+#include "indexed_graph.h"
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
+  std::vector<DFPattern> matched_nodes_;
+  IndexedGraph<Expr> expr_graph_;
+  IndexedGraph<DFPattern> pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+    memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
+  if (memoize_ && memo_.count(pattern)) {
+    CHECK_EQ(memo_[pattern].size(), 1);
+    return expr.same_as(memo_[pattern][0]);
+  } else {
+    auto watermark = matched_nodes_.size();
+    auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+    if (out) {
+      memo_[pattern].push_back(expr);
+      matched_nodes_.push_back(pattern);
+    } else {
+      ClearMap(watermark);
+    }
+    return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as<OpNode>()) {
+    Op op = GetRef<Op>(op_node);
+    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
+    for (auto kv : attributes) {
+      auto attr_name = kv.first;
+      auto attr_value = kv.second;
+      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      if (op_map.count(op)) {
+        switch (op_map[op].type_code()) {
+          case kDLInt:
+            if (auto* val = kv.second.as<IntImmNode>()) {
+              matches = val->value == op_map[op].operator int64_t();
+            }
+            break;
+          case kDLFloat:
+            if (auto* val = kv.second.as<FloatImmNode>()) {
+              matches = val->value == op_map[op].operator double();
+            }
+            break;
+          case kTVMStr:
+            if (auto* val = kv.second.as<tir::StringImmNode>()) {
+              matches = val->value == op_map[op].operator std::string();
+            }
+            break;
+          default:
+            CHECK(false) << "Unsupported type in Type Pattern Node";
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+Array<DFPattern> reverse(const Array<DFPattern>& args) {
+  Array<DFPattern> new_args;
+  for (auto it = args.rbegin(); it != args.rend(); ++it) {
+    new_args.push_back(*it);
+  }
+  return new_args;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
+  // utilities
+  auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
+    if (op) {
+      if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
+        return expr_pattern->expr.as<OpNode>();
+      }
+    }
+    return nullptr;
+  };
+  auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
+    if (const auto* op_node = get_op_node(op)) {
+      if (op_node->name == op_type) {
+        return true;
+      }
+    }
+    return false;
+  };
+  auto is_expr_op = [](const Expr& expr, std::string op_type) {
+    if (const auto* call_node = expr.as<CallNode>()) {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        if (op_node->name == op_type) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+  // logic
+  auto watermark = matched_nodes_.size();
+  if (const auto* call_node = expr.as<CallNode>()) {
+    auto matches_op = VisitDFPattern(op->op, call_node->op);
+    if (matches_op) {
+      auto watermark2 = matched_nodes_.size();
+
+      auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args,
+                                            const Array<Expr> expr_args) {
+        bool matches = true;
+        size_t i = 0;
+        if (pattern_args.size() == expr_args.size()) {
+          while (matches && i < pattern_args.size()) {
+            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+            ++i;
+          }
+        } else {
+          matches = false;
+        }
+        if (!matches) {
+          ClearMap(watermark2);
+        }
+        return matches;
+      };
+
+      // Standard case
+      if (match_args(op->args, call_node->args)) {
+        return true;
+      }
+      // Commutative Matching
+      if (const OpNode* op_node = get_op_node(op)) {
+        if ((op_node->name == "add") || (op_node->name == "multiply")) {
+          if (match_args(reverse(op->args), call_node->args)) {
+            return true;
+          }
+        }
+      }
+    } else {
+      ClearMap(watermark);
+      // associate divide/multiply
+      if (is_pattern_op(op, "divide")) {
+        if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
+          if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") &&
+              (is_expr_op(call_node->args[0], "divide") ||
+               is_expr_op(call_node->args[1], "divide"))) {
+            bool out = false;
+            for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+              auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs,
+                                     op->type_args);
+              auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div},
+                                     arg_node->attrs, arg_node->type_args);
+              out = VisitDFPattern(mul, expr);
+              if (out) {
+                return true;
+              } else {
+                ClearMap(watermark);
+              }
+            }
+            return out;
+          }
+        }
+      }
+      if (is_pattern_op(op, "multiply")) {
+        // associate multiply/divide
+        for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
+          if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
+            if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") &&
+                (is_expr_op(call_node->args[0], "multiply") ||
+                 is_expr_op(call_node->args[1], "multiply"))) {
+              auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]},
+                                     op->attrs, op->type_args);
+              auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs,
+                                     arg_node->type_args);
+              return VisitDFPattern(div, expr);
+            }
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Recursively find the Dominator parent along all inputs paths.
+bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
+  auto call_node = expr.as<CallNode>();
+  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+    if (!(call_node && node->ref_ == call_node->op)) {
+      memoize_ = true;
+      if (VisitDFPattern(op->parent, node->ref_)) {
+        return true;
+      } else {
+        memoize_ = false;
+        if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) {
+          return false;
+        }
+      }
+    }
+  }
+  return true;
+}
+
+// Iteratively ensure that the parent is dominated somewhere by the child or the path
+bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
+  std::stack<Expr> stack;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> visited;
+  stack.push(expr);
+  while (!stack.empty()) {
+    Expr current = stack.top();
+    stack.pop();
+    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+      if (visited.count(node->ref_) == 0) {
+        if (VisitDFPattern(op->parent, node->ref_)) {
+          return true;
+        } else {
+          stack.push(node->ref_);
+        }
+        visited.insert(node->ref_);
+      }
+    }
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) {
+  pattern_graph_ = CreateIndexedGraph(GetRef<DFPattern>(op));

Review comment:
       `pattern_graph_` doesn't seem to be used inside `DFPatternMatcher`. Can we remove it from `DFPatternMatcher`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org