You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/03/26 08:15:11 UTC

[incubator-tvm] branch master updated: [RELAY] Added a AnnotatedRegion utility class (#5030)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new b5ec071  [RELAY] Added a AnnotatedRegion utility class (#5030)
b5ec071 is described below

commit b5ec07114b303be20b6ce091750aae5fa4b4dee5
Author: mbaret <55...@users.noreply.github.com>
AuthorDate: Thu Mar 26 08:14:57 2020 +0000

    [RELAY] Added a AnnotatedRegion utility class (#5030)
    
    * [RELAY] Added an AnnotatedRegionSet utility class
    
    In many of the passes involved in graph partitioning,
    we need to extract and manipulate annotated regions.
    This class simplifies the extraction of regions from a relay
    expression containing region begin and end annotations
    as well as providing utility functions to query these
    regions and merge them.
    
    Co-authored-by: Ramana Radhakrishnan  <ra...@arm.com>
    
    Change-Id: Ia912fea0b99f64b6a7197aa6da2347e58f469fbb
    
    * Rename fix
    
    * Update MakeRegions
    
    * Fix __init__
    
    * Indentation
    
    * Code style
    
    * Remove 'Region' from docs
    
    * Overload [] to get region
    
    * Use src/dest for MergeRegions
    
    * Simplify merge
    
    * Tidy const loop vars
---
 python/tvm/relay/analysis/__init__.py          |   3 +
 python/tvm/relay/analysis/annotated_regions.py |  62 ++++++
 src/relay/analysis/annotated_region_set.cc     | 233 ++++++++++++++++++++
 src/relay/analysis/annotated_region_set.h      | 286 +++++++++++++++++++++++++
 tests/python/relay/test_annotated_regions.py   | 121 +++++++++++
 5 files changed, 705 insertions(+)

diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py
index e5185ea..a1833c3 100644
--- a/python/tvm/relay/analysis/__init__.py
+++ b/python/tvm/relay/analysis/__init__.py
@@ -19,6 +19,9 @@
 # Analysis passes
 from .analysis import *
 
+# Annotations
+from .annotated_regions import AnnotatedRegionSet
+
 # Call graph
 from . import call_graph
 from .call_graph import CallGraph
diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py
new file mode 100644
index 0000000..fc8e85a
--- /dev/null
+++ b/python/tvm/relay/analysis/annotated_regions.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
+"""Regions used in Relay."""
+
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+class AnnotatedRegionSet(Object):
+    """Class to represent a relay expression split into regions."""
+
+    def __init__(self, expr, region_begin_op, region_end_op):
+        """Construct regions from an expression.
+
+        Parameters
+        ----------
+        expr : tvm.relay.Expr
+            The expression from which to construct the regions.
+        region_begin_op : tvm.relay.Op
+            The region begin annotation.
+        region_end_op : tvm.relay.Op
+            The region end annotation.
+
+        """
+        self.__init_handle_by_constructor__(_ffi_api.AnnotatedRegionSet,
+                                            expr,
+                                            region_begin_op,
+                                            region_end_op)
+
+    def __len__(self):
+        return len(self.regions)
+
+    def get_region(self, expr):
+        """Get the region an expression belongs to.
+
+        Parameters
+        ----------
+        expr : tvm.relay.Expr
+            The expression.
+
+        Returns
+        -------
+        region
+            The region containing the expression.
+            None if not found.
+        """
+        return _ffi_api.GetRegion(self, expr)
diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc
new file mode 100644
index 0000000..f8e951b
--- /dev/null
+++ b/src/relay/analysis/annotated_region_set.cc
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "annotated_region_set.h"
+
+#include <tvm/relay/expr.h>
+#include <tvm/ir/error.h>
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+
+namespace tvm {
+namespace relay {
+
+AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
+  for (auto candidate : regions_) {
+    if (candidate->nodes.find(expr) != candidate->nodes.end()) {
+      return candidate;
+    }
+  }
+  return AnnotatedRegion(nullptr);
+}
+
+void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
+                                          AnnotatedRegion dest) {
+  if (dest == src) {
+    return;
+  }
+
+  // Merge src to dest and erase src.
+  dest->nodes.insert(src->nodes.begin(), src->nodes.end());
+  for (const auto& input : src->ins) {
+    dest->ins.push_back(input);
+  }
+  for (const auto& output : src->outs) {
+    dest->outs.push_back(output);
+  }
+  // if any of the outputs of src are inputs of dest, they become internal nodes
+  // so remove them from outs
+  for (const auto& input : dest->ins) {
+    auto call = Downcast<Call>(input);
+    auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
+    if (it != src->outs.end()) {
+      dest->outs.remove(*it);
+      dest->ins.remove(input);
+    }
+  }
+  regions_.erase(src);
+}
+
+void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) {
+  auto region2 = GetRegion(expr);
+  if (region2.defined()) {
+    MergeRegions(region, region2);
+  } else {
+    region->nodes.insert(expr);
+  }
+}
+
+AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
+  auto ret = regions_.emplace(AnnotatedRegion());
+  (*ret.first)->id = region_id_++;
+  return *ret.first;
+}
+
+class AnnotatedRegionSet::Creator : public ExprVisitor {
+ public:
+  Creator(const Op& region_begin_op, const Op& region_end_op) :
+    begin_op_(region_begin_op), end_op_(region_end_op) {}
+
+  AnnotatedRegionSet Create(const Expr& expr) {
+    VisitExpr(expr);
+    return std::move(region_set_);
+  }
+
+  void VisitExpr_(const CallNode* call) {
+    auto op_node = call->op.as<OpNode>();
+
+    if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
+      // Propagate region to arguments
+      auto region = region_set_->GetRegion(GetRef<Call>(call));
+      if (region.defined()) {
+        for (auto arg : call->args) {
+          region_set_->AddToRegion(region, arg);
+        }
+      }
+    } else if (call->op == begin_op_) {
+      // The annotation node is inserted on edge so it must have only one argument.
+      CHECK_EQ(call->args.size(), 1U);
+
+      auto region = region_set_->GetRegion(GetRef<Call>(call));
+      if (!region.defined()) {
+        throw Error(ErrorBuilder()
+                      << "Cannot find the corresponding region for start annotation:\n"
+                      << AsText(GetRef<Call>(call), false));
+      }
+      region->ins.push_back(GetRef<Call>(call));
+    } else {
+      CHECK_EQ(call->op, end_op_);
+      // The annotation node is inserted on edge so it must have only one argument.
+      CHECK_EQ(call->args.size(), 1U);
+
+      // Check if the argument already belongs to a region
+      auto region = region_set_->GetRegion(call->args[0]);
+      if (!region.defined()) {
+        region = region_set_->MakeRegion();
+        region->nodes.insert(call->args[0]);
+      }
+      region->nodes.insert(GetRef<Call>(call));
+      region->outs.push_back(GetRef<Call>(call));
+    }
+    ExprVisitor::VisitExpr_(call);
+  }
+
+  void VisitExpr_(const TupleNode* op) {
+    auto region = region_set_->GetRegion(GetRef<Tuple>(op));
+    if (region.defined()) {
+      for (auto field : op->fields) {
+        region_set_->AddToRegion(region, field);
+      }
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const TupleGetItemNode* g) {
+    auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g));
+    if (region.defined()) {
+      region_set_->AddToRegion(region, g->tuple);
+    }
+    ExprVisitor::VisitExpr_(g);
+  }
+
+  void VisitExpr_(const FunctionNode* op) {
+    auto region = region_set_->GetRegion(GetRef<Function>(op));
+    if (region.defined()) {
+      for (auto param : op->params) {
+        region_set_->AddToRegion(region, param);
+      }
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const LetNode* op) {
+    auto region = region_set_->GetRegion(GetRef<Let>(op));
+    if (region.defined()) {
+      region_set_->AddToRegion(region, op->var);
+      region_set_->AddToRegion(region, op->value);
+      region_set_->AddToRegion(region, op->body);
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const IfNode* op) {
+    auto region = region_set_->GetRegion(GetRef<If>(op));
+    if (region.defined()) {
+      region_set_->AddToRegion(region, op->cond);
+      region_set_->AddToRegion(region, op->true_branch);
+      region_set_->AddToRegion(region, op->false_branch);
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const RefCreateNode* op) {
+    auto region = region_set_->GetRegion(GetRef<RefCreate>(op));
+    if (region.defined()) {
+      region_set_->AddToRegion(region, op->value);
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const RefReadNode* op) {
+    auto region = region_set_->GetRegion(GetRef<RefRead>(op));
+    if (region.defined()) {
+      region_set_->AddToRegion(region, op->ref);
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const RefWriteNode* op) {
+    auto region = region_set_->GetRegion(GetRef<RefWrite>(op));
+    if (region.defined()) {
+      region_set_->AddToRegion(region, op->ref);
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+ private:
+  /*! \brief The region set being constructed.*/
+  AnnotatedRegionSet region_set_;
+  /*! \brief Region 'begin' annotation operator. */
+  const Op begin_op_;
+  /*! \brief Region 'end' annotation operator. */
+  const Op end_op_;
+};
+
+AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end) {
+  return Creator(begin, end).Create(expr);
+}
+
+TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode);
+TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode);
+
+TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet")
+.set_body_typed([](Expr expr, Op begin, Op end) {
+  return AnnotatedRegionSet::Create(expr, begin, end);
+});
+
+TVM_REGISTER_GLOBAL("relay.analysis.GetRegion")
+.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) {
+  return region_set->GetRegion(expr);
+});
+
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h
new file mode 100644
index 0000000..c5db2cc
--- /dev/null
+++ b/src/relay/analysis/annotated_region_set.h
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/pass/annotated_region_set.h
+ * \brief Define data structures to extract and manipulate regions from
+ * a relay function. Regions are denoted by region_begin and region_end
+ * annotations that exist on all the input and output edges of the region.
+ */
+
+#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
+#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/ir/error.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+#include <list>
+
+namespace tvm {
+namespace relay {
+
+class AnnotatedRegion;
+class AnnotatedRegionSet;
+
+class AnnotatedRegionNode : public Object {
+ public:
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("id", &id);
+    Array<Expr> nodes_array(nodes.begin(), nodes.end());
+    v->Visit("nodes", &nodes_array);
+    Array<Expr> args_array(ins.begin(), ins.end());
+    v->Visit("args", &args_array);
+    Array<Expr> rets_array(outs.begin(), outs.end());
+    v->Visit("rets", &rets_array);
+  }
+
+  /*! \brief Get the region ID. */
+  int GetID() const {
+    return id;
+  }
+
+  /*! \brief Get the region's inputs. */
+  std::list<Expr> GetInputs() const {
+    return ins;
+  }
+
+  /*! \brief Get the region's outputs. */
+  std::list<Expr> GetOutputs() const {
+    return outs;
+  }
+
+  /*! \brief Get the region's nodes. */
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
+    return nodes;
+  }
+
+  static constexpr const char* _type_key = "relay.AnnotatedRegion";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);
+
+ protected:
+  /*! \brief The region ID. */
+  int id{-1};
+  /*! \brief The inputs to this region. */
+  std::list<Expr> ins;
+  /*! \brief The outputs of this region */
+  std::list<Expr> outs;
+  /*! \brief Nodes in this region. */
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
+
+  friend class AnnotatedRegionSet;
+  friend class AnnotatedRegionSetNode;
+};
+
+/*!
+ * \brief An object to hold the properties of a region as used by the
+ * AnnotatedRegionSet class. This should be considered read-only.
+*/
+class AnnotatedRegion : public ObjectRef {
+ public:
+  AnnotatedRegion() {
+    auto n = make_object<AnnotatedRegionNode>();
+    data_ = std::move(n);
+  }
+
+  /*!
+ * \brief Construct from an object pointer.
+ * \param n The object pointer.
+ */
+  explicit AnnotatedRegion(ObjectPtr<Object> n) : ObjectRef(n) {}
+
+  /*! \return Mutable pointers to the node. */
+  AnnotatedRegionNode* operator->() const {
+    auto* ptr = get_mutable();
+    CHECK(ptr != nullptr);
+    return static_cast<AnnotatedRegionNode*>(ptr);
+  }
+};
+
+class AnnotatedRegionSetNode : public Object {
+  using UnorderedRegionSet =
+  std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
+  // Create iterator alias for a RegionSet object.
+  using iterator = UnorderedRegionSet::iterator;
+  using const_iterator = UnorderedRegionSet::const_iterator;
+
+ public:
+  /*! \brief Default constructor. */
+  AnnotatedRegionSetNode() = default;
+
+  /*! \return The begin iterator */
+  iterator begin() {
+    return regions_.begin();
+  }
+  /*! \return The end iterator */
+  iterator end() {
+    return regions_.end();
+  }
+  /*! \return The const begin iterator */
+  const_iterator begin() const {
+    return regions_.begin();
+  }
+  /*! \return The const end iterator */
+  const_iterator end() const {
+    return regions_.end();
+  }
+
+  /*!
+   * \brief Get the region that an expression belongs to.
+   *
+   * \param expr Which expr to get the region for.
+   *
+   * \return A pointer to the region, nullptr if the expression
+   * doesn't belong to a region.
+   */
+  AnnotatedRegion GetRegion(const Expr& expr) const;
+
+  /*!
+ * \brief Merge src region into dest region.
+ *
+ * \param src The region to merge - will be erased.
+ * \param dest The region into which src will be merged.
+ */
+  void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest);
+
+  void VisitAttrs(AttrVisitor* v) {
+    Array<AnnotatedRegion> regions_array(regions_.begin(), regions_.end());
+    v->Visit("regions", &regions_array);
+  }
+
+  static constexpr const char* _type_key = "relay.AnnotatedRegionSet";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionSetNode, Object);
+
+ private:
+  /*!
+   * \brief Add an expression to a region.
+   *
+   * \param region The region to add the expression to.
+   * \param expr The expression.
+   */
+  void AddToRegion(AnnotatedRegion region, const Expr& expr);
+
+  /*!
+   * \brief Make a new region.
+   *
+   * \return The new region.
+   */
+  AnnotatedRegion MakeRegion();
+
+  std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
+  /*! \brief The next region ID to assign. */
+  int region_id_{0};
+
+  friend class AnnotatedRegionSet;
+};
+
+/*!
+ * \brief A class to hold a set of regions produced from a relay expression
+ * that contains 'region_begin' and 'region_end' style annotations. The
+ * regions should be disjoint. The class provides both a method to construct
+ * the region set of a given relay expression as well as additional methods
+ * to update and query regions.
+ */
+class AnnotatedRegionSet : public ObjectRef {
+  using UnorderedRegionSet =
+    std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
+  // Create iterator alias for a RegionSet object.
+  using iterator = UnorderedRegionSet::iterator;
+  using const_iterator = UnorderedRegionSet::const_iterator;
+
+ public:
+  AnnotatedRegionSet() {
+    auto n = make_object<AnnotatedRegionSetNode>();
+    data_ = std::move(n);
+  }
+
+  /*!
+ * \brief Construct from an object pointer.
+ *
+ * \param n The object pointer.
+ */
+  explicit AnnotatedRegionSet(ObjectPtr<Object> n) : ObjectRef(n) {}
+
+  /*! \return The begin iterator. */
+  iterator begin() {
+    auto* n = operator->();
+    CHECK(n);
+    return n->begin();
+  }
+  /*! \return The end iterator. */
+  iterator end() {
+    auto* n = operator->();
+    CHECK(n);
+    return n->end();
+  }
+  /*! \return The begin iterator. */
+  const_iterator begin() const {
+    const auto* n = operator->();
+    CHECK(n);
+    return n->begin();
+  }
+  /*! \return The end iterator. */
+  const_iterator end() const {
+    const auto *n = operator->();
+    CHECK(n);
+    return n->end();
+  }
+
+  /*! \return mutable pointers to the node. */
+  AnnotatedRegionSetNode* operator->() const {
+    auto* ptr = get_mutable();
+    CHECK(ptr != nullptr);
+    return static_cast<AnnotatedRegionSetNode*>(ptr);
+  }
+
+  /*! \return The region an expression belongs to. */
+  AnnotatedRegion operator[](const Expr& expr) {
+    const auto *n = operator->();
+    CHECK(n);
+    return n->GetRegion(expr);
+  }
+
+  /*! \brief Create a RegionSet from a relay expression.
+   *
+   * \param expr The relay expr from which to construct the set.
+   * \param begin Region begin annotation operator.
+   * \param end Region end annotation operator.
+   *
+   * \return The created RegionSet for the expression.
+   */
+  static AnnotatedRegionSet Create(const Expr& expr,
+                                   const Op& begin,
+                                   const Op& end);
+
+ private:
+  /*! \brief Helper class to construct a RegionSet from an expr.*/
+  class Creator;
+};
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
diff --git a/tests/python/relay/test_annotated_regions.py b/tests/python/relay/test_annotated_regions.py
new file mode 100644
index 0000000..a246398
--- /dev/null
+++ b/tests/python/relay/test_annotated_regions.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
+from tvm import relay
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+
+
+def check_region(region_set, args, nodes, rets):
+    region = region_set.get_region(args[0])
+    assert region
+    assert set(args) == set(region.args)
+    assert set(nodes) == set(region.nodes)
+    assert set(rets) == set(region.rets)
+
+
+def test_region_set_creator_diamond():
+    data = relay.var('data', shape=(10, 10))
+    cb_1 = compiler_begin(data, 'test_target')
+    O_1 = relay.abs(cb_1)
+    ce_1 = compiler_end(O_1, 'test_target')
+    ce_2 = compiler_end(O_1, 'test_target')
+    cb_2 = compiler_begin(ce_1, 'test_target')
+    O_2 = relay.nn.relu(cb_2)
+    ce_3 = compiler_end(O_2, 'test_target')
+    cb_d = compiler_begin(ce_2, "default")
+    X = relay.tanh(cb_d)
+    ce_d = compiler_end(X, 'default')
+    cb_3 = compiler_begin(ce_3, 'test_target')
+    cb_4 = compiler_begin(ce_d, 'test_target')
+    O_3 = relay.add(cb_3, cb_4)
+    ce_4 = compiler_end(O_3, 'test_target')
+    diamond = relay.Function([data], ce_4)
+
+    region_set = relay.analysis.AnnotatedRegionSet(diamond,
+                                                   relay.op.get("annotation.compiler_begin"),
+                                                   relay.op.get("annotation.compiler_end"))
+    assert len(region_set) == 4
+    check_region(
+        region_set,
+        [cb_1],
+        [cb_1, O_1, ce_1, ce_2],
+        [ce_1, ce_2],
+    )
+    check_region(
+        region_set,
+        [cb_2],
+        [cb_2, O_2, ce_3],
+        [ce_3],
+    )
+    check_region(
+        region_set,
+        [cb_d],
+        [cb_d, X, ce_d],
+        [ce_d],
+    )
+    check_region(
+        region_set,
+        [cb_3, cb_4],
+        [cb_3, cb_4, O_3, ce_4],
+        [ce_4],
+    )
+
+
+def test_region_set_creator_merged():
+    data = relay.var('data', shape=(10, 10))
+    cb_1 = compiler_begin(data, 'test_target')
+    O_1 = relay.abs(cb_1)
+    ce_2 = compiler_end(O_1, 'test_target')
+    O_2 = relay.nn.relu(O_1)
+    ce_3 = compiler_end(O_2, 'test_target')
+    cb_d = compiler_begin(ce_2, "default")
+    X = relay.tanh(cb_d)
+    ce_d = compiler_end(X, 'default')
+    cb_3 = compiler_begin(ce_3, 'test_target')
+    cb_4 = compiler_begin(ce_d, 'test_target')
+    O_3 = relay.add(cb_3, cb_4)
+    ce_4 = compiler_end(O_3, 'test_target')
+    merged = relay.Function([data], ce_4)
+
+    region_set = relay.analysis.AnnotatedRegionSet(merged,
+                                                   relay.op.get("annotation.compiler_begin"),
+                                                   relay.op.get("annotation.compiler_end"))
+    assert len(region_set) == 3
+    check_region(
+        region_set,
+        [cb_1],
+        [cb_1, O_1, O_2, ce_2, ce_3],
+        [ce_2, ce_3],
+    )
+    check_region(
+        region_set,
+        [cb_d],
+        [cb_d, X, ce_d],
+        [ce_d],
+    )
+    check_region(
+        region_set,
+        [cb_3, cb_4],
+        [cb_3, cb_4, O_3, ce_4],
+        [ce_4],
+    )
+
+
+if __name__ == "__main__":
+    test_region_set_creator_diamond()
+    test_region_set_creator_merged()
+