You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/05/17 12:34:25 UTC

[incubator-mxnet] branch master updated: AMP improvements + enable bf16 input for quantize_v2 (#20983)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3098af4a80 AMP improvements + enable bf16 input for quantize_v2 (#20983)
3098af4a80 is described below

commit 3098af4a808f905e285a4100c6b96b8e126c2c5a
Author: Paweł Głomski <pa...@intel.com>
AuthorDate: Tue May 17 14:34:09 2022 +0200

    AMP improvements + enable bf16 input for quantize_v2 (#20983)
    
    * AMP improvements + enable bf16 input for quantize_v2
    
    * Fix sanity
    
    * Improve tests, AMP conversion interface, fix forwad hooks
    
    * Fix tests
    
    * Fix imports in tests
    
    * Use different lp16_fp32 op in test
    
    * Add amp.disable_amp() context, fix tests
    
    * Add tests, generalize optimization disabling
    
    * Fix sanity
    
    * Review fixes
    
    * Use is_integral<>::value
    
    * Review fixes # rerun CI
    Change flag type to unsigned int
    Add a warning for an incorrect flag attribute value
    
    * Add message to static_assert
    
    * Cast flag attribute value to int before using
    
    * Test amp node excluding, change names of tests
    
    * Fix static_cast type
    
    * Fix sanity
---
 include/mxnet/c_api.h                            |  17 +-
 include/mxnet/imperative.h                       |  21 ++
 python/mxnet/amp/amp.py                          |  56 +++---
 python/mxnet/amp/lists/symbol_bf16.py            |  18 +-
 python/mxnet/gluon/block.py                      |  46 ++++-
 python/mxnet/ndarray/ndarray.py                  |   2 +
 src/c_api/c_api_ndarray.cc                       |  13 ++
 src/c_api/c_api_symbolic.cc                      |   7 +-
 src/common/utils.h                               |  32 +++
 src/imperative/imperative.cc                     |   7 +
 src/nnvm/low_precision_pass.cc                   |  88 ++++++--
 src/operator/quantization/quantize_v2-inl.h      |  12 ++
 tests/python/amp/common.py                       | 245 +++++++++++++++++++++++
 tests/python/dnnl/subgraphs/test_amp_subgraph.py |  45 ++++-
 tests/python/dnnl/test_amp.py                    | 149 +++-----------
 tests/python/gpu/test_amp.py                     | 176 ++++------------
 tests/python/quantization/test_quantization.py   |  17 ++
 17 files changed, 626 insertions(+), 325 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index eb13feedb1..d8757c7f9f 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1225,6 +1225,19 @@ MXNET_DLL int MXAutogradIsRecording(bool* curr);
  * \return 0 when success, -1 when failure happens
  */
 MXNET_DLL int MXAutogradIsTraining(bool* curr);
+/*!
+ * \brief set what optimization constraints to apply
+ * \param constraints state composed of OptConstraint flags.
+ * \param prev returns the previous status before this set.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSetOptimizationConstraints(unsigned int constraints, unsigned int* prev);
+/*!
+ * \brief get current optimization constraints
+ * \param curr returns the current status
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXGetOptimizationConstraints(unsigned int* curr);
 /*!
  * \brief get whether numpy compatibility is on
  * \param curr returns the current status
@@ -2025,9 +2038,7 @@ MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle,
                                       const uint32_t num_fp32_ops,
                                       const char** const fp32_ops_p,
                                       const uint32_t num_widest_dtype_ops,
-                                      const char** const widest_dtype_ops_p,
-                                      const uint32_t num_excluded_symbols,
-                                      const char** const excluded_syms_p);
+                                      const char** const widest_dtype_ops_p);
 
 /*!
  * \brief Set calibration table to node attributes in the sym
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index e4e3f6a938..42876f7bf4 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -35,6 +35,14 @@
 #include "./ndarray.h"
 
 namespace mxnet {
+
+constexpr char OPT_CONSTRAINT_ATTR[] = "__opt_constraint__";
+enum class OptConstraint : unsigned int {
+  None       = 0,
+  DisableAMP = 1 << 0
+  // DisableQuantization = 1 << 1
+};
+
 /*! \brief there are three numpy shape flags based on priority.
  * GlobalOn
  *   turn on numpy shape flag globally, it includes thread local.
@@ -47,6 +55,7 @@ namespace mxnet {
  * */
 enum NumpyShape { Off, ThreadLocalOn, GlobalOn };
 typedef NumpyShape NumpyDefaultDtype;
+
 /*! \brief runtime functions for NDArray */
 class Imperative {
  public:
@@ -237,6 +246,16 @@ class Imperative {
     }
     return old;
   }
+  /*! \brief return current optimization constraints. */
+  OptConstraint get_opt_constraints() const {
+    return opt_constraints_;
+  }
+  /*! \brief set optimization constraints. */
+  OptConstraint set_opt_constraints(OptConstraint constraints) {
+    OptConstraint old = opt_constraints_;
+    opt_constraints_  = constraints;
+    return old;
+  }
   /*! \brief to record operator, return corresponding node. */
   void RecordOp(nnvm::NodeAttrs&& attrs,
                 const std::vector<NDArray*>& inputs,
@@ -321,6 +340,7 @@ class Imperative {
   static thread_local bool is_train_;
   static thread_local bool is_recording_;
   static thread_local bool is_deferred_compute_;
+  static thread_local OptConstraint opt_constraints_;
   // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
   // Delete it in the next major release.
   static thread_local bool is_np_shape_thread_local_;
@@ -328,6 +348,7 @@ class Imperative {
   static MX_THREAD_LOCAL bool is_train_;
   static MX_THREAD_LOCAL bool is_recording_;
   static MX_THREAD_LOCAL bool is_deferred_compute_;
+  static MX_THREAD_LOCAL OptConstraint opt_constraints_;
   // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
   // Delete it in the next major release.
   static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
diff --git a/python/mxnet/amp/amp.py b/python/mxnet/amp/amp.py
index 2d370ce7ab..b73e48e846 100644
--- a/python/mxnet/amp/amp.py
+++ b/python/mxnet/amp/amp.py
@@ -39,7 +39,7 @@ from ..symbol import contrib as symbol_contrib
 from .. import ndarray
 from ..ndarray import NDArray, dtype_np_to_mx, get_dtype_type, get_dtype_name, bfloat16
 from . import lists
-from ..gluon import Block, trainer
+from ..gluon import Block, HybridBlock, trainer
 from .. import base
 from ..base import (_NP_OP_PREFIX, _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_PREFIX,
                     _NP_EXT_OP_SUBMODULE_LIST, _NP_INTERNAL_OP_PREFIX,
@@ -428,7 +428,7 @@ def unscale(optimizer_or_trainer):
                         "an optimizer, instead is %s" % type(optimizer_or_trainer))
 
 
-def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", target_dtype_ops=None,
+def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype, target_dtype_ops=None,
                    fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=[],
                    cast_params_offline=False):
     """Given a symbol object representing a neural network of data type FP32 and target_dtype,
@@ -464,9 +464,7 @@ def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", targ
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
     cast_params_offline : bool, default False
-        Whether to cast the arg_params and aux_params that don't require to be in LP16
-        because of a cast layer following it, but will reduce the computation and memory
-        overhead of the model if casted.
+        Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
     """
     import json
 
@@ -497,22 +495,30 @@ def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", targ
             "conditional_fp32_ops should be a list of (str, str, list of str)"
         cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
 
-    nodes_attr = sym.attr_dict()
+    nodes_attrs = sym.attr_dict()
     nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
-    if not set(excluded_sym_names).issubset(set(nodes_op.keys())):
-        logging.warning("excluded_sym_names are not present in the network. Missing layers: {}".format(
-            set(excluded_sym_names) - set(nodes_op.keys())))
-
     for node_name, node_op in nodes_op.items():
         if node_op not in cond_ops:
             continue
-        node_attrs = nodes_attr[node_name]
+        node_attrs = nodes_attrs[node_name]
         for attr_name, attr_vals in cond_ops[node_op].items():
             assert attr_name in node_attrs
             if node_attrs[attr_name] in attr_vals:
-                excluded_sym_names += node_name
+                excluded_sym_names.append(node_name)
                 break
-    excluded_sym_names = list(set(excluded_sym_names))
+
+    excluded_sym_names = set(excluded_sym_names)
+    for node in sym.get_internals():
+        if node.name in excluded_sym_names:
+            excluded_sym_names.remove(node.name)
+            opt_constraints = node.attr('__opt_constraint__')
+            opt_constraints = 0 if opt_constraints is None else int(opt_constraints)
+            opt_constraints |= HybridBlock.OptConstraint.Flag.DisableAMP.value
+            node._set_attr(__opt_constraint__=str(opt_constraints))
+
+    if len(excluded_sym_names) > 0:
+        logging.warning("excluded_sym_names are not present in the network. Missing nodes: {}".format(
+            excluded_sym_names))
 
     # Op lists should not intersect
     common_ops = set(target_dtype_ops) & set(fp32_ops)
@@ -561,13 +567,11 @@ def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", targ
                                             ctypes.c_uint(len(fp32_ops)),
                                             c_str_array(fp32_ops),
                                             ctypes.c_uint(len(widest_dtype_ops)),
-                                            c_str_array(widest_dtype_ops),
-                                            ctypes.c_uint(len(excluded_sym_names)),
-                                            c_str_array(excluded_sym_names)))
+                                            c_str_array(widest_dtype_ops)))
     return type(sym)(out)
 
 
-def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype="float16",
+def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype,
                   target_dtype_ops=None, fp32_ops=None, conditional_fp32_ops=None,
                   excluded_sym_names=[], cast_params_offline=False):
     """API for converting a model from FP32 model to a mixed precision model.
@@ -605,9 +609,7 @@ def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype="float
         A list of strings that represent the names of symbols that users want to exclude
         from being executed in lower precision.
     cast_params_offline : bool, default False
-        Whether to cast the arg_params and aux_params that don't require to be in LP16
-        because of a cast layer following it, but will reduce the computation and memory
-        overhead of the model if casted.
+        Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
     """
     assert isinstance(sym, Symbol), "First argument to convert_model should be a Symbol"
     assert isinstance(
@@ -641,9 +643,9 @@ def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype="float
 
 
 @wrap_ctx_to_device_func
-def convert_hybrid_block(block, data_example, target_dtype="float16", target_dtype_ops=None,
+def convert_hybrid_block(block, data_example, target_dtype, target_dtype_ops=None,
                          fp32_ops=None, conditional_fp32_ops=None,
-                         excluded_sym_names=[], device=gpu(0),
+                         excluded_sym_names=[], device=None,
                          cast_params_offline=False):
     """Given a hybrid block/symbol block representing a FP32 model and a target_dtype,
     return a block with mixed precision support which can be used for inference use cases.
@@ -668,14 +670,12 @@ def convert_hybrid_block(block, data_example, target_dtype="float16", target_dty
     excluded_sym_names : list of strs
         A list of strings that represent the names of symbols that users want to exclude
         from being quantized
-    device : Context
-        Context on which model parameters should live
+    device : Device
+        Device on which model parameters should live. Default value: current device.
     cast_params_offline : bool, default False
-        Whether to cast the arg_params and aux_params that don't require to be in LP16
-        because of a cast layer following it, but will reduce the computation and memory
-        overhead of the model if casted.
+        Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
     """
-    from ..gluon import HybridBlock, SymbolBlock
+    from ..gluon import SymbolBlock
     from ..ndarray import NDArray as ND_NDArray, waitall
     from ..numpy import ndarray as NP_NDArray
 
diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py
index bab52688b4..566990c411 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -18,16 +18,22 @@
 # coding: utf-8
 """Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API."""
 
+from ...runtime import Features
+
 # Functions that should be cast to lower precision
 BF16_FUNCS = [
     'Convolution',
     'Deconvolution',
-    'FullyConnected',
-    '_sg_onednn_conv',
-    '_sg_onednn_fully_connected',
-    '_sg_onednn_selfatt_qk',
-    '_sg_onednn_selfatt_valatt'
+    'FullyConnected'
 ]
+if Features.instance.is_enabled('ONEDNN'):
+    BF16_FUNCS.extend([
+        '_sg_onednn_conv',
+        '_sg_onednn_fully_connected',
+        '_sg_onednn_selfatt_qk',
+        '_sg_onednn_selfatt_valatt'
+    ])
+
 
 # Functions that should not be casted, either because
 # they are irrelevant (not used in the network itself
@@ -45,6 +51,7 @@ BF16_FP32_FUNCS = [
     'sqrt',
     'square',
     'tanh',
+    '_contrib_quantize_v2',
 ]
 
 # Functions that when running with Bfloat16, the params that still need float32.
@@ -98,7 +105,6 @@ FP32_FUNCS = [
     '_contrib_quadratic',
     '_contrib_quantize',
     '_contrib_quantize_asym',
-    '_contrib_quantize_v2',
     '_contrib_quantized_concat',
     '_contrib_quantized_conv',
     '_contrib_quantized_flatten',
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 5b327b1a8d..cff346b9f4 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -20,6 +20,8 @@
 """Base container class for all neural network models."""
 __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
 
+import enum
+import ctypes
 import copy
 import warnings
 import weakref
@@ -1045,6 +1047,32 @@ class HybridBlock(Block):
         `Hybridize - A Hybrid of Imperative and Symbolic Programming
         <https://mxnet.apache.org/versions/master/api/python/docs/tutorials/packages/gluon/blocks/hybridize.html>`_
     """
+    class OptConstraint:
+        class Flag(enum.Flag):
+            DisableAMP = enum.auto()
+
+        def __init__(self, flag) -> None:
+            self.flag = flag
+            self.enter_state = None
+
+        def __enter__(self):
+            self.enter_state = HybridBlock.OptConstraint.Flag(get_optimization_constraints())
+            target_state = self.enter_state | self.flag
+            set_optimization_constraints(target_state)
+
+        def __exit__(self, ptype, value, trace):
+            set_optimization_constraints(self.enter_state)
+
+        @staticmethod
+        def disable_all():
+            opt_flag = HybridBlock.OptConstraint.Flag()
+            for flag in HybridBlock.OptConstraint.Flag:
+                opt_flag |= flag
+
+        @staticmethod
+        def disable_amp():
+            return HybridBlock.OptConstraint(HybridBlock.OptConstraint.Flag.DisableAMP)
+
     def __init__(self):
         super(HybridBlock, self).__init__()
         assert hasattr(self, "hybrid_forward") is False, (
@@ -1550,7 +1578,6 @@ class HybridBlock(Block):
 
         if remove_amp_cast:
             handle = SymbolHandle()
-            import ctypes
             check_call(_LIB.MXSymbolRemoveAmpCast(sym.handle, ctypes.byref(handle)))
             sym = type(sym)(handle)
         return sym, arg_dict
@@ -1571,7 +1598,6 @@ class HybridBlock(Block):
         """
         def c_callback(name, op_name, array):
             """wrapper for user callback"""
-            import ctypes
             array = ctypes.cast(array, NDArrayHandle)
             array = NDArray(array, writable=False)
             name = py_str(name)
@@ -1820,12 +1846,12 @@ class SymbolBlock(HybridBlock):
     def __call__(self, x, *args):
         """Calls forward. Only accepts positional arguments."""
         for hook in self._forward_pre_hooks.values():
-            hook(self, [x] + args)
+            hook(self, [x, *args])
 
         out = self.forward(x, *args)
 
         for hook in self._forward_hooks.values():
-            hook(self, [x] + args, out)
+            hook(self, [x, *args], out)
 
         return out
 
@@ -1943,3 +1969,15 @@ def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dt
             aux_types.append(default_dtype)
 
     return (arg_types, aux_types)
+
+
+def set_optimization_constraints(state):
+    prev_state = ctypes.c_uint()
+    check_call(_LIB.MXSetOptimizationConstraints(ctypes.c_uint(state.value), ctypes.byref(prev_state)))
+    return HybridBlock.OptConstraint.Flag(prev_state.value)
+
+
+def get_optimization_constraints():
+    curr = ctypes.c_uint()
+    check_call(_LIB.MXGetOptimizationConstraints(ctypes.byref(curr)))
+    return HybridBlock.OptConstraint.Flag(curr.value)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 18fc8e7261..0e6432a999 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2653,6 +2653,8 @@ fixed-size items.
         array([[1, 1, 1],
                [1, 1, 1]], dtype=int32)
         """
+        if self.dtype == bfloat16:
+            return self.astype(np.float32).asnumpy()
         data = np.empty(self.shape, dtype=self.dtype)
         check_call(_LIB.MXNDArraySyncCopyToCPU(
             self.handle,
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 2e9c0a3736..b91a997b7c 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -291,6 +291,19 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) {
   API_END();
 }
 
+int MXSetOptimizationConstraints(unsigned int constraints, unsigned int* prev) {
+  API_BEGIN();
+  *prev =
+      static_cast<unsigned int>(Imperative::Get()->set_opt_constraints(OptConstraint(constraints)));
+  API_END();
+}
+
+int MXGetOptimizationConstraints(unsigned int* curr) {
+  API_BEGIN();
+  *curr = static_cast<unsigned int>(Imperative::Get()->get_opt_constraints());
+  API_END();
+}
+
 int MXIsNumpyShape(int* curr) {
   API_BEGIN();
   *curr = Imperative::Get()->is_np_shape();
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 6f9eeb2dac..e17d7beb2e 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -986,9 +986,7 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle,
                             const uint32_t num_fp32_ops,
                             const char** const fp32_ops_p,
                             const uint32_t num_widest_dtype_ops,
-                            const char** const widest_dtype_ops_p,
-                            const uint32_t num_excluded_symbols,
-                            const char** const excluded_syms_p) {
+                            const char** const widest_dtype_ops_p) {
   nnvm::Symbol* result_sym = new nnvm::Symbol();
   API_BEGIN();
   nnvm::Symbol* sym                   = static_cast<nnvm::Symbol*>(sym_handle);
@@ -1002,8 +1000,6 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle,
   std::unordered_set<std::string> fp32_ops(fp32_ops_p, fp32_ops_p + num_fp32_ops);
   std::unordered_set<std::string> widest_dtype_ops(widest_dtype_ops_p,
                                                    widest_dtype_ops_p + num_widest_dtype_ops);
-  std::unordered_set<std::string> excluded_syms(excluded_syms_p,
-                                                excluded_syms_p + num_excluded_symbols);
 
   nnvm::DTypeVector arg_types(num_all_args);
   std::unordered_map<std::string, int> node_name_to_type_map;
@@ -1022,7 +1018,6 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle,
   g.attrs["target_dtype_ops"] = std::make_shared<nnvm::any>(std::move(target_dtype_ops));
   g.attrs["fp32_ops"]         = std::make_shared<nnvm::any>(std::move(fp32_ops));
   g.attrs["widest_dtype_ops"] = std::make_shared<nnvm::any>(std::move(widest_dtype_ops));
-  g.attrs["excluded_syms"]    = std::make_shared<nnvm::any>(std::move(excluded_syms));
   g                           = ApplyPass(std::move(g), "ReducePrecision");
 
   result_sym->outputs                      = g.outputs;
diff --git a/src/common/utils.h b/src/common/utils.h
index d839a6db40..fe8413f18e 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -438,6 +438,38 @@ inline std::string attr_value_string(const nnvm::NodeAttrs& attrs,
   return attrs.dict.at(attr_name);
 }
 
+/*! \brief Seeks an attribute in a node and its subgraphs and invokes a function on each. */
+template <typename Fn>
+inline void attr_foreach(const nnvm::NodeAttrs& attrs, const std::string& attr_name, const Fn& fn) {
+  const auto& found_it = attrs.dict.find(attr_name);
+  if (found_it != attrs.dict.end()) {
+    fn(found_it->second);
+  }
+  for (const auto& subgraph : attrs.subgraphs) {
+    DFSVisit(subgraph->outputs,
+             [&](const nnvm::ObjectPtr& node) { attr_foreach(node->attrs, attr_name, fn); });
+  }
+}
+
+template <typename ValueType>
+inline ValueType flag_attr_accumulate(const nnvm::NodeAttrs& attrs, const std::string& attr_name) {
+  static_assert(std::is_integral<ValueType>::value, "ValueType must be an integral type.");
+
+  ValueType result = 0;
+  attr_foreach(attrs, attr_name, [&](const std::string& attr_value) {
+    std::istringstream ss(attr_value);
+    ValueType temp;
+    ss >> temp;
+    result |= temp;
+
+    if (ss.fail() || !ss.eof()) {
+      LOG(WARNING) << "Incorrect value of an attribute: " << attr_name
+                   << ". Expected an integer, while got: " << attr_value;
+    }
+  });
+  return result;
+}
+
 /*! \brief get string representation of the operator stypes */
 inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs,
                                          const int dev_mask,
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index b9bdaac947..fb123c18c9 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -33,11 +33,13 @@ namespace mxnet {
 thread_local bool Imperative::is_train_                 = false;
 thread_local bool Imperative::is_recording_             = false;
 thread_local bool Imperative::is_deferred_compute_      = false;
+thread_local OptConstraint Imperative::opt_constraints_ = OptConstraint::None;
 thread_local bool Imperative::is_np_shape_thread_local_ = false;
 #else
 MX_THREAD_LOCAL bool Imperative::is_train_                 = false;
 MX_THREAD_LOCAL bool Imperative::is_recording_             = false;
 MX_THREAD_LOCAL bool Imperative::is_deferred_compute_      = false;
+MX_THREAD_LOCAL OptConstraint Imperative::opt_constraints_ = OptConstraint::None;
 MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false;
 #endif
 
@@ -367,6 +369,11 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs&& attrs,
     node_count_++;
   }
 
+  if (get_opt_constraints() != OptConstraint::None) {
+    node->attrs.dict[OPT_CONSTRAINT_ATTR] =
+        std::to_string(static_cast<std::underlying_type_t<OptConstraint>>(get_opt_constraints()));
+  }
+
   for (uint32_t i = 0; i < outputs.size(); ++i) {
     outputs[i]->deferredcompute_entry_ = nnvm::NodeEntry{node, i, 0};
   }
diff --git a/src/nnvm/low_precision_pass.cc b/src/nnvm/low_precision_pass.cc
index 192390eacb..2c2cce5b99 100644
--- a/src/nnvm/low_precision_pass.cc
+++ b/src/nnvm/low_precision_pass.cc
@@ -30,6 +30,7 @@
 #include <algorithm>
 #include <functional>
 #include "operator/operator_common.h"
+#include "common/utils.h"
 
 namespace mxnet {
 using nnvm::Graph;
@@ -59,8 +60,9 @@ class MappedNodeEntry {
    * converted and dtypes of its outputs may have changed
    */
   void UpdateDTypeAfterConversion(const int new_dtype) {
-    CHECK_EQ(dtype, original_dtype);  // dtype should be changed only once
+    CHECK(dtype == original_dtype || dtype == new_dtype);  // dtype should be changed only once
     CHECK(entry.node->op());
+    CHECK_NE(new_dtype, -1);
     dtype = new_dtype;
   }
 
@@ -91,6 +93,8 @@ class MappedNodeEntry {
 
   /*! \brief Returns whether this entry has the specified dtype or an existing cast to that dtype */
   bool HasDTypeEntry(const int target_dtype) const {
+    CHECK_NE(target_dtype, -1);
+
     return dtype == target_dtype || casts.count(target_dtype) > 0;
   }
 
@@ -99,6 +103,8 @@ class MappedNodeEntry {
    * input entires of a node before its conversion.
    */
   bool CanBeCastTo(const int target_dtype) {
+    CHECK_NE(target_dtype, -1);
+
     static const auto& amp_cast_op = Op::Get("amp_cast");
     static const auto& infertype   = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType")[amp_cast_op];
     nnvm::NodeAttrs dummy_atts;
@@ -113,6 +119,8 @@ class MappedNodeEntry {
   /*! \brief Returns whether this NodeEntry (of a parameter) can be cast offline */
   bool CanBeCastOfflineTo(const int target_dtype) const {
     CHECK(entry.node->is_variable());
+    CHECK_NE(target_dtype, -1);
+
     return casts.count(target_dtype) > 0;
   }
 
@@ -177,8 +185,21 @@ static bool TryLowPrecision(const int target_dtype,
   static const auto& fmutate_inputs = Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
 
   std::vector<int> in_types(old_node->inputs.size(), -1);
+  bool has_lp_input = false;
+  for (int i = 0; i < old_node->inputs.size(); ++i) {
+    if (entry_map->at(old_node->inputs[i]).HasDTypeEntry(target_dtype)) {
+      in_types[i]  = target_dtype;
+      has_lp_input = true;
+    }
+  }
+  if (!has_lp_input) {
+    // when inputs are not already in low precision, assume the first input should be in low
+    // precision in order to convert this op
+    in_types[0] = target_dtype;
+  }
+
+  // infer types of other inputs
   std::vector<int> out_types(old_node->num_outputs(), -1);
-  in_types[0] = target_dtype;
   if (infertype.count(old_node->op()) == 0 ||
       infertype[old_node->op()](old_node->attrs, &in_types, &out_types) == false) {
     return false;
@@ -226,21 +247,38 @@ static void HandleWidestDtypeNode(const int target_dtype,
                                   EntryMap_t* const entry_map) {
   static const auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
 
-  std::vector<int> in_types(old_node->inputs.size(), target_dtype);
-  std::vector<int> out_types(old_node->num_outputs(), -1);
-  const bool inferred = (infertype.count(old_node->op()) > 0 &&
-                         infertype[old_node->op()](old_node->attrs, &in_types, &out_types));
-
-  bool has_lp_inputs = inferred;
-  for (int i = 0; has_lp_inputs && i < old_node->inputs.size(); ++i) {
-    const NodeEntry& input = old_node->inputs[i];
-    has_lp_inputs &= entry_map->at(input).HasDTypeEntry(in_types[i]);
+  // gather info about current dtypes of inputs
+  // if there is already at least one input with target dtype, we try converting to low precision
+  bool try_lp = false;
+  std::vector<int> in_types(old_node->inputs.size(), -1);
+  for (int i = 0; i < old_node->inputs.size(); ++i) {
+    if (entry_map->at(old_node->inputs[i]).HasDTypeEntry(target_dtype)) {
+      in_types[i] = target_dtype;  // set only lp inputs
+      try_lp      = true;
+    }
   }
 
-  if (!has_lp_inputs ||
-      !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, entry_map)) {
-    KeepOriginalNode(old_node, node_map, entry_map);
+  if (try_lp) {
+    // run infertype, to see what other input types this op needs with the current lp inputs
+    std::vector<int> out_types(old_node->num_outputs(), -1);
+    try_lp = (infertype.count(old_node->op()) > 0 &&
+              infertype[old_node->op()](old_node->attrs, &in_types, &out_types));
+
+    if (try_lp) {
+      // if we have to add casts to inputs, this op shouldn't run in low precision
+      for (int i = 0; i < old_node->inputs.size(); ++i) {
+        const NodeEntry& old_input_ne = old_node->inputs[i];
+        if (in_types[i] != -1 && !entry_map->at(old_input_ne).HasDTypeEntry(in_types[i])) {
+          try_lp = false;
+          break;
+        }
+      }
+      if (try_lp && TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, entry_map)) {
+        return;
+      }
+    }
   }
+  KeepOriginalNode(old_node, node_map, entry_map);
 }
 /*!
  * \brief Tries to convert the node to low precision if some of its inputs already are converted.
@@ -309,7 +347,6 @@ Graph ReducePrecision(Graph&& src) {
   const auto& target_dtype_ops = src.GetAttr<std::unordered_set<std::string>>("target_dtype_ops");
   const auto& fp32_ops         = src.GetAttr<std::unordered_set<std::string>>("fp32_ops");
   const auto& widest_dtype_ops = src.GetAttr<std::unordered_set<std::string>>("widest_dtype_ops");
-  const auto& excluded_syms    = src.GetAttr<std::unordered_set<std::string>>("excluded_syms");
   auto src_dtypes              = src.GetAttr<nnvm::DTypeVector>("dtype");  // copy, not reference
 
   CHECK(target_dtype == mshadow::kFloat16 || target_dtype == mshadow::kBfloat16)
@@ -360,7 +397,7 @@ Graph ReducePrecision(Graph&& src) {
   }
 
   // convert the model
-  DFSVisit(src.outputs, [&](const ObjectPtr& old_node) {
+  const auto convert_node_fn = [&](const ObjectPtr& old_node) {
     if (old_node->is_variable() || old_node->op() == Op::Get("amp_multicast") ||
         IsCastOp(old_node->op())) {
       const ObjectPtr& new_node = node_map.at(old_node.get());
@@ -370,8 +407,10 @@ Graph ReducePrecision(Graph&& src) {
       }
       return;
     }
-
-    if (fp32_ops.count(old_node->op()->name) > 0 || excluded_syms.count(old_node->attrs.name) > 0) {
+    auto opt_constraints = common::flag_attr_accumulate<std::underlying_type_t<OptConstraint>>(
+        old_node->attrs, OPT_CONSTRAINT_ATTR);
+    if (fp32_ops.count(old_node->op()->name) > 0 ||
+        (opt_constraints & static_cast<int>(OptConstraint::DisableAMP))) {
       KeepOriginalNode(old_node, node_map, &entry_map);
     } else if (target_dtype_ops.count(old_node->op()->name) > 0) {
       if (!TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, &entry_map)) {
@@ -384,7 +423,20 @@ Graph ReducePrecision(Graph&& src) {
     } else {
       HandleDTypeNeutralNode(target_dtype, old_node, node_map, nodes_entries, &entry_map);
     }
+  };
+
+  // Because some nodes depend on casts present in the graph, the order of visited nodes will
+  // determine whether some nodes are converted or not. To avoid this, first we make a virtual
+  // conversion pass in order to have all the necessary casts already present (in the
+  // MappedNodeEntry instances) during the second (true) conversion pass
+
+  // virtual conversion pass
+  DFSVisit(src.outputs, [&](const ObjectPtr& old_node) {
+    convert_node_fn(old_node);
+    node_map[old_node.get()]->inputs.clear();  // make this pass "virtual" by removing edges
   });
+  // true conversion pass
+  DFSVisit(src.outputs, [&](const ObjectPtr& old_node) { convert_node_fn(old_node); });
 
   std::vector<NodeEntry> outputs;
   for (const auto& out_ne : src.outputs) {
diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h
index 3337c0a833..7f297d6e57 100644
--- a/src/operator/quantization/quantize_v2-inl.h
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -151,8 +151,20 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 3U);
   const QuantizeV2Param& param = nnvm::get<QuantizeV2Param>(attrs.parsed);
+
+#if MXNET_USE_ONEDNN == 1
+  if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+    CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kBfloat16 ||
+          in_attrs->at(0) == mshadow::kUint8 || in_attrs->at(0) == mshadow::kInt8);
+  } else {
+    CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
+          in_attrs->at(0) == mshadow::kInt8);
+  }
+#else
   CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
         in_attrs->at(0) == mshadow::kInt8);
+#endif
+
   auto out_type = GetQuantizeOutputType(param);
   if (out_type == mshadow::kUint8) {
     TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
diff --git a/tests/python/amp/common.py b/tests/python/amp/common.py
new file mode 100644
index 0000000000..3e221c613e
--- /dev/null
+++ b/tests/python/amp/common.py
@@ -0,0 +1,245 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import warnings
+import collections
+import mxnet as mx
+from mxnet import amp
+from mxnet.gluon import nn
+from mxnet.operator import get_all_registered_operators_grouped
+
+
+def test_amp_coverage(lp_dtype, lp_name):
+    conditional = [item[0] for item in amp.list_conditional_fp32_ops(lp_dtype)]
+    lp16_ops = amp.list_lp16_ops(lp_dtype)
+    lp16_fp32_ops = amp.list_lp16_fp32_ops(lp_dtype)
+    fp32_ops = amp.list_fp32_ops(lp_dtype)
+    widest_ops = amp.list_widest_type_cast(lp_dtype)
+    all_lp_lists = [lp16_ops, lp16_fp32_ops, fp32_ops, widest_ops, conditional]
+
+    # Check for duplicates
+    for op_list in all_lp_lists:
+        ret = [op for op, count in collections.Counter(op_list).items() if count > 1]
+        assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
+
+    all_lp_ops = [op for op_list in all_lp_lists for op in op_list]
+    ret = [op for op, count in collections.Counter(all_lp_ops).items() if count > 1]
+    assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
+
+    # Check the coverage
+    covered_ops = set(all_lp_ops)
+    all_mxnet_ops = get_all_registered_operators_grouped()
+    required_ops = {op for op in all_mxnet_ops if not "backward" in op}
+
+    extra_ops = covered_ops - required_ops
+    assert not extra_ops, f"{len(extra_ops)} operators are not needed in the AMP lists: {sorted(extra_ops)}"
+
+    guidelines = f"""Please follow these guidelines for choosing a proper list:
+    - if your operator is not to be used in a computational graph
+      (e.g. image manipulation operators, optimizers) or does not have
+      inputs, put it in {lp_name.upper()}_FP32_FUNCS list,
+    - if your operator requires FP32 inputs or is not safe to use with lower
+      precision, put it in FP32_FUNCS list,
+    - if your operator supports both FP32 and lower precision, has
+      multiple inputs and expects all inputs to be of the same
+      type, put it in WIDEST_TYPE_CASTS list,
+    - if your operator supports both FP32 and lower precision and has
+      either a single input or supports inputs of different type,
+      put it in {lp_name.upper()}_FP32_FUNCS list,
+    - if your operator is both safe to use in lower precision and
+      it is highly beneficial to use it in lower precision, then
+      put it in {lp_name.upper()}_FUNCS (this is unlikely for new operators)
+    - If you are not sure which list to choose, FP32_FUNCS is the
+      safest option"""
+    missing_ops = required_ops - covered_ops
+
+    if len(missing_ops) > 0:
+      warnings.warn(f"{len(missing_ops)} operators {sorted(missing_ops)} do not exist in AMP lists "
+                    f"(in python/mxnet/amp/lists/symbol_{lp_name.lower()}.py) - please add them. \n{guidelines}")
+
+
+def test_amp_basic_use(lp_dtype):
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super().__init__()
+      self.fc1 = nn.Dense(4)
+      self.fc2 = nn.Dense(4)
+
+    def forward(self, x):
+      x = self.fc1(x)
+      x = self.fc2(x)
+      return x.reshape((-1, 2, 2))
+
+  data_example = mx.np.random.uniform(-1, 1, (4, 4))
+
+  net = TestNet()
+  net.initialize()
+  net = amp.convert_hybrid_block(net, data_example, lp_dtype)
+
+  lp16_casts = 1  # cast for network input
+  lp16_casts += 2  # cast for weights and bias of `fc1`
+  lp16_casts += 2  # cast for weights and bias of `fc2`
+
+  other_casts = 1  # cast for the network output (from lp16 to f32)
+
+  lp16_tensors = 1  # cast network input
+  lp16_tensors += 3  # cast weights and bias of `fc1`, `fc1` output
+  lp16_tensors += 3  # cast weights and bias of `fc2`, `fc2` output
+  lp16_tensors += 1  # reshape output
+  check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num=lp16_tensors, lp16_casts_num=lp16_casts,
+                      other_casts_num=other_casts)
+
+
+def test_amp_offline_casting(lp_dtype):
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super().__init__()
+      self.lp16_op1 = nn.Conv2D(4, 3)
+      self.lp16_op2 = nn.Conv2DTranspose(4, 3)
+      self.fp32_op = nn.Dense(4)
+
+    def forward(self, x):
+      x = self.lp16_op1(x)
+      x = self.lp16_op2(x)
+      x = x.reshape(x.shape[0], -1)
+      with nn.HybridBlock.OptConstraint.disable_amp():
+        x = self.fp32_op(x)
+      return x
+
+  net = TestNet()
+  net.initialize()
+  data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
+  lp_net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
+
+  check_amp_net_stats(lp_dtype, lp_net, data_example, lp16_tensors_num=4,
+                      lp16_casts_num=1, other_casts_num=1)
+  for name, data in lp_net.collect_params().items():
+    assert mx.nd.get_dtype_name(data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
+
+
+def test_amp_offline_casting_shared_params(lp_dtype):
+  COMMON_SIZE = 4
+
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super().__init__()
+      self.lp16_op1 = nn.Dense(COMMON_SIZE)
+      self.lp16_op2 = nn.Dense(COMMON_SIZE)
+      self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
+      self.fp32_op = nn.Dense(COMMON_SIZE)
+      self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
+
+    def forward(self, x):
+      x = self.lp16_op1(x)
+      x1 = self.lp16_op2(x)
+      with nn.HybridBlock.OptConstraint.disable_amp():
+        x2 = self.fp32_op(x)
+      x = mx.np.concat((x1, x2), axis=1)
+      return x
+
+  net = TestNet()
+  net.initialize()
+  data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
+  lp_net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
+
+  check_amp_net_stats(lp_dtype, lp_net, data_example, lp16_tensors_num=4,
+                      lp16_casts_num=2, other_casts_num=2)
+  for name, data in lp_net.collect_params().items():
+    assert mx.nd.get_dtype_name(data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
+
+
+def test_lp16_fp32_ops_order_independence(lp_dtype):
+  class TestNet(nn.HybridBlock):
+    def __init__(self, lp16_fp32_is_first):
+      super().__init__()
+      if lp16_fp32_is_first:
+        self.first = mx.npx.batch_flatten  # lp16_fp32_op
+        self.second = nn.Dense(4)
+      else:
+        self.first = nn.Dense(4)
+        self.second = mx.npx.batch_flatten  # lp16_fp32_op
+
+    def forward(self, x):
+      x = 2**x
+      x1 = self.first(x)
+      x2 = self.second(x)
+      return x1, x2
+
+  data_example = mx.np.random.uniform(-1, 1, (4, 16))
+
+  for lp16_fp32_is_second in [False, True]:
+    net = TestNet(lp16_fp32_is_second)
+    net.initialize()
+    net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
+    check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num=3,
+                        lp16_casts_num=1, other_casts_num=2)
+
+
+def test_amp_node_excluding(lp_dtype):
+  DISABLE_AMP_ATTR_DICT = {'__opt_constraint__': str(
+      mx.gluon.HybridBlock.OptConstraint.Flag.DisableAMP.value)}
+
+  data = mx.sym.var('data')
+  wei = mx.sym.var('weights')
+  bias = mx.sym.var('bias')
+  # manually excluded
+  fc1 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc1', attr=DISABLE_AMP_ATTR_DICT)
+  # to be excluded using the conversion API
+  fc2 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc2')
+  symnet = mx.sym.Group([fc1, fc2])
+
+  net = mx.gluon.SymbolBlock(symnet, [data])
+  net.initialize()
+
+  # exclude only nodes with set attribute (only 1 node - `fc1`)
+  data_example = mx.np.random.uniform(-1, 1, (4, 16))
+  net_1_excluded = amp.convert_hybrid_block(net, data_example, lp_dtype)
+
+  lp16_tensors = 4  # cast `data`, weights and bias of `fc1`, `fc1` output
+  lp16_casts = 3  # `data` cast, casts for weights and bias of `fc1`
+  other_casts = 1  # cast for the network output (from lp16 to f32)
+  check_amp_net_stats(lp_dtype, net_1_excluded, data_example, lp16_tensors_num=lp16_tensors,
+                      lp16_casts_num=lp16_casts, other_casts_num=other_casts)
+
+  # exclude using the `excluded_sym_names` argument (both nodes)
+  net_2_excluded = amp.convert_hybrid_block(net, data_example, lp_dtype,
+                                            excluded_sym_names=['fc1', 'fc2'])
+  check_amp_net_stats(lp_dtype, net_2_excluded, data_example, lp16_tensors_num=0,
+                      lp16_casts_num=0, other_casts_num=0)
+
+
+def check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num, lp16_casts_num, other_casts_num):
+  lp16_tensors = set()
+  lp16_casts = set()
+  other_casts = set()
+
+  def inspect_output(tensor_name, op_name, tensor):
+    dtype = mx.nd.get_dtype_name(tensor.dtype)
+    if op_name == 'amp_cast':
+      if dtype == lp_dtype:
+        lp16_casts.add(tensor_name)
+      else:
+        other_casts.add(tensor_name)
+    if dtype == lp_dtype:
+      lp16_tensors.add(tensor_name)
+
+  net.register_op_hook(inspect_output)
+  net(data_example)
+
+  assert len(lp16_tensors) == lp16_tensors_num, f'Bad lp16 tensors! Present tensors: {sorted(lp16_tensors)}'
+  assert len(lp16_casts) == lp16_casts_num, f'Bad lp16 casts! Present casts: {sorted(lp16_casts)}'
+  assert len(other_casts) == other_casts_num, f'Bad casts! Present casts: {sorted(other_casts)}'
diff --git a/tests/python/dnnl/subgraphs/test_amp_subgraph.py b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
index 4dd7337855..2c5c6e1b45 100644
--- a/tests/python/dnnl/subgraphs/test_amp_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
@@ -17,15 +17,21 @@
 
 import json
 import mxnet as mx
-import mxnet.gluon.nn as nn
 from mxnet import amp
-from mxnet.amp.amp import bfloat16
+from mxnet.gluon import nn
 from mxnet.test_utils import assert_almost_equal
 from subgraph_common import SG_PASS_NAME, QUANTIZE_SG_PASS_NAME
 from test_matmul_subgraph import MultiHeadAttention
 
+import sys
+from pathlib import Path
+curr_path = Path(__file__).resolve().parent
+sys.path.insert(0, str(curr_path.parent.parent))
+
+from amp.common import check_amp_net_stats
+
 AMP_SG_PASS_NAME = 'ONEDNN_AMP'
-AMP_DTYPE = bfloat16
+AMP_DTYPE = 'bfloat16'
 
 
 # Checks if amp (after the AMP_SG_PASS_NAME fuse) changes the name of tensors for calibration
@@ -86,7 +92,7 @@ def check_amp_fuse(net, data_example, expected_sym=None, quantized_nodes=[], rto
 
 @mx.util.use_np
 def test_amp_fc():
-  class TestNet(mx.gluon.HybridBlock):
+  class TestNet(nn.HybridBlock):
     def __init__(self):
       super(TestNet, self).__init__()
       self.fc1 = nn.Dense(16)
@@ -115,7 +121,7 @@ def test_amp_fc():
 
 @mx.util.use_np
 def test_amp_conv():
-  class TestNet(mx.gluon.HybridBlock):
+  class TestNet(nn.HybridBlock):
     def __init__(self):
       super(TestNet, self).__init__()
       self.conv1 = nn.Conv2D(16, (3, 3))
@@ -166,7 +172,7 @@ def test_amp_transformers():
 
 @mx.util.use_np
 def test_amp_concat():
-  class TestNet(mx.gluon.HybridBlock):
+  class TestNet(nn.HybridBlock):
     def __init__(self):
       super(TestNet, self).__init__()
       self.fc1 = nn.Dense(16)
@@ -241,3 +247,30 @@ def test_amp_fuse_with_branch():
   exp_sym = mx.sym.Group([lp16_op_2, f32_op])
   exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
   check_amp_fuse(net, [data_example], exp_sym)
+
+
+def test_amp_excluding_after_graph_pass():
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super(TestNet, self).__init__()
+      self.fc1 = nn.Dense(16)
+      self.fc2 = nn.Dense(16)
+
+    def forward(self, x):
+      x = self.fc1(x)
+      with nn.HybridBlock.OptConstraint.disable_amp():
+        x = self.fc2(x)
+      return x
+
+  data_example = mx.np.random.uniform(-1, 1, (1, 8))
+  net = TestNet()
+  net.initialize()
+
+  net_before = amp.convert_hybrid_block(net, data_example, AMP_DTYPE, cast_params_offline=True)
+  check_amp_net_stats(AMP_DTYPE, net_before, data_example, lp16_tensors_num=2,
+                      lp16_casts_num=1, other_casts_num=1)
+
+  net.optimize_for(data_example, backend=SG_PASS_NAME)  # introduces new nodes
+  net_after = amp.convert_hybrid_block(net, data_example, AMP_DTYPE, cast_params_offline=True)
+  check_amp_net_stats(AMP_DTYPE, net_after, data_example, lp16_tensors_num=2,
+                      lp16_casts_num=1, other_casts_num=1)
diff --git a/tests/python/dnnl/test_amp.py b/tests/python/dnnl/test_amp.py
index 90a66bc608..73be9bb9a8 100644
--- a/tests/python/dnnl/test_amp.py
+++ b/tests/python/dnnl/test_amp.py
@@ -15,133 +15,42 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import os
 import sys
+from pathlib import Path
+curr_path = Path(__file__).resolve().parent
+sys.path.insert(0, str(curr_path.parent))
+
 import mxnet as mx
-import numpy as np
-import warnings
-import collections
-import ctypes
-from mxnet import amp
-from mxnet.amp.amp import bfloat16
-from mxnet.gluon import nn
-from mxnet.operator import get_all_registered_operators_grouped
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-
-
-def test_amp_coverage():
-    conditional = [item[0] for item in amp.lists.symbol_bf16.CONDITIONAL_FP32_FUNCS]
-
-    # Check for duplicates
-    for a in [amp.lists.symbol_bf16.BF16_FUNCS,
-              amp.lists.symbol_bf16.BF16_FP32_FUNCS,
-              amp.lists.symbol_bf16.FP32_FUNCS,
-              amp.lists.symbol_bf16.WIDEST_TYPE_CASTS,
-              conditional]:
-        ret = [item for item, count in collections.Counter(a).items() if count > 1]
-        assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
-
-    t = []
-    for a in [amp.lists.symbol_bf16.BF16_FUNCS,
-              amp.lists.symbol_bf16.BF16_FP32_FUNCS,
-              amp.lists.symbol_bf16.FP32_FUNCS,
-              amp.lists.symbol_bf16.WIDEST_TYPE_CASTS,
-              conditional]:
-        t += a
-    ret = [item for item, count in collections.Counter(t).items() if count > 1]
-    assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
-
-    # Check the coverage
-    covered = set(t)
-    ops = get_all_registered_operators_grouped()
-    required = set(k for k in ops
-                   if not k.startswith(("_backward", "_contrib_backward", "_npi_backward")) and
-                   not k.endswith("_backward"))
-
-    extra = covered - required
-    assert not extra, f"{len(extra)} operators are not needed in the AMP lists: {sorted(extra)}"
-
-    guidelines = """Please follow these guidelines for choosing a proper list:
-    - if your operator is not to be used in a computational graph
-      (e.g. image manipulation operators, optimizers) or does not have
-      inputs, put it in BF16_FP32_FUNCS list,
-    - if your operator requires FP32 inputs or is not safe to use with lower
-      precision, put it in FP32_FUNCS list,
-    - if your operator supports both FP32 and lower precision, has
-      multiple inputs and expects all inputs to be of the same
-      type, put it in WIDEST_TYPE_CASTS list,
-    - if your operator supports both FP32 and lower precision and has
-      either a single input or supports inputs of different type,
-      put it in BF16_FP32_FUNCS list,
-    - if your operator is both safe to use in lower precision and
-      it is highly beneficial to use it in lower precision, then
-      put it in BF16_FUNCS (this is unlikely for new operators)
-    - If you are not sure which list to choose, FP32_FUNCS is the
-      safest option"""
-    diff = required - covered
-
-    if len(diff) > 0:
-      warnings.warn(f"{len(diff)} operators {sorted(diff)} do not exist in AMP lists (in "
-                    f"python/mxnet/amp/lists/symbol_bf16.py) - please add them. "
-                    f"\n{guidelines}")
+import amp.common as amp_common_tests
+
+
+AMP_DTYPE = 'bfloat16'
+
+
+def test_bf16_coverage():
+    amp_common_tests.test_amp_coverage(AMP_DTYPE, 'BF16')
+
+
+@mx.util.use_np
+def test_bf16_basic_use():
+    amp_common_tests.test_amp_basic_use(AMP_DTYPE)
 
 
 @mx.util.use_np
 def test_bf16_offline_casting():
-  class TestNet(nn.HybridBlock):
-    def __init__(self):
-      super().__init__()
-      self.lp16_op1 = nn.Conv2D(4, 3)
-      self.lp16_op2 = nn.Conv2DTranspose(4, 3)
-      self.fp32_op = nn.Dense(4)
-
-    def forward(self, x):
-      x = self.lp16_op1(x)
-      x = self.lp16_op2(x)
-      x = x.reshape(x.shape[0], -1)
-      x = self.fp32_op(x)
-      return x
-
-  net = TestNet()
-  net.initialize()
-  data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
-  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=bfloat16,
-                                    target_dtype_ops=['Convolution'], fp32_ops=['FullyConnected'],
-                                    cast_params_offline=True, device=mx.current_context())
-  lp_net(data_example)
-  for name, data in lp_net.collect_params().items():
-    assert data.dtype == (np.float32 if 'fp32_op' in name else bfloat16)
+    amp_common_tests.test_amp_offline_casting(AMP_DTYPE)
 
 
 @mx.util.use_np
 def test_bf16_offline_casting_shared_params():
-  COMMON_SIZE = 4
-
-  class TestNet(nn.HybridBlock):
-    def __init__(self):
-      super().__init__()
-      self.lp16_op1 = nn.Dense(COMMON_SIZE)
-      self.lp16_op2 = nn.Dense(COMMON_SIZE)
-      self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
-      self.fp32_op = nn.Conv1D(COMMON_SIZE, 3)
-      self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
-
-    def forward(self, x):
-      x = self.lp16_op1(x)
-      x1 = self.lp16_op2(x)
-      x2 = mx.np.expand_dims(x, 1)
-      x2 = self.fp32_op(x2)
-      x2 = mx.npx.batch_flatten(x2)
-      x = mx.np.concat((x1, x2), axis=1)
-      return x
-
-  net = TestNet()
-  net.initialize()
-  data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
-  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=bfloat16,
-                                    target_dtype_ops=['FullyConnected'], fp32_ops=['Convolution'],
-                                    cast_params_offline=True, device=mx.current_context())
-  lp_net(data_example)
-  for name, data in lp_net.collect_params().items():
-    assert data.dtype == (np.float32 if 'fp32_op' in name else bfloat16)
+    amp_common_tests.test_amp_offline_casting_shared_params(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_bf16_fp32_ops_order_independence():
+    amp_common_tests.test_lp16_fp32_ops_order_independence(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_bf16_test_node_excluding():
+    amp_common_tests.test_amp_node_excluding(AMP_DTYPE)
diff --git a/tests/python/gpu/test_amp.py b/tests/python/gpu/test_amp.py
index abc8f326e3..0c8ce79a71 100644
--- a/tests/python/gpu/test_amp.py
+++ b/tests/python/gpu/test_amp.py
@@ -15,86 +15,54 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import os
 import sys
+from pathlib import Path
+curr_path = Path(__file__).resolve().parent
+sys.path.insert(0, str(curr_path.parent))
+sys.path.insert(0, str(curr_path.parent/'unittest'))
+
 import mxnet as mx
-import numpy as np
-from random import randint
-import warnings
-import collections
-import ctypes
-from mxnet import amp
 import pytest
-from mxnet.test_utils import set_default_device, same_symbol_structure
-from mxnet.gluon.model_zoo.vision import get_model
-from mxnet.gluon import SymbolBlock, nn, rnn
-from mxnet.operator import get_all_registered_operators_grouped
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.insert(0, os.path.join(curr_path, '../unittest'))
+from mxnet import amp
+from mxnet.test_utils import set_default_device
+from mxnet.gluon import nn, rnn
+
+import amp.common as amp_common_tests
 from common import assert_raises_cudnn_not_satisfied
-sys.path.insert(0, os.path.join(curr_path, '../train'))
+
+AMP_DTYPE = 'float16'
+
 set_default_device(mx.gpu(0))
 
-@pytest.fixture()
-def amp_tests(request):
-    def teardown():
-        mx.nd.waitall()
-
-    request.addfinalizer(teardown)
-
-def test_amp_coverage(amp_tests):
-    conditional = [item[0] for item in amp.lists.symbol_fp16.CONDITIONAL_FP32_FUNCS]
-
-    # Check for duplicates
-    for a in [amp.lists.symbol_fp16.FP16_FUNCS,
-          amp.lists.symbol_fp16.FP16_FP32_FUNCS,
-          amp.lists.symbol_fp16.FP32_FUNCS,
-          amp.lists.symbol_fp16.WIDEST_TYPE_CASTS,
-          conditional]:
-        ret = [item for item, count in collections.Counter(a).items() if count > 1]
-        assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
-
-    t = []
-    for a in [amp.lists.symbol_fp16.FP16_FUNCS,
-              amp.lists.symbol_fp16.FP16_FP32_FUNCS,
-              amp.lists.symbol_fp16.FP32_FUNCS,
-              amp.lists.symbol_fp16.WIDEST_TYPE_CASTS,
-              conditional]:
-        t += a
-    ret = [item for item, count in collections.Counter(t).items() if count > 1]
-    assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
-
-    # Check the coverage
-    covered = set(t)
-    ops = get_all_registered_operators_grouped()
-    required = set(k for k in ops
-                   if not k.startswith(("_backward", "_contrib_backward", "_npi_backward")) and
-                   not k.endswith("_backward"))
-
-    extra = covered - required
-    assert not extra, f"{len(extra)} operators are not needed in the AMP lists: {sorted(extra)}"
-
-    guidelines = """Please follow these guidelines for choosing a proper list:
-    - if your operator is not to be used in a computational graph
-      (e.g. image manipulation operators, optimizers) or does not have
-      inputs, put it in FP16_FP32_FUNCS list,
-    - if your operator requires FP32 inputs or is not safe to use with lower
-      precision, put it in FP32_FUNCS list,
-    - if your operator supports both FP32 and lower precision, has
-      multiple inputs and expects all inputs to be of the same
-      type, put it in WIDEST_TYPE_CASTS list,
-    - if your operator supports both FP32 and lower precision and has
-      either a single input or supports inputs of different type,
-      put it in FP16_FP32_FUNCS list,
-    - if your operator is both safe to use in lower precision and
-      it is highly beneficial to use it in lower precision, then
-      put it in FP16_FUNCS (this is unlikely for new operators)
-    - If you are not sure which list to choose, FP32_FUNCS is the
-                     safest option"""
-    diff = required - covered
-    assert not diff, f"{len(diff)} operators {sorted(diff)} do not exist in AMP lists (in " \
-        f"python/mxnet/amp/lists/symbol_fp16.py) - please add them. " \
-        f"\n{guidelines}"
+
+def test_fp16_coverage():
+    amp_common_tests.test_amp_coverage(AMP_DTYPE, 'FP16')
+
+
+@mx.util.use_np
+def test_fp16_basic_use():
+    amp_common_tests.test_amp_basic_use(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_fp16_offline_casting():
+    amp_common_tests.test_amp_offline_casting(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_fp16_offline_casting_shared_params():
+    amp_common_tests.test_amp_offline_casting_shared_params(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_fp16_fp32_ops_order_independence():
+    amp_common_tests.test_lp16_fp32_ops_order_independence(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_fp16_test_node_excluding():
+    amp_common_tests.test_amp_node_excluding(AMP_DTYPE)
+
 
 @pytest.mark.skip(reason='Error during waitall(). Tracked in #18099')
 @assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@@ -109,63 +77,3 @@ def test_amp_conversion_rnn(amp_tests):
         new_model = amp.convert_hybrid_block(model)
         out2 = new_model(mx.nd.ones((2, 3, 4)))
         mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2)
-
-
-@mx.util.use_np
-def test_fp16_offline_casting():
-  class TestNet(nn.HybridBlock):
-    def __init__(self):
-      super().__init__()
-      self.lp16_op1 = nn.Conv2D(4, 3)
-      self.lp16_op2 = nn.Conv2DTranspose(4, 3)
-      self.fp32_op = nn.Dense(4)
-
-    def forward(self, x):
-      x = self.lp16_op1(x)
-      x = self.lp16_op2(x)
-      x = x.reshape(x.shape[0], -1)
-      x = self.fp32_op(x)
-      return x
-
-  net = TestNet()
-  net.initialize()
-  data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
-  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype='float16',
-                                    target_dtype_ops=['Convolution'], fp32_ops=['FullyConnected'],
-                                    cast_params_offline=True, device=mx.current_context())
-  lp_net(data_example)
-  for name, data in lp_net.collect_params().items():
-    assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')
-
-
-@mx.util.use_np
-def test_fp16_offline_casting_shared_params():
-  COMMON_SIZE = 4
-
-  class TestNet(nn.HybridBlock):
-    def __init__(self):
-      super().__init__()
-      self.lp16_op1 = nn.Dense(COMMON_SIZE)
-      self.lp16_op2 = nn.Dense(COMMON_SIZE)
-      self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
-      self.fp32_op = nn.Conv1D(COMMON_SIZE, 3)
-      self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
-
-    def forward(self, x):
-      x = self.lp16_op1(x)
-      x1 = self.lp16_op2(x)
-      x2 = mx.np.expand_dims(x, 1)
-      x2 = self.fp32_op(x2)
-      x2 = nn.Flatten()(x2)
-      x = mx.np.concat((x1, x2), axis=1)
-      return x
-
-  net = TestNet()
-  net.initialize()
-  data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
-  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype='float16',
-                                    target_dtype_ops=['FullyConnected'], fp32_ops=['Convolution'],
-                                    cast_params_offline=True, device=mx.current_context())
-  lp_net(data_example)
-  for name, data in lp_net.collect_params().items():
-    assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 65d973171f..5b992bcde3 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -75,6 +75,23 @@ def test_quantize_float32_to_int8():
     qdata_np = (onp.sign(data_np) * onp.minimum(onp.abs(data_np) * scale + 0.5, quantized_range)).astype(onp.int8)
     assert_almost_equal(qdata.asnumpy(), qdata_np, atol = 1)
 
+def test_calibrated_quantize_v2_bfloat16_to_int8():
+    shape = rand_shape_nd(4)
+    data = mx.nd.random.normal(0, 1, shape).astype('bfloat16')
+    min_range = mx.nd.min(data).asscalar()
+    max_range = mx.nd.max(data).asscalar()
+    qdata, min_val, max_val = mx.nd.contrib.quantize_v2(data, 'int8', min_range, max_range)
+    data_np = data.asnumpy()
+    real_range = onp.maximum(onp.abs(min_range), onp.abs(max_range))
+    quantized_range = 127.0
+    scale = quantized_range / real_range
+    assert qdata.dtype == onp.int8
+    assert min_val.dtype == onp.float32
+    assert max_val.dtype == onp.float32
+    assert same(min_val.asscalar(), -real_range)
+    assert same(max_val.asscalar(), real_range)
+    qdata_np = (onp.sign(data_np) * onp.minimum(onp.abs(data_np) * scale + 0.5, quantized_range)).astype(onp.int8)
+    assert_almost_equal(qdata.asnumpy(), qdata_np, atol=1)
 
 def test_dequantize_int8_to_float32():