You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/07/11 07:23:06 UTC
[tvm] branch main updated: [RELAY] Fix bug in MergeCompilerRegions pass (#15211)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e4a120955b [RELAY] Fix bug in MergeCompilerRegions pass (#15211)
e4a120955b is described below
commit e4a120955b02674c67ceec6d88884ec5768ff7c2
Author: chenxinli <39...@users.noreply.github.com>
AuthorDate: Tue Jul 11 15:22:57 2023 +0800
[RELAY] Fix bug in MergeCompilerRegions pass (#15211)
---
src/relay/transforms/merge_compiler_regions.cc | 36 ++++++++++++-
.../relay/test_pass_merge_compiler_regions.py | 62 ++++++++++++++++++++++
2 files changed, 97 insertions(+), 1 deletion(-)
diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc
index d70c7480e9..92e881fa61 100644
--- a/src/relay/transforms/merge_compiler_regions.cc
+++ b/src/relay/transforms/merge_compiler_regions.cc
@@ -53,6 +53,35 @@ class RegionMerger : public MixedModeVisitor {
public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
+ void find_control_flow_regions(
+ const Expr op,
+ std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual>& correlative_regions) {
+ // Find correlative restriction regions from control flow.
+
+ // In IfNode, find from condition, true_branch and false branch.
+ const IfNode* if_node = op.as<IfNode>();
+ if (if_node) {
+ auto cond_region = regions_->GetRegion(if_node->cond);
+ auto true_branch_region = regions_->GetRegion(if_node->true_branch);
+ auto false_branch_region = regions_->GetRegion(if_node->false_branch);
+ if (cond_region.defined()) {
+ correlative_regions.insert(cond_region);
+ } else {
+ find_control_flow_regions(if_node->cond, correlative_regions);
+ }
+ if (true_branch_region.defined()) {
+ correlative_regions.insert(true_branch_region);
+ } else {
+ find_control_flow_regions(if_node->true_branch, correlative_regions);
+ }
+ if (false_branch_region.defined()) {
+ correlative_regions.insert(false_branch_region);
+ } else {
+ find_control_flow_regions(if_node->false_branch, correlative_regions);
+ }
+ }
+ }
+
void VisitExpr_(const CallNode* call) final {
if (call->op == CompilerEndOp()) {
auto region = regions_->GetRegion(GetRef<Call>(call));
@@ -84,18 +113,23 @@ class RegionMerger : public MixedModeVisitor {
// Collect unmerged parent regions.
std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> mergeable_regions;
+ // Collect correlative regions to propagate restrictions.
+ std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> correlative_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
ICHECK_EQ(begin->op, CompilerBeginOp());
auto parent_region = regions_->GetRegion(begin->args[0]);
if (parent_region.defined()) {
mergeable_regions.insert(parent_region);
+ correlative_regions.insert(parent_region);
+ } else {
+ find_control_flow_regions(begin->args[0], correlative_regions);
}
}
// Propogate all the parent restrictions to the current region.
auto& region_restrictions = region_restrictions_[region->GetID()];
- for (const auto& parent_region : mergeable_regions) {
+ for (const auto& parent_region : correlative_regions) {
auto parent_restrictions = region_restrictions_[parent_region->GetID()];
region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
}
diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py
index ba94021d3f..b67eac6abd 100644
--- a/tests/python/relay/test_pass_merge_compiler_regions.py
+++ b/tests/python/relay/test_pass_merge_compiler_regions.py
@@ -17,6 +17,7 @@
"""Unit tests for merge compiler regions."""
import tvm
from tvm import relay
+import tvm.relay.transform as transform
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass
@@ -214,6 +215,67 @@ def test_example_graph():
assert tvm.ir.structural_equal(mod, ref_mod)
+def test_if_else():
+ """
+ This tests that the restriction regions propagate successful in
+ if_else control flow.
+
+ O = supported by target
+ X = not supported by target
+
+
+ O1 - - - | O1 --|
+ | | |
+ X | X
+ | | |
+ If cond ? O1: X | --> + + If cond ? O1: X +
+ | | |
+ O2 <- - -| O2 <-|
+
+
+ Avoid O1 merge to O2.
+ """
+
+ target = "test_if_else"
+
+ @tvm.ir.register_op_attr("sigmoid", "target." + target)
+ def sigmoid(expr): # pylint: disable=unused-variable
+ return True
+
+ @tvm.ir.register_op_attr("erf", "target." + target)
+ def erf(expr): # pylint: disable=unused-variable
+ return True
+
+ @tvm.ir.register_op_attr("add", "target." + target)
+ def add(expr): # pylint: disable=unused-variable
+ return True
+
+ """Test that If-else nodes merges regions correctly."""
+
+ def get_mod():
+ data = relay.var("data", shape=(1, 32))
+ add0 = relay.add(data, data)
+ sub0 = relay.subtract(add0, data)
+ eq = relay.equal(relay.sum(add0), relay.sum(sub0))
+
+ true_branch = relay.sigmoid(add0)
+ false_branch = relay.sigmoid(sub0)
+ ife = relay.If(eq, true_branch, false_branch)
+ erf = relay.erf(ife)
+ out = relay.add(add0, erf)
+ func = relay.Function([data], out)
+ mod = tvm.IRModule.from_expr(func)
+
+ return mod
+
+ for annotate_non_call_ops in [True, False]:
+ result = transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
+ merge = transform.MergeCompilerRegions()(result)
+ # Ensure partition finished without segment fault.
+ partition = transform.PartitionGraph()(merge)
+
+
if __name__ == "__main__":
test_diamond_graph_fanouts()
test_example_graph()
+ test_if_else()