You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/03/26 08:15:11 UTC
[incubator-tvm] branch master updated: [RELAY] Added a
AnnotatedRegion utility class (#5030)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new b5ec071 [RELAY] Added a AnnotatedRegion utility class (#5030)
b5ec071 is described below
commit b5ec07114b303be20b6ce091750aae5fa4b4dee5
Author: mbaret <55...@users.noreply.github.com>
AuthorDate: Thu Mar 26 08:14:57 2020 +0000
[RELAY] Added a AnnotatedRegion utility class (#5030)
* [RELAY] Added an AnnotatedRegionSet utility class
In many of the passes involved in graph partitioning,
we need to extract and manipulate annotated regions.
This class simplifies the extraction of regions from a relay
expression containing region begin and end annotations
as well as providing utility functions to query these
regions and merge them.
Co-authored-by: Ramana Radhakrishnan <ra...@arm.com>
Change-Id: Ia912fea0b99f64b6a7197aa6da2347e58f469fbb
* Rename fix
* Update MakeRegions
* Fix __init__
* Indentation
* Code style
* Remove 'Region' from docs
* Overload [] to get region
* Use src/dest for MergeRegions
* Simplify merge
* Tidy const loop vars
---
python/tvm/relay/analysis/__init__.py | 3 +
python/tvm/relay/analysis/annotated_regions.py | 62 ++++++
src/relay/analysis/annotated_region_set.cc | 233 ++++++++++++++++++++
src/relay/analysis/annotated_region_set.h | 286 +++++++++++++++++++++++++
tests/python/relay/test_annotated_regions.py | 121 +++++++++++
5 files changed, 705 insertions(+)
diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py
index e5185ea..a1833c3 100644
--- a/python/tvm/relay/analysis/__init__.py
+++ b/python/tvm/relay/analysis/__init__.py
@@ -19,6 +19,9 @@
# Analysis passes
from .analysis import *
+# Annotations
+from .annotated_regions import AnnotatedRegionSet
+
# Call graph
from . import call_graph
from .call_graph import CallGraph
diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py
new file mode 100644
index 0000000..fc8e85a
--- /dev/null
+++ b/python/tvm/relay/analysis/annotated_regions.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
+"""Regions used in Relay."""
+
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+class AnnotatedRegionSet(Object):
+ """Class to represent a relay expression split into regions."""
+
+ def __init__(self, expr, region_begin_op, region_end_op):
+ """Construct regions from an expression.
+
+ Parameters
+ ----------
+ expr : tvm.relay.Expr
+ The expression from which to construct the regions.
+ region_begin_op : tvm.relay.Op
+ The region begin annotation.
+ region_end_op : tvm.relay.Op
+ The region end annotation.
+
+ """
+ self.__init_handle_by_constructor__(_ffi_api.AnnotatedRegionSet,
+ expr,
+ region_begin_op,
+ region_end_op)
+
+ def __len__(self):
+ return len(self.regions)
+
+ def get_region(self, expr):
+ """Get the region an expression belongs to.
+
+ Parameters
+ ----------
+ expr : tvm.relay.Expr
+ The expression.
+
+ Returns
+ -------
+ region
+ The region containing the expression.
+ None if not found.
+ """
+ return _ffi_api.GetRegion(self, expr)
diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc
new file mode 100644
index 0000000..f8e951b
--- /dev/null
+++ b/src/relay/analysis/annotated_region_set.cc
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "annotated_region_set.h"
+
+#include <tvm/relay/expr.h>
+#include <tvm/ir/error.h>
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+
+namespace tvm {
+namespace relay {
+
+AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
+ for (auto candidate : regions_) {
+ if (candidate->nodes.find(expr) != candidate->nodes.end()) {
+ return candidate;
+ }
+ }
+ return AnnotatedRegion(nullptr);
+}
+
+void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
+ AnnotatedRegion dest) {
+ if (dest == src) {
+ return;
+ }
+
+ // Merge src to dest and erase src.
+ dest->nodes.insert(src->nodes.begin(), src->nodes.end());
+ for (const auto& input : src->ins) {
+ dest->ins.push_back(input);
+ }
+ for (const auto& output : src->outs) {
+ dest->outs.push_back(output);
+ }
+ // if any of the outputs of src are inputs of dest, they become internal nodes
+ // so remove them from outs
+ for (const auto& input : dest->ins) {
+ auto call = Downcast<Call>(input);
+ auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
+ if (it != src->outs.end()) {
+ dest->outs.remove(*it);
+ dest->ins.remove(input);
+ }
+ }
+ regions_.erase(src);
+}
+
+void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) {
+ auto region2 = GetRegion(expr);
+ if (region2.defined()) {
+ MergeRegions(region, region2);
+ } else {
+ region->nodes.insert(expr);
+ }
+}
+
+AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
+ auto ret = regions_.emplace(AnnotatedRegion());
+ (*ret.first)->id = region_id_++;
+ return *ret.first;
+}
+
+class AnnotatedRegionSet::Creator : public ExprVisitor {
+ public:
+ Creator(const Op& region_begin_op, const Op& region_end_op) :
+ begin_op_(region_begin_op), end_op_(region_end_op) {}
+
+ AnnotatedRegionSet Create(const Expr& expr) {
+ VisitExpr(expr);
+ return std::move(region_set_);
+ }
+
+ void VisitExpr_(const CallNode* call) {
+ auto op_node = call->op.as<OpNode>();
+
+ if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
+ // Propagate region to arguments
+ auto region = region_set_->GetRegion(GetRef<Call>(call));
+ if (region.defined()) {
+ for (auto arg : call->args) {
+ region_set_->AddToRegion(region, arg);
+ }
+ }
+ } else if (call->op == begin_op_) {
+ // The annotation node is inserted on edge so it must have only one argument.
+ CHECK_EQ(call->args.size(), 1U);
+
+ auto region = region_set_->GetRegion(GetRef<Call>(call));
+ if (!region.defined()) {
+ throw Error(ErrorBuilder()
+ << "Cannot find the corresponding region for start annotation:\n"
+ << AsText(GetRef<Call>(call), false));
+ }
+ region->ins.push_back(GetRef<Call>(call));
+ } else {
+ CHECK_EQ(call->op, end_op_);
+ // The annotation node is inserted on edge so it must have only one argument.
+ CHECK_EQ(call->args.size(), 1U);
+
+ // Check if the argument already belongs to a region
+ auto region = region_set_->GetRegion(call->args[0]);
+ if (!region.defined()) {
+ region = region_set_->MakeRegion();
+ region->nodes.insert(call->args[0]);
+ }
+ region->nodes.insert(GetRef<Call>(call));
+ region->outs.push_back(GetRef<Call>(call));
+ }
+ ExprVisitor::VisitExpr_(call);
+ }
+
+ void VisitExpr_(const TupleNode* op) {
+ auto region = region_set_->GetRegion(GetRef<Tuple>(op));
+ if (region.defined()) {
+ for (auto field : op->fields) {
+ region_set_->AddToRegion(region, field);
+ }
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const TupleGetItemNode* g) {
+ auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g));
+ if (region.defined()) {
+ region_set_->AddToRegion(region, g->tuple);
+ }
+ ExprVisitor::VisitExpr_(g);
+ }
+
+ void VisitExpr_(const FunctionNode* op) {
+ auto region = region_set_->GetRegion(GetRef<Function>(op));
+ if (region.defined()) {
+ for (auto param : op->params) {
+ region_set_->AddToRegion(region, param);
+ }
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const LetNode* op) {
+ auto region = region_set_->GetRegion(GetRef<Let>(op));
+ if (region.defined()) {
+ region_set_->AddToRegion(region, op->var);
+ region_set_->AddToRegion(region, op->value);
+ region_set_->AddToRegion(region, op->body);
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const IfNode* op) {
+ auto region = region_set_->GetRegion(GetRef<If>(op));
+ if (region.defined()) {
+ region_set_->AddToRegion(region, op->cond);
+ region_set_->AddToRegion(region, op->true_branch);
+ region_set_->AddToRegion(region, op->false_branch);
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const RefCreateNode* op) {
+ auto region = region_set_->GetRegion(GetRef<RefCreate>(op));
+ if (region.defined()) {
+ region_set_->AddToRegion(region, op->value);
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const RefReadNode* op) {
+ auto region = region_set_->GetRegion(GetRef<RefRead>(op));
+ if (region.defined()) {
+ region_set_->AddToRegion(region, op->ref);
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const RefWriteNode* op) {
+ auto region = region_set_->GetRegion(GetRef<RefWrite>(op));
+ if (region.defined()) {
+ region_set_->AddToRegion(region, op->ref);
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ private:
+ /*! \brief The region set being constructed.*/
+ AnnotatedRegionSet region_set_;
+ /*! \brief Region 'begin' annotation operator. */
+ const Op begin_op_;
+ /*! \brief Region 'end' annotation operator. */
+ const Op end_op_;
+};
+
+AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end) {
+ return Creator(begin, end).Create(expr);
+}
+
+TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode);
+TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode);
+
+TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet")
+.set_body_typed([](Expr expr, Op begin, Op end) {
+ return AnnotatedRegionSet::Create(expr, begin, end);
+});
+
+TVM_REGISTER_GLOBAL("relay.analysis.GetRegion")
+.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) {
+ return region_set->GetRegion(expr);
+});
+
+
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h
new file mode 100644
index 0000000..c5db2cc
--- /dev/null
+++ b/src/relay/analysis/annotated_region_set.h
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/pass/annotated_region_set.h
+ * \brief Define data structures to extract and manipulate regions from
+ * a relay function. Regions are denoted by region_begin and region_end
+ * annotations that exist on all the input and output edges of the region.
+ */
+
+#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
+#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/ir/error.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+#include <list>
+
+namespace tvm {
+namespace relay {
+
+class AnnotatedRegion;
+class AnnotatedRegionSet;
+
+class AnnotatedRegionNode : public Object {
+ public:
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("id", &id);
+ Array<Expr> nodes_array(nodes.begin(), nodes.end());
+ v->Visit("nodes", &nodes_array);
+ Array<Expr> args_array(ins.begin(), ins.end());
+ v->Visit("args", &args_array);
+ Array<Expr> rets_array(outs.begin(), outs.end());
+ v->Visit("rets", &rets_array);
+ }
+
+ /*! \brief Get the region ID. */
+ int GetID() const {
+ return id;
+ }
+
+ /*! \brief Get the region's inputs. */
+ std::list<Expr> GetInputs() const {
+ return ins;
+ }
+
+ /*! \brief Get the region's outputs. */
+ std::list<Expr> GetOutputs() const {
+ return outs;
+ }
+
+ /*! \brief Get the region's nodes. */
+ std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
+ return nodes;
+ }
+
+ static constexpr const char* _type_key = "relay.AnnotatedRegion";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);
+
+ protected:
+ /*! \brief The region ID. */
+ int id{-1};
+ /*! \brief The inputs to this region. */
+ std::list<Expr> ins;
+ /*! \brief The outputs of this region */
+ std::list<Expr> outs;
+ /*! \brief Nodes in this region. */
+ std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
+
+ friend class AnnotatedRegionSet;
+ friend class AnnotatedRegionSetNode;
+};
+
+/*!
+ * \brief An object to hold the properties of a region as used by the
+ * AnnotatedRegionSet class. This should be considered read-only.
+*/
+class AnnotatedRegion : public ObjectRef {
+ public:
+ AnnotatedRegion() {
+ auto n = make_object<AnnotatedRegionNode>();
+ data_ = std::move(n);
+ }
+
+ /*!
+ * \brief Construct from an object pointer.
+ * \param n The object pointer.
+ */
+ explicit AnnotatedRegion(ObjectPtr<Object> n) : ObjectRef(n) {}
+
+ /*! \return Mutable pointers to the node. */
+ AnnotatedRegionNode* operator->() const {
+ auto* ptr = get_mutable();
+ CHECK(ptr != nullptr);
+ return static_cast<AnnotatedRegionNode*>(ptr);
+ }
+};
+
+class AnnotatedRegionSetNode : public Object {
+ using UnorderedRegionSet =
+ std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
+ // Create iterator alias for a RegionSet object.
+ using iterator = UnorderedRegionSet::iterator;
+ using const_iterator = UnorderedRegionSet::const_iterator;
+
+ public:
+ /*! \brief Default constructor. */
+ AnnotatedRegionSetNode() = default;
+
+ /*! \return The begin iterator */
+ iterator begin() {
+ return regions_.begin();
+ }
+ /*! \return The end iterator */
+ iterator end() {
+ return regions_.end();
+ }
+ /*! \return The const begin iterator */
+ const_iterator begin() const {
+ return regions_.begin();
+ }
+ /*! \return The const end iterator */
+ const_iterator end() const {
+ return regions_.end();
+ }
+
+ /*!
+ * \brief Get the region that an expression belongs to.
+ *
+ * \param expr Which expr to get the region for.
+ *
+ * \return A pointer to the region, nullptr if the expression
+ * doesn't belong to a region.
+ */
+ AnnotatedRegion GetRegion(const Expr& expr) const;
+
+ /*!
+ * \brief Merge src region into dest region.
+ *
+ * \param src The region to merge - will be erased.
+ * \param dest The region into which src will be merged.
+ */
+ void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest);
+
+ void VisitAttrs(AttrVisitor* v) {
+ Array<AnnotatedRegion> regions_array(regions_.begin(), regions_.end());
+ v->Visit("regions", ®ions_array);
+ }
+
+ static constexpr const char* _type_key = "relay.AnnotatedRegionSet";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionSetNode, Object);
+
+ private:
+ /*!
+ * \brief Add an expression to a region.
+ *
+ * \param region The region to add the expression to.
+ * \param expr The expression.
+ */
+ void AddToRegion(AnnotatedRegion region, const Expr& expr);
+
+ /*!
+ * \brief Make a new region.
+ *
+ * \return The new region.
+ */
+ AnnotatedRegion MakeRegion();
+
+ std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
+ /*! \brief The next region ID to assign. */
+ int region_id_{0};
+
+ friend class AnnotatedRegionSet;
+};
+
+/*!
+ * \brief A class to hold a set of regions produced from a relay expression
+ * that contains 'region_begin' and 'region_end' style annotations. The
+ * regions should be disjoint. The class provides both a method to construct
+ * the region set of a given relay expression as well as additional methods
+ * to update and query regions.
+ */
+class AnnotatedRegionSet : public ObjectRef {
+ using UnorderedRegionSet =
+ std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
+ // Create iterator alias for a RegionSet object.
+ using iterator = UnorderedRegionSet::iterator;
+ using const_iterator = UnorderedRegionSet::const_iterator;
+
+ public:
+ AnnotatedRegionSet() {
+ auto n = make_object<AnnotatedRegionSetNode>();
+ data_ = std::move(n);
+ }
+
+ /*!
+ * \brief Construct from an object pointer.
+ *
+ * \param n The object pointer.
+ */
+ explicit AnnotatedRegionSet(ObjectPtr<Object> n) : ObjectRef(n) {}
+
+ /*! \return The begin iterator. */
+ iterator begin() {
+ auto* n = operator->();
+ CHECK(n);
+ return n->begin();
+ }
+ /*! \return The end iterator. */
+ iterator end() {
+ auto* n = operator->();
+ CHECK(n);
+ return n->end();
+ }
+ /*! \return The begin iterator. */
+ const_iterator begin() const {
+ const auto* n = operator->();
+ CHECK(n);
+ return n->begin();
+ }
+ /*! \return The end iterator. */
+ const_iterator end() const {
+ const auto *n = operator->();
+ CHECK(n);
+ return n->end();
+ }
+
+ /*! \return mutable pointers to the node. */
+ AnnotatedRegionSetNode* operator->() const {
+ auto* ptr = get_mutable();
+ CHECK(ptr != nullptr);
+ return static_cast<AnnotatedRegionSetNode*>(ptr);
+ }
+
+ /*! \return The region an expression belongs to. */
+ AnnotatedRegion operator[](const Expr& expr) {
+ const auto *n = operator->();
+ CHECK(n);
+ return n->GetRegion(expr);
+ }
+
+ /*! \brief Create a RegionSet from a relay expression.
+ *
+ * \param expr The relay expr from which to construct the set.
+ * \param begin Region begin annotation operator.
+ * \param end Region end annotation operator.
+ *
+ * \return The created RegionSet for the expression.
+ */
+ static AnnotatedRegionSet Create(const Expr& expr,
+ const Op& begin,
+ const Op& end);
+
+ private:
+ /*! \brief Helper class to construct a RegionSet from an expr.*/
+ class Creator;
+};
+
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
diff --git a/tests/python/relay/test_annotated_regions.py b/tests/python/relay/test_annotated_regions.py
new file mode 100644
index 0000000..a246398
--- /dev/null
+++ b/tests/python/relay/test_annotated_regions.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
+from tvm import relay
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+
+
+def check_region(region_set, args, nodes, rets):
+ region = region_set.get_region(args[0])
+ assert region
+ assert set(args) == set(region.args)
+ assert set(nodes) == set(region.nodes)
+ assert set(rets) == set(region.rets)
+
+
+def test_region_set_creator_diamond():
+ data = relay.var('data', shape=(10, 10))
+ cb_1 = compiler_begin(data, 'test_target')
+ O_1 = relay.abs(cb_1)
+ ce_1 = compiler_end(O_1, 'test_target')
+ ce_2 = compiler_end(O_1, 'test_target')
+ cb_2 = compiler_begin(ce_1, 'test_target')
+ O_2 = relay.nn.relu(cb_2)
+ ce_3 = compiler_end(O_2, 'test_target')
+ cb_d = compiler_begin(ce_2, "default")
+ X = relay.tanh(cb_d)
+ ce_d = compiler_end(X, 'default')
+ cb_3 = compiler_begin(ce_3, 'test_target')
+ cb_4 = compiler_begin(ce_d, 'test_target')
+ O_3 = relay.add(cb_3, cb_4)
+ ce_4 = compiler_end(O_3, 'test_target')
+ diamond = relay.Function([data], ce_4)
+
+ region_set = relay.analysis.AnnotatedRegionSet(diamond,
+ relay.op.get("annotation.compiler_begin"),
+ relay.op.get("annotation.compiler_end"))
+ assert len(region_set) == 4
+ check_region(
+ region_set,
+ [cb_1],
+ [cb_1, O_1, ce_1, ce_2],
+ [ce_1, ce_2],
+ )
+ check_region(
+ region_set,
+ [cb_2],
+ [cb_2, O_2, ce_3],
+ [ce_3],
+ )
+ check_region(
+ region_set,
+ [cb_d],
+ [cb_d, X, ce_d],
+ [ce_d],
+ )
+ check_region(
+ region_set,
+ [cb_3, cb_4],
+ [cb_3, cb_4, O_3, ce_4],
+ [ce_4],
+ )
+
+
+def test_region_set_creator_merged():
+ data = relay.var('data', shape=(10, 10))
+ cb_1 = compiler_begin(data, 'test_target')
+ O_1 = relay.abs(cb_1)
+ ce_2 = compiler_end(O_1, 'test_target')
+ O_2 = relay.nn.relu(O_1)
+ ce_3 = compiler_end(O_2, 'test_target')
+ cb_d = compiler_begin(ce_2, "default")
+ X = relay.tanh(cb_d)
+ ce_d = compiler_end(X, 'default')
+ cb_3 = compiler_begin(ce_3, 'test_target')
+ cb_4 = compiler_begin(ce_d, 'test_target')
+ O_3 = relay.add(cb_3, cb_4)
+ ce_4 = compiler_end(O_3, 'test_target')
+ merged = relay.Function([data], ce_4)
+
+ region_set = relay.analysis.AnnotatedRegionSet(merged,
+ relay.op.get("annotation.compiler_begin"),
+ relay.op.get("annotation.compiler_end"))
+ assert len(region_set) == 3
+ check_region(
+ region_set,
+ [cb_1],
+ [cb_1, O_1, O_2, ce_2, ce_3],
+ [ce_2, ce_3],
+ )
+ check_region(
+ region_set,
+ [cb_d],
+ [cb_d, X, ce_d],
+ [ce_d],
+ )
+ check_region(
+ region_set,
+ [cb_3, cb_4],
+ [cb_3, cb_4, O_3, ce_4],
+ [ce_4],
+ )
+
+
+if __name__ == "__main__":
+ test_region_set_creator_diamond()
+ test_region_set_creator_merged()
+