You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/06/23 14:59:03 UTC
[incubator-mxnet] branch master updated: [master] Node elimination graph pass (#21046)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c486a0e304 [master] Node elimination graph pass (#21046)
c486a0e304 is described below
commit c486a0e304b61ab5f90b1ab0fc3afd1481bb7d66
Author: PiotrWolinski - Intel <pi...@intel.com>
AuthorDate: Thu Jun 23 16:58:47 2022 +0200
[master] Node elimination graph pass (#21046)
* Added node elimination graph pass
* Fixed typo
* Fix typo
* Fix test_conv_subgraph.py
* Fix test_amp_concat
* Fix review
Co-authored-by: Bartosz Kuncer <ba...@intel.com>
---
src/c_api/c_api_symbolic.cc | 8 +++-
src/c_api/c_api_test.cc | 1 +
.../subgraph/eliminate_common_nodes_pass.cc | 46 ++++++++++++++++++++++
tests/python/dnnl/subgraphs/test_amp_subgraph.py | 17 ++++----
tests/python/dnnl/subgraphs/test_conv_subgraph.py | 2 -
5 files changed, 62 insertions(+), 12 deletions(-)
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index e17d7beb2e..b70edf9852 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -1070,6 +1070,7 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle,
nnvm::Graph g = Symbol2Graph(*s);
property->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
+ g = ApplyPass(std::move(g), "EliminateCommonNodesPass");
g = ApplyPass(std::move(g), "BuildSubgraph");
property->RemoveAttr("graph");
g.attrs.erase("subgraph_property");
@@ -1151,7 +1152,12 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
NDArray** in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
auto init_graph = [&](auto s) {
- nnvm::Graph g = Symbol2Graph(*s);
+ nnvm::Graph g = Symbol2Graph(*s);
+
+ // EliminateCommonNodesPass must be performed before first call to the indexed graph,
+ // because otherwise changing graph via other passes will result in an error, due to the fact
+ // that once indexed_graph is created, it cannot be changed.
+ g = ApplyPass(std::move(g), "EliminateCommonNodesPass");
const auto& indexed_graph = g.indexed_graph();
const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
std::vector<std::string> input_names = s->ListInputNames(nnvm::Symbol::kAll);
diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc
index 600de84c41..735667bcb6 100644
--- a/src/c_api/c_api_test.cc
+++ b/src/c_api/c_api_test.cc
@@ -50,6 +50,7 @@ int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
property->SetAttr("graph", g);
property->SetAttr("op_names", op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
+ g = nnvm::ApplyPass(std::move(g), "EliminateCommonNodesPass");
g = nnvm::ApplyPass(std::move(g), "BuildSubgraph");
property->RemoveAttr("graph");
g.attrs.erase("subgraph_property");
diff --git a/src/operator/subgraph/eliminate_common_nodes_pass.cc b/src/operator/subgraph/eliminate_common_nodes_pass.cc
new file mode 100644
index 0000000000..6239c3142d
--- /dev/null
+++ b/src/operator/subgraph/eliminate_common_nodes_pass.cc
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file eliminate_common_nodes_pass.cc
+ * \brief Graph pass to eliminate common nodes from the input graph
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+
+#include "imperative/exec_pass.h"
+
+namespace mxnet {
+
+nnvm::Graph EliminateCommonNodesPass(nnvm::Graph&& g) {
+ const int enabled = dmlc::GetEnv("MXNET_NODE_ELIMINATION", 1);
+ if (enabled == 0) {
+ LOG(INFO) << "Skipping common nodes elimination.";
+ return std::move(g);
+ }
+
+ return exec::EliminateCommonExpr(std::move(g));
+}
+
+NNVM_REGISTER_PASS(EliminateCommonNodesPass)
+ .describe("Removes additional Nodes with identical inputs and function.")
+ .set_body(EliminateCommonNodesPass)
+ .set_change_graph(true);
+
+} // namespace mxnet
diff --git a/tests/python/dnnl/subgraphs/test_amp_subgraph.py b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
index 2c5c6e1b45..b66ea44cde 100644
--- a/tests/python/dnnl/subgraphs/test_amp_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
@@ -23,6 +23,7 @@ from mxnet.test_utils import assert_almost_equal
from subgraph_common import SG_PASS_NAME, QUANTIZE_SG_PASS_NAME
from test_matmul_subgraph import MultiHeadAttention
+import os
import sys
from pathlib import Path
curr_path = Path(__file__).resolve().parent
@@ -177,7 +178,6 @@ def test_amp_concat():
super(TestNet, self).__init__()
self.fc1 = nn.Dense(16)
self.fc2 = nn.Dense(16)
- self.fc2.share_parameters(self.fc1.collect_params())
def forward(self, x):
x1 = self.fc1(x)
@@ -192,19 +192,18 @@ def test_amp_concat():
exp_data = mx.symbol.Variable('data')
exp_amp_data = mx.symbol.amp_cast(exp_data, dtype=AMP_DTYPE)
- exp_weight = mx.symbol.Variable('weight')
- exp_bias = mx.symbol.Variable('bias')
- exp_fc = [mx.symbol.FullyConnected(exp_amp_data, exp_weight, exp_bias, num_hidden=1)
- for _ in range(2)]
+
+ exp_weight = [mx.symbol.Variable(f"fc{i}_weight") for i in range(2)]
+ exp_bias = [mx.symbol.Variable(f"fc{i}_bias") for i in range(2)]
+ exp_fc = [mx.symbol.FullyConnected(exp_amp_data, exp_weight[i], exp_bias[i], num_hidden=1)
+ for i in range(2)]
exp_sym = mx.symbol.Concat(*exp_fc)
exp_sym = mx.symbol.amp_cast(exp_sym, dtype='float32')
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
check_amp_fuse(net, [data_example], exp_sym)
- amp_weight = mx.symbol.amp_cast(exp_weight, dtype=AMP_DTYPE)
- amp_bias = mx.symbol.amp_cast(exp_bias, dtype=AMP_DTYPE)
- exp_fc[0] = mx.symbol.FullyConnected(exp_amp_data, amp_weight, amp_bias, num_hidden=1)
- exp_fc[1] = mx.symbol.FullyConnected(exp_data, exp_weight, exp_bias, num_hidden=1)
+ exp_fc[0] = mx.symbol.FullyConnected(exp_amp_data, exp_weight[0], exp_bias[0], num_hidden=1)
+ exp_fc[1] = mx.symbol.FullyConnected(exp_data, exp_weight[1], exp_bias[1], num_hidden=1)
exp_sym = mx.symbol.Concat(*exp_fc)
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
check_amp_fuse(net, [data_example], exp_sym, ['sg_onednn_fully_connected_1'])
diff --git a/tests/python/dnnl/subgraphs/test_conv_subgraph.py b/tests/python/dnnl/subgraphs/test_conv_subgraph.py
index c700f96564..c6855d6eb0 100644
--- a/tests/python/dnnl/subgraphs/test_conv_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_conv_subgraph.py
@@ -235,7 +235,6 @@ def test_pos_conv_act_add(data_shape, alg, quantize, use_bias):
else:
self.act = nn.Activation(activation = alg)
self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias)
- self.conv1.share_parameters(self.conv0.collect_params())
def forward(self, x):
out = self.act(self.conv0(x)) + self.conv1(x)
@@ -308,7 +307,6 @@ def test_pos_conv_bn_sum_act(use_bias, data_shape, alg, quantize):
super(ConvBNSumAct, self).__init__(**kwargs)
self.conv0 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias)
self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1)
- self.conv1.share_parameters(self.conv0.collect_params())
self.bn = nn.BatchNorm()
if alg == "relu6":
self.act = RELU6()