You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/04/03 16:33:21 UTC
[incubator-tvm] branch master updated: [RELAY][FIX] Fix hang in
MergeCompilerRegions (#5227)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 54975a3 [RELAY][FIX] Fix hang in MergeCompilerRegions (#5227)
54975a3 is described below
commit 54975a3fd24fa45b815be39075f4614e53009444
Author: mbaret <55...@users.noreply.github.com>
AuthorDate: Fri Apr 3 17:33:15 2020 +0100
[RELAY][FIX] Fix hang in MergeCompilerRegions (#5227)
For certain network topologies, MCR could hang.
This patch fixes that case.
Change-Id: I3edd8a8a6b452b2b838b777720adea22a3b995b4
---
src/relay/analysis/annotated_region_set.cc | 5 ++---
src/relay/transforms/merge_compiler_regions.cc | 5 +++--
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc
index f7b9b42..ad2b9e1 100644
--- a/src/relay/analysis/annotated_region_set.cc
+++ b/src/relay/analysis/annotated_region_set.cc
@@ -22,7 +22,6 @@
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
-#include <algorithm>
#include <unordered_map>
#include <vector>
@@ -58,8 +57,8 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input);
- auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
- if (it != src->outs.end()) {
+ auto it = src->nodes.find(call->args[0]);
+ if (it != src->nodes.end()) {
dest->outs.remove(*it);
ins_to_remove.push_back(input);
}
diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc
index 4a8ff64..5253010 100644
--- a/src/relay/transforms/merge_compiler_regions.cc
+++ b/src/relay/transforms/merge_compiler_regions.cc
@@ -263,6 +263,7 @@ class RegionMerger : public ExprVisitor {
void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) {
auto region = regions_->GetRegion(GetRef<Call>(call));
+ if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
// set the region target
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
region_targets_[region->GetID()] = compiler_attrs->compiler;
@@ -281,13 +282,13 @@ class RegionMerger : public ExprVisitor {
}
}
// get the mergeable regions now all the parents have been visited
- std::vector<AnnotatedRegion> mergeable_regions;
+ std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
auto parent_region = regions_->GetRegion(begin->args[0]);
if (!parent_region.defined()) continue;
- mergeable_regions.push_back(parent_region);
+ mergeable_regions.insert(parent_region);
}
auto& region_restrictions = region_restrictions_[region->GetID()];
for (const auto& parent_region : mergeable_regions) {