You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/04/29 18:23:26 UTC
[incubator-tvm] branch master updated: [BYOC] Bind constant tuples
in graph partitioner (#5476)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 046b0d9 [BYOC] Bind constant tuples in graph partitioner (#5476)
046b0d9 is described below
commit 046b0d98a08153a4829a12cc81a4fa856be6efcd
Author: mbaret <55...@users.noreply.github.com>
AuthorDate: Wed Apr 29 19:23:15 2020 +0100
[BYOC] Bind constant tuples in graph partitioner (#5476)
* Bind constant tuples in the graph partitioner
Change-Id: I815b32b5445a536c1837369b04f67dbbb0aed900
* Add partitioning test
Change-Id: I3a492ec8d1beab4830214e3bc8da2a7c80771ca4
* Rename test target
Change-Id: Ie32f37c1395ff597c0047ad3a93ed04ce3f3125d
---
src/relay/transforms/partition_graph.cc | 22 ++++++++++++---
tests/python/relay/test_pass_partition_graph.py | 37 +++++++++++++++++++++++++
2 files changed, 55 insertions(+), 4 deletions(-)
diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc
index 15ad60b..3b0d6bc 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -393,12 +393,26 @@ class Partitioner : public ExprMutator {
Array<Var> params;
Array<Expr> param_expr;
- std::unordered_map<std::string, runtime::NDArray> params_bind;
+ Map<Var, Expr> params_bind;
+
+ auto IsConstant = [](const Expr& expr) {
+ if (expr->IsInstance<ConstantNode>())
+ return true;
+ if (expr->IsInstance<TupleNode>()) {
+ auto tuple = expr.as<TupleNode>();
+ for (const auto& field : tuple->fields) {
+ if (!field->IsInstance<ConstantNode>())
+ return false;
+ }
+ return true;
+ }
+ return false;
+ };
for (auto pair : region_args[region]) {
params.push_back(pair.first);
- if (const auto* cn = pair.second.as<ConstantNode>()) {
- params_bind[pair.first->name_hint()] = cn->data;
+ if (IsConstant(pair.second)) {
+ params_bind.Set(pair.first, pair.second);
} else {
param_expr.push_back(pair.second);
}
@@ -428,7 +442,7 @@ class Partitioner : public ExprMutator {
// Constant propagation
if (!params_bind.empty()) {
- global_region_func = backend::BindParamsByName(global_region_func, params_bind);
+ global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind));
}
std::string fname = name;
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 2a4fd31..d78b9ea 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -1155,6 +1155,42 @@ def test_duplicate_merge_and_tuplegetitem():
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
+def test_constant_tuples():
+ @reg.register("qnn.concatenate", "target.const_tuples")
+ def add(attrs, args): # pylint: disable=unused-variable
+ return True
+
+ def create_graph():
+ a = relay.var('a', shape=(10, 10), dtype="uint8")
+ b = relay.var('b', shape=(10, 10), dtype="uint8")
+ a1 = relay.abs(a)
+
+ zeroi = relay.const(1, "int32")
+ zerof = relay.const(0, "float32")
+ con = relay.qnn.op.concatenate((a1, b),
+ input_scales=(zerof, zerof),
+ input_zero_points=(zeroi, zeroi),
+ output_scale=zerof,
+ output_zero_point=zeroi,
+ axis=1)
+
+ f = relay.Function([a, b], con)
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ seq = tvm.transform.Sequential([
+ transform.AnnotateTarget("const_tuples"),
+ transform.MergeCompilerRegions(),
+ transform.PartitionGraph(),
+ ])
+
+ partitioned = seq(create_graph())
+ concat = partitioned["const_tuples_0"].body
+ assert type(concat.args[1]) == relay.Tuple
+ assert type(concat.args[2]) == relay.Tuple
+ assert type(concat.args[3]) == relay.Constant
+ assert type(concat.args[4]) == relay.Constant
+
if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
@@ -1171,3 +1207,4 @@ if __name__ == "__main__":
test_multiple_use_of_an_output()
test_duplicate_outputs()
test_duplicate_merge_and_tuplegetitem()
+ test_constant_tuples()