You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/07/11 07:23:06 UTC

[tvm] branch main updated: [RELAY] Fix bug in MergeCompilerRegions pass (#15211)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e4a120955b [RELAY] Fix bug in MergeCompilerRegions pass (#15211)
e4a120955b is described below

commit e4a120955b02674c67ceec6d88884ec5768ff7c2
Author: chenxinli <39...@users.noreply.github.com>
AuthorDate: Tue Jul 11 15:22:57 2023 +0800

    [RELAY] Fix bug in MergeCompilerRegions pass (#15211)
---
 src/relay/transforms/merge_compiler_regions.cc     | 36 ++++++++++++-
 .../relay/test_pass_merge_compiler_regions.py      | 62 ++++++++++++++++++++++
 2 files changed, 97 insertions(+), 1 deletion(-)

diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc
index d70c7480e9..92e881fa61 100644
--- a/src/relay/transforms/merge_compiler_regions.cc
+++ b/src/relay/transforms/merge_compiler_regions.cc
@@ -53,6 +53,35 @@ class RegionMerger : public MixedModeVisitor {
  public:
   explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
 
+  void find_control_flow_regions(
+      const Expr op,
+      std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual>& correlative_regions) {
+    // Find correlative restriction regions from control flow.
+
+    // In IfNode, find from condition, true_branch and false branch.
+    const IfNode* if_node = op.as<IfNode>();
+    if (if_node) {
+      auto cond_region = regions_->GetRegion(if_node->cond);
+      auto true_branch_region = regions_->GetRegion(if_node->true_branch);
+      auto false_branch_region = regions_->GetRegion(if_node->false_branch);
+      if (cond_region.defined()) {
+        correlative_regions.insert(cond_region);
+      } else {
+        find_control_flow_regions(if_node->cond, correlative_regions);
+      }
+      if (true_branch_region.defined()) {
+        correlative_regions.insert(true_branch_region);
+      } else {
+        find_control_flow_regions(if_node->true_branch, correlative_regions);
+      }
+      if (false_branch_region.defined()) {
+        correlative_regions.insert(false_branch_region);
+      } else {
+        find_control_flow_regions(if_node->false_branch, correlative_regions);
+      }
+    }
+  }
+
   void VisitExpr_(const CallNode* call) final {
     if (call->op == CompilerEndOp()) {
       auto region = regions_->GetRegion(GetRef<Call>(call));
@@ -84,18 +113,23 @@ class RegionMerger : public MixedModeVisitor {
 
       // Collect unmerged parent regions.
       std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> mergeable_regions;
+      // Collect correlative regions to propagate restrictions.
+      std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> correlative_regions;
       for (const auto& arg : region->GetInputs()) {
         auto begin = Downcast<Call>(arg);
         ICHECK_EQ(begin->op, CompilerBeginOp());
         auto parent_region = regions_->GetRegion(begin->args[0]);
         if (parent_region.defined()) {
           mergeable_regions.insert(parent_region);
+          correlative_regions.insert(parent_region);
+        } else {
+          find_control_flow_regions(begin->args[0], correlative_regions);
         }
       }
 
       // Propogate all the parent restrictions to the current region.
       auto& region_restrictions = region_restrictions_[region->GetID()];
-      for (const auto& parent_region : mergeable_regions) {
+      for (const auto& parent_region : correlative_regions) {
         auto parent_restrictions = region_restrictions_[parent_region->GetID()];
         region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
       }
diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py
index ba94021d3f..b67eac6abd 100644
--- a/tests/python/relay/test_pass_merge_compiler_regions.py
+++ b/tests/python/relay/test_pass_merge_compiler_regions.py
@@ -17,6 +17,7 @@
 """Unit tests for merge compiler regions."""
 import tvm
 from tvm import relay
+import tvm.relay.transform as transform
 from tvm.relay.op.annotation import compiler_begin, compiler_end
 from tvm.relay.testing import run_opt_pass
 
@@ -214,6 +215,67 @@ def test_example_graph():
     assert tvm.ir.structural_equal(mod, ref_mod)
 
 
+def test_if_else():
+    """
+    This tests that the restriction regions propagate successful in
+    if_else control flow.
+
+    O = supported by target
+    X = not supported by target
+
+
+           O1 - - - |      O1 --|
+            |       |               |
+            X       |               X
+            |       |                              |
+    If cond ? O1: X | -->       +       +  If cond ? O1: X  +
+            |       |                                           |
+           O2 <- - -|                                          O2 <-|
+
+
+    Avoid O1 merge to O2.
+    """
+
+    target = "test_if_else"
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(expr):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(expr):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("add", "target." + target)
+    def add(expr):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes merges regions correctly."""
+
+    def get_mod():
+        data = relay.var("data", shape=(1, 32))
+        add0 = relay.add(data, data)
+        sub0 = relay.subtract(add0, data)
+        eq = relay.equal(relay.sum(add0), relay.sum(sub0))
+
+        true_branch = relay.sigmoid(add0)
+        false_branch = relay.sigmoid(sub0)
+        ife = relay.If(eq, true_branch, false_branch)
+        erf = relay.erf(ife)
+        out = relay.add(add0, erf)
+        func = relay.Function([data], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    for annotate_non_call_ops in [True, False]:
+        result = transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
+        merge = transform.MergeCompilerRegions()(result)
+        # Ensure partition finished without segment fault.
+        partition = transform.PartitionGraph()(merge)
+
+
 if __name__ == "__main__":
     test_diamond_graph_fanouts()
     test_example_graph()
+    test_if_else()