You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2022/03/07 15:35:06 UTC

[GitHub] [incubator-mxnet] bartekkuncer commented on a change in pull request #20753: [WIP] Improve AMP, bf16 support. Support oneDNN ops in AMP

bartekkuncer commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r820780134



##########
File path: include/mxnet/c_api.h
##########
@@ -2016,27 +2016,21 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle,
  */
 MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle,
                                       SymbolHandle* ret_sym_handle,
-                                      uint32_t num_args,
-                                      const int* arg_type_data,
-                                      uint32_t num_ind_ptr,
-                                      const int* ind_ptr,
-                                      const int* target_dtype,
-                                      const int cast_optional_params,
-                                      const uint32_t num_target_dtype_op_names,
-                                      const uint32_t num_fp32_op_names,
-                                      const uint32_t num_widest_dtype_op_names,
-                                      const uint32_t num_conditional_fp32_op_names,
+                                      const int target_dtype,

Review comment:
       Function description does not match the function (parameters).

##########
File path: src/operator/subgraph/dnnl/dnnl_post_amp_property.h
##########
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_AMP_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_AMP_PROPERTY_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <set>
+#include <string>
+#include <vector>
+
+#include "../../tensor/amp_cast.h"
+#include "../common.h"

Review comment:
       Please try to avoid dots in includes.

##########
File path: python/mxnet/amp/lists/symbol_bf16.py
##########
@@ -21,8 +21,13 @@
 # Functions that should be cast to lower precision
 BF16_FUNCS = [
     'Convolution',

Review comment:
       What about batch_dot?

##########
File path: python/mxnet/amp/amp.py
##########
@@ -427,10 +425,10 @@ def unscale(optimizer_or_trainer):
         raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
                         "an optimizer, instead is %s" % type(optimizer_or_trainer))
 
-def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
-                   fp32_ops=None, conditional_fp32_ops=None,
-                   excluded_sym_names=None, data_names=None,
-                   cast_optional_params=False):
+
+def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", target_dtype_ops=None,
+                   fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=[],
+                   cast_params_offline=False):

Review comment:
       Parameters list below does not match parameters of this function.

##########
File path: tests/python/gpu/test_amp.py
##########
@@ -111,84 +111,61 @@ def test_amp_conversion_rnn(amp_tests):
         mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2)
 
 
-def test_fp16_casting(amp_tests):
-    data = mx.sym.var("data")
-    out1 = mx.sym.amp_cast(data, dtype="float16")
-    out2 = mx.sym.amp_cast(data, dtype="float32")
-    out3 = mx.sym.amp_cast(data, dtype="float16")
-    # When two ops from data, with different dtypes,
-    # data should be float32
-    res = mx.sym.Group([out1, out2])
-    final_res = amp.convert_symbol(res, data_names=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float32
-
-    # When two ops from data, both casted to float16,
-    # data should be float16
-    res = mx.sym.Group([out1, out3])
-    final_res = amp.convert_symbol(res, data_names=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # AMP Multicast test where one node is float32, another is float16
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
-    final_res = amp.convert_symbol(out4, cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data2=(1, 2), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # AMP Multicast test where two non input nodes are float16,
-    # and one input node is float32
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    data3 = mx.sym.var("data3", dtype=np.float16)
-    out5 = mx.sym.amp_multicast(data,
-                                mx.sym.elemwise_add(data2, data3),
-                                num_outputs=2)
-    final_res = amp.convert_symbol(out5, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2), data3=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float32
-
-    # AMP Multicast test where three input nodes one fp16, one fp32
-    # one unknown
-    data = mx.sym.var("data", dtype=np.float16)
-    data2 = mx.sym.var("data2", dtype=np.float32)
-    data3 = mx.sym.var("data3")
-    out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
-    final_res = amp.convert_symbol(out6, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2),
-                                data3=(1, 2))
-    assert exe.arg_arrays[2].dtype == np.float32
-
-    # Input node to amp_multicast and amp_cast, if dtypes conflict
-    # and input node is already fp16, it should still be fp16
-    data = mx.sym.var("data", dtype=np.float16)
-    data2 = mx.sym.var("data2", dtype=np.float32)
-    out7 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype="float16")])
-    final_res = amp.convert_symbol(out7, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # Input node to amp_multicast and amp_cast, if dtypes conflict
-    # and input node is already fp32, it should be changed to fp16
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    out8 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype="float16")])
-    final_res = amp.convert_symbol(out8, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # Check for symbol which has slice channel
-    data = mx.sym.var("data")
-    data2 = mx.sym.var("data2")
-    data._set_attr(__dtype__="-1")
-    data2._set_attr(__dtype__="-1")
-    concat_res = mx.sym.concat(data, data2)
-    out = mx.sym.split(concat_res, axis=1, num_outputs=2)
-    final_res = amp.convert_symbol(out)
-
+@mx.util.use_np
+def test_bf16_offline_casting():
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super().__init__()
+      self.lp16_op1 = nn.Conv2D(4, 3)
+      self.lp16_op2 = nn.Conv2DTranspose(4, 3)
+      self.fp32_op = nn.Dense(4)
+
+    def forward(self, x):
+      x = self.lp16_op1(x)
+      x = self.lp16_op2(x)
+      x = x.reshape(x.shape[0], -1)
+      x = self.fp32_op(x)
+      return x
+
+  net = TestNet()
+  net.initialize()
+  data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
+  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype='float16',
+                                    target_dtype_ops=['Convolution'], fp32_ops=['FullyConnected'],
+                                    cast_params_offline=True, device=mx.current_context())
+  lp_net(data_example)
+  for name, data in lp_net.collect_params().items():
+    assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')
+
+
+@mx.util.use_np
+def test_bf16_offline_casting_shared_params():
+  COMMON_SIZE = 4
+
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super().__init__()
+      self.lp16_op1 = nn.Dense(COMMON_SIZE)
+      self.lp16_op2 = nn.Dense(COMMON_SIZE)
+      self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
+      self.fp32_op = nn.Conv1D(COMMON_SIZE, 3)
+      self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
+
+    def forward(self, x):
+      x = self.lp16_op1(x)
+      x1 = self.lp16_op2(x)
+      x2 = mx.np.expand_dims(x, 1)
+      x2 = self.fp32_op(x2)
+      x2 = nn.Flatten()(x2)
+      x = mx.np.concat((x1, x2), axis=1)
+      return x
+
+  net = TestNet()
+  net.initialize()
+  data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
+  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype='float16',
+                                    target_dtype_ops=['FullyConnected'], fp32_ops=['Convolution'],
+                                    cast_params_offline=True, device=mx.current_context())
+  lp_net(data_example)
+  for name, data in lp_net.collect_params().items():
+    assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')

Review comment:
       No new line.

##########
File path: python/mxnet/amp/amp.py
##########
@@ -534,62 +532,37 @@ def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
                             Op %s not in any of them''' % (illegal_ops)
 
     widest_dtype_ops = list_widest_type_cast(target_dtype)
-    if target_dtype == bfloat16:
-        target_dtype = _DTYPE_NP_TO_MX[bfloat16]
-    else:
-        target_dtype = _DTYPE_NP_TO_MX[np.dtype(target_dtype).type]
 
-    # Prepare a data_names list based on list_inputs if its not provided
-    # Add all names in list for the nodes in the symbol which don't have
-    # __dtype__ set
-    attr_dict = sym.attr_dict()
-    if data_names is None:
-        data_names = []
-        for sym_name in sym.list_inputs():
-            if not sym_name in attr_dict:
-                data_names.append(sym_name)
-                continue
-            if not "__dtype__" in attr_dict[sym_name]:
-                data_names.append(sym_name)
-    model_param_names = list(set(sym.list_inputs()) - set(data_names))
-
-    # Since assumption is that it is a FP32 model, set dtypes for all
-    # data_names to float32
-    str_keys = []
-    sdata = []
-    for k in data_names:
-        str_keys.append(k)
-        sdata.append(0)
-    keys = c_str_array(str_keys)
+    input_names = list(input_dtypes.keys())
+    all_arg_names, all_arg_types = [], []
+
+    for name, dtype in {**input_dtypes, **param_dtypes}.items():
+        all_arg_names.append(name)
+        all_arg_types.append(dtype_np_to_mx(dtype))
     out = SymbolHandle()
     check_call(_LIB.MXReducePrecisionSymbol(sym.handle,
                                             ctypes.byref(out),
-                                            mx_uint(len(sdata)),
-                                            c_array_buf(ctypes.c_int, array('i', sdata)),
-                                            mx_uint(len(indptr)),
-                                            c_array_buf(ctypes.c_int, array('i', indptr)),
-                                            ctypes.byref(ctypes.c_int(target_dtype)),
-                                            ctypes.c_int(cast_optional_params),
-                                            mx_uint(len(target_dtype_ops)),
-                                            mx_uint(len(fp32_ops)),
-                                            mx_uint(len(widest_dtype_ops)),
-                                            mx_uint(len(conditional_op_names)),
-                                            mx_uint(len(excluded_sym_names)),
-                                            mx_uint(len(model_param_names)),
+                                            ctypes.c_int(dtype_np_to_mx(target_dtype)),
+                                            ctypes.c_int(cast_params_offline),
+                                            ctypes.c_uint(len(input_names)),
+                                            ctypes.c_uint(len(all_arg_names)),
+                                            ctypes.c_uint(len(target_dtype_ops)),
+                                            ctypes.c_uint(len(fp32_ops)),
+                                            ctypes.c_uint(len(widest_dtype_ops)),
+                                            ctypes.c_uint(len(excluded_sym_names)),
+                                            c_str_array(input_names),
+                                            c_str_array(all_arg_names),
+                                            c_array_buf(ctypes.c_int, array('i', all_arg_types)),
                                             c_str_array(target_dtype_ops),
                                             c_str_array(fp32_ops),
                                             c_str_array(widest_dtype_ops),
-                                            c_str_array(conditional_op_names),
-                                            c_str_array(excluded_sym_names),
-                                            c_str_array(param_names),
-                                            c_str_array(param_vals),
-                                            c_str_array(model_param_names),
-                                            keys))
-    return Symbol(out)
-
-def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=None,
-                  fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=None,
-                  cast_optional_params=False):
+                                            c_str_array(excluded_sym_names)))
+    return type(sym)(out)
+
+
+def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype="float16",
+                  target_dtype_ops=None, fp32_ops=None, conditional_fp32_ops=None,
+                  excluded_sym_names=[], cast_params_offline=False):

Review comment:
       List of parameters descriptions does not match the actual ones. 

##########
File path: python/mxnet/amp/amp.py
##########
@@ -700,70 +659,50 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
         from being quantized
     device : Context
         Context on which model parameters should live
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be in LP16
         because of a cast layer following it, but will reduce the computation and memory
         overhead of the model if casted.
     """
     from ..gluon import HybridBlock, SymbolBlock
+    from ..ndarray import NDArray as ND_NDArray
+    from ..numpy import ndarray as NP_NDArray
+
     assert isinstance(block, HybridBlock), "block input should be a HybridBlock"
+    if not isinstance(data_example, list):
+        assert isinstance(data_example, (ND_NDArray, NP_NDArray))
+        data_example = [data_example]
+
     if not block._cached_graph:
-        raise RuntimeError(
-            "Please first call block.hybridize() and then run forward with "
-            "this block at least once before calling export.")
-
-    # Prepare inputs to pass to the convert_symbol API
-    inputs, sym = block._cached_graph
-    input_names = []
-    for inp in inputs:
-        input_names.append(inp.name)
-    converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops,
-                                   fp32_ops, conditional_fp32_ops,
-                                   excluded_sym_names, data_names=input_names,
-                                   cast_optional_params=cast_optional_params)
-
-    arg_names = set(converted_sym.list_arguments())
-    aux_names = set(converted_sym.list_auxiliary_states())
-    arg_dict = {}
-
-    # If dtype for the param was set in the json, cast the
-    # param to this dtype
-    attr_dict = converted_sym.attr_dict()
-    for param in block.collect_params().values():
-        name = param.name
-        if name in arg_names:
-            arg_dict['arg:%s'%name] = param._reduce()
-            if name in attr_dict and "__dtype__" in attr_dict[name]:
-                if attr_dict[name]["__dtype__"] != "-1":
-                    typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
-                    if typ == bfloat16:
-                        arg_dict['arg:%s' % name] = _cast_symbol_NDArray(arg_dict['arg:%s' % name], bfloat16)
-                    else:
-                        arg_dict['arg:%s'%name] = arg_dict['arg:%s'%name].astype(typ)
+        block.hybridize()
+        block(*data_example)
+
+    sym, params = block.export(None, remove_amp_cast=False)
+    args, auxs = {}, {}
+    for name, data in params.items():
+        if name.startswith('arg:'):
+            arg_name = name[len('arg:'):]
+            args[arg_name] = data
         else:
-            assert name in aux_names
-            arg_dict['aux:%s'%name] = param._reduce()
-            if name in attr_dict and "__dtype__" in attr_dict[name]:
-                if attr_dict[name]["__dtype__"] != "-1":
-                    typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
-                    if typ == bfloat16:
-                        arg_dict['aux:%s' % name] = _cast_symbol_NDArray(arg_dict['aux:%s' % name], 'bfloat16')
-                    else:
-                        arg_dict['aux:%s'%name] = arg_dict['aux:%s'%name].astype(typ)
+            assert name.startswith('aux:')
+            aux_name = name[len('aux:'):]
+            auxs[aux_name] = data
+
+    input_names = set(sym.list_arguments()) - (set(args.keys()) | set(auxs.keys()))
+    input_names_ordered = HybridBlock.generate_arg_names(len(data_example))
+    assert input_names == set(input_names_ordered)  # TODO: message

Review comment:
       TODO

##########
File path: src/operator/subgraph/dnnl/dnnl_post_amp_property.h
##########
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_AMP_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_AMP_PROPERTY_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <set>
+#include <string>
+#include <vector>
+
+#include "../../tensor/amp_cast.h"
+#include "../common.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool IsSupportedAMPFuseOp(const nnvm::Node& node) {
+  const auto& op = node.op();
+  return (op != nullptr &&
+          (op == Op::Get("_sg_onednn_conv") || op == Op::Get("_sg_onednn_fully_connected") ||
+           op == Op::Get("_sg_onednn_selfatt_qk") || op == Op::Get("_sg_onednn_selfatt_valatt")));
+}
+
+class SgDNNLPostAMPSelector : public SubgraphSelector {
+ public:
+  /*! \brief pattern match status */
+  enum class SelectStatus {
+    kFail = 0,
+    kStart,
+    kSuccess,
+  };
+
+ private:
+  SelectStatus status;
+  std::vector<const nnvm::Node*> matched_list;
+
+ public:
+  bool Select(const nnvm::Node& n) override {
+    if (IsSupportedAMPFuseOp(n)) {
+      status = SelectStatus::kStart;
+      matched_list.clear();
+      matched_list.push_back(&n);
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+    if (status == SelectStatus::kFail || new_node.is_variable())
+      return false;
+    // If n isn't the last matched node, then we encoutered a internal

Review comment:
       ```suggestion
       // If 'n' is not the last matched node, then we have encountered an internal
   ```

##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +457,73 @@ def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
         from being casted to LP16 or FP32.
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be in LP16
         because of a cast layer following it, but will reduce the computation and memory
         overhead of the model if casted.
     """
-    assert isinstance(sym, Symbol), "First argument to convert_symbol should be Symbol"
-
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
+    import json
 
-    if target_dtype == 'bfloat16':
-        target_dtype = bfloat16
-
-    if target_dtype_ops is not None:
-        assert isinstance(target_dtype_ops, list), "target_dtype_ops should be a list of strs"
-    else:
+    assert isinstance(sym, Symbol), "First argument to convert_symbol should be Symbol"

Review comment:
       ```suggestion
       assert isinstance(sym, Symbol), "First argument to convert_symbol should be a Symbol"
   ```

##########
File path: src/operator/subgraph/dnnl/dnnl_post_amp_property.h
##########
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_AMP_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_AMP_PROPERTY_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <set>
+#include <string>
+#include <vector>
+
+#include "../../tensor/amp_cast.h"
+#include "../common.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool IsSupportedAMPFuseOp(const nnvm::Node& node) {
+  const auto& op = node.op();
+  return (op != nullptr &&
+          (op == Op::Get("_sg_onednn_conv") || op == Op::Get("_sg_onednn_fully_connected") ||
+           op == Op::Get("_sg_onednn_selfatt_qk") || op == Op::Get("_sg_onednn_selfatt_valatt")));
+}
+
+class SgDNNLPostAMPSelector : public SubgraphSelector {
+ public:
+  /*! \brief pattern match status */
+  enum class SelectStatus {
+    kFail = 0,
+    kStart,
+    kSuccess,
+  };
+
+ private:
+  SelectStatus status;
+  std::vector<const nnvm::Node*> matched_list;
+
+ public:
+  bool Select(const nnvm::Node& n) override {
+    if (IsSupportedAMPFuseOp(n)) {
+      status = SelectStatus::kStart;
+      matched_list.clear();
+      matched_list.push_back(&n);
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+    if (status == SelectStatus::kFail || new_node.is_variable())
+      return false;
+    // If n isn't the last matched node, then we encoutered a internal
+    // branch, we should pop out the node behind n and stop fusion.

Review comment:
       ```suggestion
       // branch and we should pop out the node behind 'n' and stop fusion.
   ```

##########
File path: python/mxnet/amp/amp.py
##########
@@ -622,61 +595,47 @@ def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dt
     excluded_sym_names : list of strs
         A list of strings that represent the names of symbols that users want to exclude
         from being executed in lower precision.
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be in LP16
         because of a cast layer following it, but will reduce the computation and memory
         overhead of the model if casted.
     """
-    if excluded_sym_names is None:
-        excluded_sym_names = []
-        if not isinstance(excluded_sym_names, list):
-            raise ValueError('excluded_sym_names must be a list of strings representing'
-                             ' the names of the symbols that should not be casted,'
-                             ' while received type %s' % str(type(excluded_sym_names)))
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
-
     assert isinstance(sym, Symbol), "First argument to convert_model should be Symbol"

Review comment:
       ```suggestion
       assert isinstance(sym, Symbol), "First argument to convert_model should be a Symbol"
   ```

##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,322 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "../operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool is_cast_op(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
 
-// get suffix for a node entry so that it can be used for amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& mirror_map) {
-  static const auto& flist_outputs = nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+class MappedNodeEntry {
+ public:
+  MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+      : entry(std::move(node_entry)), original_dtype(original_dtype) {
+    dtype = original_dtype;
+  }
+
+  void convert(const int new_dtype) {
+    CHECK_EQ(dtype, original_dtype);  // dtype should be changed only once
+    dtype = new_dtype;
+  }
+
+  const NodeEntry& as_original() {
+    return as_type(original_dtype);
+  }
+
+  const NodeEntry& as_type(const int target_dtype) {
+    if (dtype == target_dtype) {
+      return entry;
+    }
+    NodeEntry& cast_entry = casts[target_dtype];
+    if (cast_entry.node == nullptr) {
+      cast_entry = cast(target_dtype);
+      CHECK(cast_entry.node);
     }
+    return cast_entry;
   }
-  return suffix;
-}
 
-// add amp_cast node between curr_node and input
-static void AddCastNode(const nnvm::NodeEntry& e,
-                        const std::string& suffix,
-                        const nnvm::NodeEntry& input,
-                        const std::string dtype,
-                        nnvm::NodeEntryMap<NodeEntry>* mirror_entry_map,
-                        ObjectPtr curr_node) {
-  ObjectPtr cast_node =
-      InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + dtype, curr_node, input);
-  cast_node->attrs.dict["dtype"] = dtype;
-  cast_node->op()->attr_parser(&(cast_node->attrs));
-  (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version};
-  return;
-}
+  bool has_dtype_entry(const int target_dtype) const {
+    return dtype == target_dtype || casts.count(target_dtype) > 0;
+  }
+
+  bool can_be_cast_offline_to(const int target_dtype) const {

Review comment:
       Why amount of casts to a particular dtype determines whether casting to this dtype can be performed offline?

##########
File path: python/mxnet/amp/amp.py
##########
@@ -622,61 +595,47 @@ def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dt
     excluded_sym_names : list of strs
         A list of strings that represent the names of symbols that users want to exclude
         from being executed in lower precision.
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be in LP16
         because of a cast layer following it, but will reduce the computation and memory
         overhead of the model if casted.
     """
-    if excluded_sym_names is None:
-        excluded_sym_names = []
-        if not isinstance(excluded_sym_names, list):
-            raise ValueError('excluded_sym_names must be a list of strings representing'
-                             ' the names of the symbols that should not be casted,'
-                             ' while received type %s' % str(type(excluded_sym_names)))
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
-
     assert isinstance(sym, Symbol), "First argument to convert_model should be Symbol"
-    assert isinstance(arg_params, dict), "Second argument to convert_model should be a dict of name to ndarray"
-    assert isinstance(aux_params, dict), "Third argument to convert_model should be a dict of name to ndarray"
-
-    param_names = list(arg_params.keys()) + list(aux_params.keys())
-
-    # Only pass non params as data_names, param types can be inferred
-    data_names = list(set(sym.list_inputs()) - set(param_names))
-    sym = convert_symbol(sym, target_dtype, target_dtype_ops,
-                         fp32_ops, conditional_fp32_ops,
-                         excluded_sym_names, data_names,
-                         cast_optional_params)
+    assert isinstance(
+        arg_params, dict), "Second argument to convert_model should be a dict of name to ndarray"
+    assert isinstance(
+        aux_params, dict), "Third argument to convert_model should be a dict of name to ndarray"
+
+    arg_params = arg_params.copy()
+    aux_params = aux_params.copy()
+    param_dtypes = {name: data.dtype for name, data in arg_params.items()}
+    param_dtypes.update({name: data.dtype for name, data in aux_params.items()})
+    sym = convert_symbol(sym, input_dtypes, param_dtypes, target_dtype, target_dtype_ops,
+                         fp32_ops, conditional_fp32_ops, excluded_sym_names, cast_params_offline)
 
     # If dtype is set for params, cast the param to that dtype
     attr_dict = sym.attr_dict()
     for sym_name in sym.list_arguments():
-        if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]:
-            if attr_dict[sym_name]["__dtype__"] != "-1":
-                typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])]
-                if typ == bfloat16:
-                    arg_params[sym_name] = _cast_symbol_NDArray(arg_params[sym_name], bfloat16)
-                else:
-                    arg_params[sym_name] = arg_params[sym_name].astype(typ)
+        if attr_dict.get(sym_name, {}).get("__dtype__", "-1") != "-1" and sym_name in arg_params:
+            typ = dtype_mx_to_np(int(attr_dict[sym_name]["__dtype__"]))
+            if arg_params[sym_name].dtype != typ:
+                arg_params[sym_name] = arg_params[sym_name].astype(typ)
 
     for sym_name in sym.list_auxiliary_states():
-        if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]:
-            if attr_dict[sym_name]["__dtype__"] != "-1":
-                typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])]
-                if typ == bfloat16:
-                    aux_params[sym_name] = _cast_symbol_NDArray(aux_params[sym_name], bfloat16)
-                else:
-                    aux_params[sym_name] = aux_params[sym_name].astype(typ)
+        if attr_dict.get(sym_name, {}).get("__dtype__", "-1") != "-1" and sym_name in aux_params:
+            typ = dtype_mx_to_np(int(attr_dict[sym_name]["__dtype__"]))
+            if aux_params[sym_name].dtype != typ:
+                aux_params[sym_name] = aux_params[sym_name].astype(typ)
 
     # Return the converted symbol and casted params
     return sym, arg_params, aux_params
 
+
 @wrap_ctx_to_device_func
-def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
+def convert_hybrid_block(block, data_example, target_dtype="float16", target_dtype_ops=None,
                          fp32_ops=None, conditional_fp32_ops=None,
-                         excluded_sym_names=None, device=gpu(0),
-                         cast_optional_params=False):
+                         excluded_sym_names=[], device=gpu(0),
+                         cast_params_offline=False):

Review comment:
       Lacking parameter description.

##########
File path: tests/python/dnnl/subgraphs/test_amp_subgraph.py
##########
@@ -0,0 +1,243 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import mxnet as mx
+import mxnet.gluon.nn as nn
+from mxnet import amp
+from mxnet.amp.amp import bfloat16
+from mxnet.test_utils import assert_almost_equal
+from subgraph_common import SG_PASS_NAME, QUANTIZE_SG_PASS_NAME
+from test_matmul_subgraph import MultiHeadAttention
+
+AMP_SG_PASS_NAME = 'ONEDNN_AMP'
+AMP_DTYPE = bfloat16
+
+
+# Checks if amp (after the AMP_SG_PASS_NAME fuse) changes the name of tensors for calibration
+def check_amp_with_quantization(net, data_example, quantized_nodes):
+  net.optimize_for(data_example, backend=QUANTIZE_SG_PASS_NAME)
+  symnet = net.export(None)[0]
+  nodes = {n['name'] for n in json.loads(symnet.tojson())['nodes'] if n['op'] != 'null'}
+  quant_excluded_nodes = list(nodes - set(quantized_nodes))
+
+  _, calib_tensors1 = mx.contrib.quantization._quantize_symbol(
+      symnet, mx.current_context(), excluded_symbols=quant_excluded_nodes)
+
+  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=AMP_DTYPE,
+                                    excluded_sym_names=quantized_nodes, cast_params_offline=True,
+                                    device=mx.current_context())
+  lp_net.optimize_for(data_example, backend=AMP_SG_PASS_NAME)
+  lp_symnet = lp_net.export(None, remove_amp_cast=False)[0]
+  _, calib_tensors2 = mx.contrib.quantization._quantize_symbol(
+      lp_symnet, mx.cpu(), excluded_symbols=quant_excluded_nodes)
+  assert calib_tensors1 == calib_tensors2
+
+
+def same_graph_structure(symnet1, symnet2, expected):
+  nodes1 = json.loads(symnet1.tojson(remove_amp_cast=False))['nodes']
+  nodes2 = json.loads(symnet2.tojson(remove_amp_cast=False))['nodes']
+  assert (len(nodes1) == len(nodes2)) == expected
+  for node1, node2 in zip(nodes1, nodes2):
+    if node1['op'] != node2['op'] or node1['inputs'] != node2['inputs']:
+      assert expected == False
+      break
+
+
+def check_amp_fuse(net, data_example, expected_sym=None, quantized_nodes=[], rtol=0.05):
+  net.hybridize()
+  out_ref = net(*data_example)
+
+  net.optimize_for(data_example, backend=SG_PASS_NAME)  # amp pass works only on onednn nodes

Review comment:
       ```suggestion
     net.optimize_for(data_example, backend=SG_PASS_NAME)  # amp pass works only on oneDNN nodes
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org