You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/06/23 14:59:03 UTC

[incubator-mxnet] branch master updated: [master] Node elimination graph pass (#21046)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c486a0e304 [master] Node elimination graph pass (#21046)
c486a0e304 is described below

commit c486a0e304b61ab5f90b1ab0fc3afd1481bb7d66
Author: PiotrWolinski - Intel <pi...@intel.com>
AuthorDate: Thu Jun 23 16:58:47 2022 +0200

    [master] Node elimination graph pass (#21046)
    
    * Added node elimination graph pass
    
    * Fixed typo
    
    * Fix typo
    
    * Fix test_conv_subgraph.py
    
    * Fix test_amp_concat
    
    * Fix review
    
    Co-authored-by: Bartosz Kuncer <ba...@intel.com>
---
 src/c_api/c_api_symbolic.cc                        |  8 +++-
 src/c_api/c_api_test.cc                            |  1 +
 .../subgraph/eliminate_common_nodes_pass.cc        | 46 ++++++++++++++++++++++
 tests/python/dnnl/subgraphs/test_amp_subgraph.py   | 17 ++++----
 tests/python/dnnl/subgraphs/test_conv_subgraph.py  |  2 -
 5 files changed, 62 insertions(+), 12 deletions(-)

diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index e17d7beb2e..b70edf9852 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -1070,6 +1070,7 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle,
     nnvm::Graph g = Symbol2Graph(*s);
     property->SetAttr("graph", g);
     g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
+    g                            = ApplyPass(std::move(g), "EliminateCommonNodesPass");
     g                            = ApplyPass(std::move(g), "BuildSubgraph");
     property->RemoveAttr("graph");
     g.attrs.erase("subgraph_property");
@@ -1151,7 +1152,12 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
   NDArray** in_aux_ptr    = reinterpret_cast<NDArray**>(in_aux_handle);
 
   auto init_graph = [&](auto s) {
-    nnvm::Graph g                        = Symbol2Graph(*s);
+    nnvm::Graph g = Symbol2Graph(*s);
+
+    // EliminateCommonNodesPass must be performed before first call to the indexed graph,
+    // because otherwise changing graph via other passes will result in an error, due to the fact
+    // that once indexed_graph is created, it cannot be changed.
+    g                                    = ApplyPass(std::move(g), "EliminateCommonNodesPass");
     const auto& indexed_graph            = g.indexed_graph();
     const auto& mutable_nodes            = indexed_graph.mutable_input_nodes();
     std::vector<std::string> input_names = s->ListInputNames(nnvm::Symbol::kAll);
diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc
index 600de84c41..735667bcb6 100644
--- a/src/c_api/c_api_test.cc
+++ b/src/c_api/c_api_test.cc
@@ -50,6 +50,7 @@ int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
       property->SetAttr("graph", g);
       property->SetAttr("op_names", op_name_set);
       g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
+      g                            = nnvm::ApplyPass(std::move(g), "EliminateCommonNodesPass");
       g                            = nnvm::ApplyPass(std::move(g), "BuildSubgraph");
       property->RemoveAttr("graph");
       g.attrs.erase("subgraph_property");
diff --git a/src/operator/subgraph/eliminate_common_nodes_pass.cc b/src/operator/subgraph/eliminate_common_nodes_pass.cc
new file mode 100644
index 0000000000..6239c3142d
--- /dev/null
+++ b/src/operator/subgraph/eliminate_common_nodes_pass.cc
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file eliminate_common_nodes_pass.cc
+ * \brief Graph pass to eliminate common nodes from the input graph
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+
+#include "imperative/exec_pass.h"
+
+namespace mxnet {
+
+nnvm::Graph EliminateCommonNodesPass(nnvm::Graph&& g) {
+  const int enabled = dmlc::GetEnv("MXNET_NODE_ELIMINATION", 1);
+  if (enabled == 0) {
+    LOG(INFO) << "Skipping common nodes elimination.";
+    return std::move(g);
+  }
+
+  return exec::EliminateCommonExpr(std::move(g));
+}
+
+NNVM_REGISTER_PASS(EliminateCommonNodesPass)
+    .describe("Removes additional Nodes with identical inputs and function.")
+    .set_body(EliminateCommonNodesPass)
+    .set_change_graph(true);
+
+}  // namespace mxnet
diff --git a/tests/python/dnnl/subgraphs/test_amp_subgraph.py b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
index 2c5c6e1b45..b66ea44cde 100644
--- a/tests/python/dnnl/subgraphs/test_amp_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
@@ -23,6 +23,7 @@ from mxnet.test_utils import assert_almost_equal
 from subgraph_common import SG_PASS_NAME, QUANTIZE_SG_PASS_NAME
 from test_matmul_subgraph import MultiHeadAttention
 
+import os
 import sys
 from pathlib import Path
 curr_path = Path(__file__).resolve().parent
@@ -177,7 +178,6 @@ def test_amp_concat():
       super(TestNet, self).__init__()
       self.fc1 = nn.Dense(16)
       self.fc2 = nn.Dense(16)
-      self.fc2.share_parameters(self.fc1.collect_params())
 
     def forward(self, x):
       x1 = self.fc1(x)
@@ -192,19 +192,18 @@ def test_amp_concat():
 
   exp_data = mx.symbol.Variable('data')
   exp_amp_data = mx.symbol.amp_cast(exp_data, dtype=AMP_DTYPE)
-  exp_weight = mx.symbol.Variable('weight')
-  exp_bias = mx.symbol.Variable('bias')
-  exp_fc = [mx.symbol.FullyConnected(exp_amp_data, exp_weight, exp_bias, num_hidden=1)
-            for _ in range(2)]
+
+  exp_weight = [mx.symbol.Variable(f"fc{i}_weight") for i in range(2)]
+  exp_bias = [mx.symbol.Variable(f"fc{i}_bias") for i in range(2)]
+  exp_fc = [mx.symbol.FullyConnected(exp_amp_data, exp_weight[i], exp_bias[i], num_hidden=1)
+            for i in range(2)]
   exp_sym = mx.symbol.Concat(*exp_fc)
   exp_sym = mx.symbol.amp_cast(exp_sym, dtype='float32')
   exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
   check_amp_fuse(net, [data_example], exp_sym)
 
-  amp_weight = mx.symbol.amp_cast(exp_weight, dtype=AMP_DTYPE)
-  amp_bias = mx.symbol.amp_cast(exp_bias, dtype=AMP_DTYPE)
-  exp_fc[0] = mx.symbol.FullyConnected(exp_amp_data, amp_weight, amp_bias, num_hidden=1)
-  exp_fc[1] = mx.symbol.FullyConnected(exp_data, exp_weight, exp_bias, num_hidden=1)
+  exp_fc[0] = mx.symbol.FullyConnected(exp_amp_data, exp_weight[0], exp_bias[0], num_hidden=1)
+  exp_fc[1] = mx.symbol.FullyConnected(exp_data, exp_weight[1], exp_bias[1], num_hidden=1)
   exp_sym = mx.symbol.Concat(*exp_fc)
   exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
   check_amp_fuse(net, [data_example], exp_sym, ['sg_onednn_fully_connected_1'])
diff --git a/tests/python/dnnl/subgraphs/test_conv_subgraph.py b/tests/python/dnnl/subgraphs/test_conv_subgraph.py
index c700f96564..c6855d6eb0 100644
--- a/tests/python/dnnl/subgraphs/test_conv_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_conv_subgraph.py
@@ -235,7 +235,6 @@ def test_pos_conv_act_add(data_shape, alg, quantize, use_bias):
         else:
           self.act = nn.Activation(activation = alg)
         self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias)
-        self.conv1.share_parameters(self.conv0.collect_params())
 
     def forward(self, x):
         out = self.act(self.conv0(x)) + self.conv1(x)
@@ -308,7 +307,6 @@ def test_pos_conv_bn_sum_act(use_bias, data_shape, alg, quantize):
         super(ConvBNSumAct, self).__init__(**kwargs)
         self.conv0 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias)
         self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1)
-        self.conv1.share_parameters(self.conv0.collect_params())
         self.bn = nn.BatchNorm()
         if alg == "relu6":
           self.act = RELU6()