You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/04/03 16:33:21 UTC

[incubator-tvm] branch master updated: [RELAY][FIX] Fix hang in MergeCompilerRegions (#5227)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 54975a3  [RELAY][FIX] Fix hang in MergeCompilerRegions (#5227)
54975a3 is described below

commit 54975a3fd24fa45b815be39075f4614e53009444
Author: mbaret <55...@users.noreply.github.com>
AuthorDate: Fri Apr 3 17:33:15 2020 +0100

    [RELAY][FIX] Fix hang in MergeCompilerRegions (#5227)
    
    For certain network topologies, MCR could hang.
    This patch fixes that case.
    
    Change-Id: I3edd8a8a6b452b2b838b777720adea22a3b995b4
---
 src/relay/analysis/annotated_region_set.cc     | 5 ++---
 src/relay/transforms/merge_compiler_regions.cc | 5 +++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc
index f7b9b42..ad2b9e1 100644
--- a/src/relay/analysis/annotated_region_set.cc
+++ b/src/relay/analysis/annotated_region_set.cc
@@ -22,7 +22,6 @@
 #include <tvm/relay/expr.h>
 #include <tvm/ir/error.h>
 
-#include <algorithm>
 #include <unordered_map>
 #include <vector>
 
@@ -58,8 +57,8 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
   std::vector<Expr> ins_to_remove;
   for (const auto& input : dest->ins) {
     auto call = Downcast<Call>(input);
-    auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
-    if (it != src->outs.end()) {
+    auto it = src->nodes.find(call->args[0]);
+    if (it != src->nodes.end()) {
       dest->outs.remove(*it);
       ins_to_remove.push_back(input);
     }
diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc
index 4a8ff64..5253010 100644
--- a/src/relay/transforms/merge_compiler_regions.cc
+++ b/src/relay/transforms/merge_compiler_regions.cc
@@ -263,6 +263,7 @@ class RegionMerger : public ExprVisitor {
   void VisitExpr_(const CallNode* call) final {
     if (call->op == compiler_end_op) {
       auto region = regions_->GetRegion(GetRef<Call>(call));
+      if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
       // set the region target
       auto compiler_attrs = call->attrs.as<CompilerAttrs>();
       region_targets_[region->GetID()] = compiler_attrs->compiler;
@@ -281,13 +282,13 @@ class RegionMerger : public ExprVisitor {
         }
       }
       // get the mergeable regions now all the parents have been visited
-      std::vector<AnnotatedRegion> mergeable_regions;
+      std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
       for (const auto& arg : region->GetInputs()) {
         auto begin = Downcast<Call>(arg);
         CHECK_EQ(begin->op, compiler_begin_op);
         auto parent_region = regions_->GetRegion(begin->args[0]);
         if (!parent_region.defined()) continue;
-        mergeable_regions.push_back(parent_region);
+        mergeable_regions.insert(parent_region);
       }
       auto& region_restrictions = region_restrictions_[region->GetID()];
       for (const auto& parent_region : mergeable_regions) {