You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/05/17 12:34:25 UTC
[incubator-mxnet] branch master updated: AMP improvements + enable bf16 input for quantize_v2 (#20983)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3098af4a80 AMP improvements + enable bf16 input for quantize_v2 (#20983)
3098af4a80 is described below
commit 3098af4a808f905e285a4100c6b96b8e126c2c5a
Author: Paweł Głomski <pa...@intel.com>
AuthorDate: Tue May 17 14:34:09 2022 +0200
AMP improvements + enable bf16 input for quantize_v2 (#20983)
* AMP improvements + enable bf16 input for quantize_v2
* Fix sanity
* Improve tests, AMP conversion interface, fix forwad hooks
* Fix tests
* Fix imports in tests
* Use different lp16_fp32 op in test
* Add amp.disable_amp() context, fix tests
* Add tests, generalize optimization disabling
* Fix sanity
* Review fixes
* Use is_integral<>::value
* Review fixes # rerun CI
Change flag type to unsigned int
Add a warning for an incorrect flag attribute value
* Add message to static_assert
* Cast flag attribute value to int before using
* Test amp node excluding, change names of tests
* Fix static_cast type
* Fix sanity
---
include/mxnet/c_api.h | 17 +-
include/mxnet/imperative.h | 21 ++
python/mxnet/amp/amp.py | 56 +++---
python/mxnet/amp/lists/symbol_bf16.py | 18 +-
python/mxnet/gluon/block.py | 46 ++++-
python/mxnet/ndarray/ndarray.py | 2 +
src/c_api/c_api_ndarray.cc | 13 ++
src/c_api/c_api_symbolic.cc | 7 +-
src/common/utils.h | 32 +++
src/imperative/imperative.cc | 7 +
src/nnvm/low_precision_pass.cc | 88 ++++++--
src/operator/quantization/quantize_v2-inl.h | 12 ++
tests/python/amp/common.py | 245 +++++++++++++++++++++++
tests/python/dnnl/subgraphs/test_amp_subgraph.py | 45 ++++-
tests/python/dnnl/test_amp.py | 149 +++-----------
tests/python/gpu/test_amp.py | 176 ++++------------
tests/python/quantization/test_quantization.py | 17 ++
17 files changed, 626 insertions(+), 325 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index eb13feedb1..d8757c7f9f 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1225,6 +1225,19 @@ MXNET_DLL int MXAutogradIsRecording(bool* curr);
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXAutogradIsTraining(bool* curr);
+/*!
+ * \brief set what optimization constraints to apply
+ * \param constraints state composed of OptConstraint flags.
+ * \param prev returns the previous status before this set.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSetOptimizationConstraints(unsigned int constraints, unsigned int* prev);
+/*!
+ * \brief get current optimization constraints
+ * \param curr returns the current status
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXGetOptimizationConstraints(unsigned int* curr);
/*!
* \brief get whether numpy compatibility is on
* \param curr returns the current status
@@ -2025,9 +2038,7 @@ MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle,
const uint32_t num_fp32_ops,
const char** const fp32_ops_p,
const uint32_t num_widest_dtype_ops,
- const char** const widest_dtype_ops_p,
- const uint32_t num_excluded_symbols,
- const char** const excluded_syms_p);
+ const char** const widest_dtype_ops_p);
/*!
* \brief Set calibration table to node attributes in the sym
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index e4e3f6a938..42876f7bf4 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -35,6 +35,14 @@
#include "./ndarray.h"
namespace mxnet {
+
+constexpr char OPT_CONSTRAINT_ATTR[] = "__opt_constraint__";
+enum class OptConstraint : unsigned int {
+ None = 0,
+ DisableAMP = 1 << 0
+ // DisableQuantization = 1 << 1
+};
+
/*! \brief there are three numpy shape flags based on priority.
* GlobalOn
* turn on numpy shape flag globally, it includes thread local.
@@ -47,6 +55,7 @@ namespace mxnet {
* */
enum NumpyShape { Off, ThreadLocalOn, GlobalOn };
typedef NumpyShape NumpyDefaultDtype;
+
/*! \brief runtime functions for NDArray */
class Imperative {
public:
@@ -237,6 +246,16 @@ class Imperative {
}
return old;
}
+ /*! \brief return current optimization constraints. */
+ OptConstraint get_opt_constraints() const {
+ return opt_constraints_;
+ }
+ /*! \brief set optimization constraints. */
+ OptConstraint set_opt_constraints(OptConstraint constraints) {
+ OptConstraint old = opt_constraints_;
+ opt_constraints_ = constraints;
+ return old;
+ }
/*! \brief to record operator, return corresponding node. */
void RecordOp(nnvm::NodeAttrs&& attrs,
const std::vector<NDArray*>& inputs,
@@ -321,6 +340,7 @@ class Imperative {
static thread_local bool is_train_;
static thread_local bool is_recording_;
static thread_local bool is_deferred_compute_;
+ static thread_local OptConstraint opt_constraints_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_thread_local_;
@@ -328,6 +348,7 @@ class Imperative {
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
static MX_THREAD_LOCAL bool is_deferred_compute_;
+ static MX_THREAD_LOCAL OptConstraint opt_constraints_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
diff --git a/python/mxnet/amp/amp.py b/python/mxnet/amp/amp.py
index 2d370ce7ab..b73e48e846 100644
--- a/python/mxnet/amp/amp.py
+++ b/python/mxnet/amp/amp.py
@@ -39,7 +39,7 @@ from ..symbol import contrib as symbol_contrib
from .. import ndarray
from ..ndarray import NDArray, dtype_np_to_mx, get_dtype_type, get_dtype_name, bfloat16
from . import lists
-from ..gluon import Block, trainer
+from ..gluon import Block, HybridBlock, trainer
from .. import base
from ..base import (_NP_OP_PREFIX, _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_PREFIX,
_NP_EXT_OP_SUBMODULE_LIST, _NP_INTERNAL_OP_PREFIX,
@@ -428,7 +428,7 @@ def unscale(optimizer_or_trainer):
"an optimizer, instead is %s" % type(optimizer_or_trainer))
-def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", target_dtype_ops=None,
+def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype, target_dtype_ops=None,
fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=[],
cast_params_offline=False):
"""Given a symbol object representing a neural network of data type FP32 and target_dtype,
@@ -464,9 +464,7 @@ def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", targ
data_names : list of strs, optional
A list of strings that represent input data tensor names to the model
cast_params_offline : bool, default False
- Whether to cast the arg_params and aux_params that don't require to be in LP16
- because of a cast layer following it, but will reduce the computation and memory
- overhead of the model if casted.
+ Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
"""
import json
@@ -497,22 +495,30 @@ def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", targ
"conditional_fp32_ops should be a list of (str, str, list of str)"
cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
- nodes_attr = sym.attr_dict()
+ nodes_attrs = sym.attr_dict()
nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
- if not set(excluded_sym_names).issubset(set(nodes_op.keys())):
- logging.warning("excluded_sym_names are not present in the network. Missing layers: {}".format(
- set(excluded_sym_names) - set(nodes_op.keys())))
-
for node_name, node_op in nodes_op.items():
if node_op not in cond_ops:
continue
- node_attrs = nodes_attr[node_name]
+ node_attrs = nodes_attrs[node_name]
for attr_name, attr_vals in cond_ops[node_op].items():
assert attr_name in node_attrs
if node_attrs[attr_name] in attr_vals:
- excluded_sym_names += node_name
+ excluded_sym_names.append(node_name)
break
- excluded_sym_names = list(set(excluded_sym_names))
+
+ excluded_sym_names = set(excluded_sym_names)
+ for node in sym.get_internals():
+ if node.name in excluded_sym_names:
+ excluded_sym_names.remove(node.name)
+ opt_constraints = node.attr('__opt_constraint__')
+ opt_constraints = 0 if opt_constraints is None else int(opt_constraints)
+ opt_constraints |= HybridBlock.OptConstraint.Flag.DisableAMP.value
+ node._set_attr(__opt_constraint__=str(opt_constraints))
+
+ if len(excluded_sym_names) > 0:
+ logging.warning("excluded_sym_names are not present in the network. Missing nodes: {}".format(
+ excluded_sym_names))
# Op lists should not intersect
common_ops = set(target_dtype_ops) & set(fp32_ops)
@@ -561,13 +567,11 @@ def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", targ
ctypes.c_uint(len(fp32_ops)),
c_str_array(fp32_ops),
ctypes.c_uint(len(widest_dtype_ops)),
- c_str_array(widest_dtype_ops),
- ctypes.c_uint(len(excluded_sym_names)),
- c_str_array(excluded_sym_names)))
+ c_str_array(widest_dtype_ops)))
return type(sym)(out)
-def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype="float16",
+def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype,
target_dtype_ops=None, fp32_ops=None, conditional_fp32_ops=None,
excluded_sym_names=[], cast_params_offline=False):
"""API for converting a model from FP32 model to a mixed precision model.
@@ -605,9 +609,7 @@ def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype="float
A list of strings that represent the names of symbols that users want to exclude
from being executed in lower precision.
cast_params_offline : bool, default False
- Whether to cast the arg_params and aux_params that don't require to be in LP16
- because of a cast layer following it, but will reduce the computation and memory
- overhead of the model if casted.
+ Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
"""
assert isinstance(sym, Symbol), "First argument to convert_model should be a Symbol"
assert isinstance(
@@ -641,9 +643,9 @@ def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype="float
@wrap_ctx_to_device_func
-def convert_hybrid_block(block, data_example, target_dtype="float16", target_dtype_ops=None,
+def convert_hybrid_block(block, data_example, target_dtype, target_dtype_ops=None,
fp32_ops=None, conditional_fp32_ops=None,
- excluded_sym_names=[], device=gpu(0),
+ excluded_sym_names=[], device=None,
cast_params_offline=False):
"""Given a hybrid block/symbol block representing a FP32 model and a target_dtype,
return a block with mixed precision support which can be used for inference use cases.
@@ -668,14 +670,12 @@ def convert_hybrid_block(block, data_example, target_dtype="float16", target_dty
excluded_sym_names : list of strs
A list of strings that represent the names of symbols that users want to exclude
from being quantized
- device : Context
- Context on which model parameters should live
+ device : Device
+ Device on which model parameters should live. Default value: current device.
cast_params_offline : bool, default False
- Whether to cast the arg_params and aux_params that don't require to be in LP16
- because of a cast layer following it, but will reduce the computation and memory
- overhead of the model if casted.
+ Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
"""
- from ..gluon import HybridBlock, SymbolBlock
+ from ..gluon import SymbolBlock
from ..ndarray import NDArray as ND_NDArray, waitall
from ..numpy import ndarray as NP_NDArray
diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py
index bab52688b4..566990c411 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -18,16 +18,22 @@
# coding: utf-8
"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API."""
+from ...runtime import Features
+
# Functions that should be cast to lower precision
BF16_FUNCS = [
'Convolution',
'Deconvolution',
- 'FullyConnected',
- '_sg_onednn_conv',
- '_sg_onednn_fully_connected',
- '_sg_onednn_selfatt_qk',
- '_sg_onednn_selfatt_valatt'
+ 'FullyConnected'
]
+if Features.instance.is_enabled('ONEDNN'):
+ BF16_FUNCS.extend([
+ '_sg_onednn_conv',
+ '_sg_onednn_fully_connected',
+ '_sg_onednn_selfatt_qk',
+ '_sg_onednn_selfatt_valatt'
+ ])
+
# Functions that should not be casted, either because
# they are irrelevant (not used in the network itself
@@ -45,6 +51,7 @@ BF16_FP32_FUNCS = [
'sqrt',
'square',
'tanh',
+ '_contrib_quantize_v2',
]
# Functions that when running with Bfloat16, the params that still need float32.
@@ -98,7 +105,6 @@ FP32_FUNCS = [
'_contrib_quadratic',
'_contrib_quantize',
'_contrib_quantize_asym',
- '_contrib_quantize_v2',
'_contrib_quantized_concat',
'_contrib_quantized_conv',
'_contrib_quantized_flatten',
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 5b327b1a8d..cff346b9f4 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -20,6 +20,8 @@
"""Base container class for all neural network models."""
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
+import enum
+import ctypes
import copy
import warnings
import weakref
@@ -1045,6 +1047,32 @@ class HybridBlock(Block):
`Hybridize - A Hybrid of Imperative and Symbolic Programming
<https://mxnet.apache.org/versions/master/api/python/docs/tutorials/packages/gluon/blocks/hybridize.html>`_
"""
+ class OptConstraint:
+ class Flag(enum.Flag):
+ DisableAMP = enum.auto()
+
+ def __init__(self, flag) -> None:
+ self.flag = flag
+ self.enter_state = None
+
+ def __enter__(self):
+ self.enter_state = HybridBlock.OptConstraint.Flag(get_optimization_constraints())
+ target_state = self.enter_state | self.flag
+ set_optimization_constraints(target_state)
+
+ def __exit__(self, ptype, value, trace):
+ set_optimization_constraints(self.enter_state)
+
+ @staticmethod
+ def disable_all():
+ opt_flag = HybridBlock.OptConstraint.Flag()
+ for flag in HybridBlock.OptConstraint.Flag:
+ opt_flag |= flag
+
+ @staticmethod
+ def disable_amp():
+ return HybridBlock.OptConstraint(HybridBlock.OptConstraint.Flag.DisableAMP)
+
def __init__(self):
super(HybridBlock, self).__init__()
assert hasattr(self, "hybrid_forward") is False, (
@@ -1550,7 +1578,6 @@ class HybridBlock(Block):
if remove_amp_cast:
handle = SymbolHandle()
- import ctypes
check_call(_LIB.MXSymbolRemoveAmpCast(sym.handle, ctypes.byref(handle)))
sym = type(sym)(handle)
return sym, arg_dict
@@ -1571,7 +1598,6 @@ class HybridBlock(Block):
"""
def c_callback(name, op_name, array):
"""wrapper for user callback"""
- import ctypes
array = ctypes.cast(array, NDArrayHandle)
array = NDArray(array, writable=False)
name = py_str(name)
@@ -1820,12 +1846,12 @@ class SymbolBlock(HybridBlock):
def __call__(self, x, *args):
"""Calls forward. Only accepts positional arguments."""
for hook in self._forward_pre_hooks.values():
- hook(self, [x] + args)
+ hook(self, [x, *args])
out = self.forward(x, *args)
for hook in self._forward_hooks.values():
- hook(self, [x] + args, out)
+ hook(self, [x, *args], out)
return out
@@ -1943,3 +1969,15 @@ def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dt
aux_types.append(default_dtype)
return (arg_types, aux_types)
+
+
+def set_optimization_constraints(state):
+ prev_state = ctypes.c_uint()
+ check_call(_LIB.MXSetOptimizationConstraints(ctypes.c_uint(state.value), ctypes.byref(prev_state)))
+ return HybridBlock.OptConstraint.Flag(prev_state.value)
+
+
+def get_optimization_constraints():
+ curr = ctypes.c_uint()
+ check_call(_LIB.MXGetOptimizationConstraints(ctypes.byref(curr)))
+ return HybridBlock.OptConstraint.Flag(curr.value)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 18fc8e7261..0e6432a999 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2653,6 +2653,8 @@ fixed-size items.
array([[1, 1, 1],
[1, 1, 1]], dtype=int32)
"""
+ if self.dtype == bfloat16:
+ return self.astype(np.float32).asnumpy()
data = np.empty(self.shape, dtype=self.dtype)
check_call(_LIB.MXNDArraySyncCopyToCPU(
self.handle,
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 2e9c0a3736..b91a997b7c 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -291,6 +291,19 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) {
API_END();
}
+int MXSetOptimizationConstraints(unsigned int constraints, unsigned int* prev) {
+ API_BEGIN();
+ *prev =
+ static_cast<unsigned int>(Imperative::Get()->set_opt_constraints(OptConstraint(constraints)));
+ API_END();
+}
+
+int MXGetOptimizationConstraints(unsigned int* curr) {
+ API_BEGIN();
+ *curr = static_cast<unsigned int>(Imperative::Get()->get_opt_constraints());
+ API_END();
+}
+
int MXIsNumpyShape(int* curr) {
API_BEGIN();
*curr = Imperative::Get()->is_np_shape();
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 6f9eeb2dac..e17d7beb2e 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -986,9 +986,7 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle,
const uint32_t num_fp32_ops,
const char** const fp32_ops_p,
const uint32_t num_widest_dtype_ops,
- const char** const widest_dtype_ops_p,
- const uint32_t num_excluded_symbols,
- const char** const excluded_syms_p) {
+ const char** const widest_dtype_ops_p) {
nnvm::Symbol* result_sym = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
@@ -1002,8 +1000,6 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle,
std::unordered_set<std::string> fp32_ops(fp32_ops_p, fp32_ops_p + num_fp32_ops);
std::unordered_set<std::string> widest_dtype_ops(widest_dtype_ops_p,
widest_dtype_ops_p + num_widest_dtype_ops);
- std::unordered_set<std::string> excluded_syms(excluded_syms_p,
- excluded_syms_p + num_excluded_symbols);
nnvm::DTypeVector arg_types(num_all_args);
std::unordered_map<std::string, int> node_name_to_type_map;
@@ -1022,7 +1018,6 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle,
g.attrs["target_dtype_ops"] = std::make_shared<nnvm::any>(std::move(target_dtype_ops));
g.attrs["fp32_ops"] = std::make_shared<nnvm::any>(std::move(fp32_ops));
g.attrs["widest_dtype_ops"] = std::make_shared<nnvm::any>(std::move(widest_dtype_ops));
- g.attrs["excluded_syms"] = std::make_shared<nnvm::any>(std::move(excluded_syms));
g = ApplyPass(std::move(g), "ReducePrecision");
result_sym->outputs = g.outputs;
diff --git a/src/common/utils.h b/src/common/utils.h
index d839a6db40..fe8413f18e 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -438,6 +438,38 @@ inline std::string attr_value_string(const nnvm::NodeAttrs& attrs,
return attrs.dict.at(attr_name);
}
+/*! \brief Seeks an attribute in a node and its subgraphs and invokes a function on each. */
+template <typename Fn>
+inline void attr_foreach(const nnvm::NodeAttrs& attrs, const std::string& attr_name, const Fn& fn) {
+ const auto& found_it = attrs.dict.find(attr_name);
+ if (found_it != attrs.dict.end()) {
+ fn(found_it->second);
+ }
+ for (const auto& subgraph : attrs.subgraphs) {
+ DFSVisit(subgraph->outputs,
+ [&](const nnvm::ObjectPtr& node) { attr_foreach(node->attrs, attr_name, fn); });
+ }
+}
+
+template <typename ValueType>
+inline ValueType flag_attr_accumulate(const nnvm::NodeAttrs& attrs, const std::string& attr_name) {
+ static_assert(std::is_integral<ValueType>::value, "ValueType must be an integral type.");
+
+ ValueType result = 0;
+ attr_foreach(attrs, attr_name, [&](const std::string& attr_value) {
+ std::istringstream ss(attr_value);
+ ValueType temp;
+ ss >> temp;
+ result |= temp;
+
+ if (ss.fail() || !ss.eof()) {
+ LOG(WARNING) << "Incorrect value of an attribute: " << attr_name
+ << ". Expected an integer, while got: " << attr_value;
+ }
+ });
+ return result;
+}
+
/*! \brief get string representation of the operator stypes */
inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs,
const int dev_mask,
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index b9bdaac947..fb123c18c9 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -33,11 +33,13 @@ namespace mxnet {
thread_local bool Imperative::is_train_ = false;
thread_local bool Imperative::is_recording_ = false;
thread_local bool Imperative::is_deferred_compute_ = false;
+thread_local OptConstraint Imperative::opt_constraints_ = OptConstraint::None;
thread_local bool Imperative::is_np_shape_thread_local_ = false;
#else
MX_THREAD_LOCAL bool Imperative::is_train_ = false;
MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
MX_THREAD_LOCAL bool Imperative::is_deferred_compute_ = false;
+MX_THREAD_LOCAL OptConstraint Imperative::opt_constraints_ = OptConstraint::None;
MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false;
#endif
@@ -367,6 +369,11 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs&& attrs,
node_count_++;
}
+ if (get_opt_constraints() != OptConstraint::None) {
+ node->attrs.dict[OPT_CONSTRAINT_ATTR] =
+ std::to_string(static_cast<std::underlying_type_t<OptConstraint>>(get_opt_constraints()));
+ }
+
for (uint32_t i = 0; i < outputs.size(); ++i) {
outputs[i]->deferredcompute_entry_ = nnvm::NodeEntry{node, i, 0};
}
diff --git a/src/nnvm/low_precision_pass.cc b/src/nnvm/low_precision_pass.cc
index 192390eacb..2c2cce5b99 100644
--- a/src/nnvm/low_precision_pass.cc
+++ b/src/nnvm/low_precision_pass.cc
@@ -30,6 +30,7 @@
#include <algorithm>
#include <functional>
#include "operator/operator_common.h"
+#include "common/utils.h"
namespace mxnet {
using nnvm::Graph;
@@ -59,8 +60,9 @@ class MappedNodeEntry {
* converted and dtypes of its outputs may have changed
*/
void UpdateDTypeAfterConversion(const int new_dtype) {
- CHECK_EQ(dtype, original_dtype); // dtype should be changed only once
+ CHECK(dtype == original_dtype || dtype == new_dtype); // dtype should be changed only once
CHECK(entry.node->op());
+ CHECK_NE(new_dtype, -1);
dtype = new_dtype;
}
@@ -91,6 +93,8 @@ class MappedNodeEntry {
/*! \brief Returns whether this entry has the specified dtype or an existing cast to that dtype */
bool HasDTypeEntry(const int target_dtype) const {
+ CHECK_NE(target_dtype, -1);
+
return dtype == target_dtype || casts.count(target_dtype) > 0;
}
@@ -99,6 +103,8 @@ class MappedNodeEntry {
* input entires of a node before its conversion.
*/
bool CanBeCastTo(const int target_dtype) {
+ CHECK_NE(target_dtype, -1);
+
static const auto& amp_cast_op = Op::Get("amp_cast");
static const auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType")[amp_cast_op];
nnvm::NodeAttrs dummy_atts;
@@ -113,6 +119,8 @@ class MappedNodeEntry {
/*! \brief Returns whether this NodeEntry (of a parameter) can be cast offline */
bool CanBeCastOfflineTo(const int target_dtype) const {
CHECK(entry.node->is_variable());
+ CHECK_NE(target_dtype, -1);
+
return casts.count(target_dtype) > 0;
}
@@ -177,8 +185,21 @@ static bool TryLowPrecision(const int target_dtype,
static const auto& fmutate_inputs = Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
std::vector<int> in_types(old_node->inputs.size(), -1);
+ bool has_lp_input = false;
+ for (int i = 0; i < old_node->inputs.size(); ++i) {
+ if (entry_map->at(old_node->inputs[i]).HasDTypeEntry(target_dtype)) {
+ in_types[i] = target_dtype;
+ has_lp_input = true;
+ }
+ }
+ if (!has_lp_input) {
+ // when inputs are not already in low precision, assume the first input should be in low
+ // precision in order to convert this op
+ in_types[0] = target_dtype;
+ }
+
+ // infer types of other inputs
std::vector<int> out_types(old_node->num_outputs(), -1);
- in_types[0] = target_dtype;
if (infertype.count(old_node->op()) == 0 ||
infertype[old_node->op()](old_node->attrs, &in_types, &out_types) == false) {
return false;
@@ -226,21 +247,38 @@ static void HandleWidestDtypeNode(const int target_dtype,
EntryMap_t* const entry_map) {
static const auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
- std::vector<int> in_types(old_node->inputs.size(), target_dtype);
- std::vector<int> out_types(old_node->num_outputs(), -1);
- const bool inferred = (infertype.count(old_node->op()) > 0 &&
- infertype[old_node->op()](old_node->attrs, &in_types, &out_types));
-
- bool has_lp_inputs = inferred;
- for (int i = 0; has_lp_inputs && i < old_node->inputs.size(); ++i) {
- const NodeEntry& input = old_node->inputs[i];
- has_lp_inputs &= entry_map->at(input).HasDTypeEntry(in_types[i]);
+ // gather info about current dtypes of inputs
+ // if there is already at least one input with target dtype, we try converting to low precision
+ bool try_lp = false;
+ std::vector<int> in_types(old_node->inputs.size(), -1);
+ for (int i = 0; i < old_node->inputs.size(); ++i) {
+ if (entry_map->at(old_node->inputs[i]).HasDTypeEntry(target_dtype)) {
+ in_types[i] = target_dtype; // set only lp inputs
+ try_lp = true;
+ }
}
- if (!has_lp_inputs ||
- !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, entry_map)) {
- KeepOriginalNode(old_node, node_map, entry_map);
+ if (try_lp) {
+ // run infertype, to see what other input types this op needs with the current lp inputs
+ std::vector<int> out_types(old_node->num_outputs(), -1);
+ try_lp = (infertype.count(old_node->op()) > 0 &&
+ infertype[old_node->op()](old_node->attrs, &in_types, &out_types));
+
+ if (try_lp) {
+ // if we have to add casts to inputs, this op shouldn't run in low precision
+ for (int i = 0; i < old_node->inputs.size(); ++i) {
+ const NodeEntry& old_input_ne = old_node->inputs[i];
+ if (in_types[i] != -1 && !entry_map->at(old_input_ne).HasDTypeEntry(in_types[i])) {
+ try_lp = false;
+ break;
+ }
+ }
+ if (try_lp && TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, entry_map)) {
+ return;
+ }
+ }
}
+ KeepOriginalNode(old_node, node_map, entry_map);
}
/*!
* \brief Tries to convert the node to low precision if some of its inputs already are converted.
@@ -309,7 +347,6 @@ Graph ReducePrecision(Graph&& src) {
const auto& target_dtype_ops = src.GetAttr<std::unordered_set<std::string>>("target_dtype_ops");
const auto& fp32_ops = src.GetAttr<std::unordered_set<std::string>>("fp32_ops");
const auto& widest_dtype_ops = src.GetAttr<std::unordered_set<std::string>>("widest_dtype_ops");
- const auto& excluded_syms = src.GetAttr<std::unordered_set<std::string>>("excluded_syms");
auto src_dtypes = src.GetAttr<nnvm::DTypeVector>("dtype"); // copy, not reference
CHECK(target_dtype == mshadow::kFloat16 || target_dtype == mshadow::kBfloat16)
@@ -360,7 +397,7 @@ Graph ReducePrecision(Graph&& src) {
}
// convert the model
- DFSVisit(src.outputs, [&](const ObjectPtr& old_node) {
+ const auto convert_node_fn = [&](const ObjectPtr& old_node) {
if (old_node->is_variable() || old_node->op() == Op::Get("amp_multicast") ||
IsCastOp(old_node->op())) {
const ObjectPtr& new_node = node_map.at(old_node.get());
@@ -370,8 +407,10 @@ Graph ReducePrecision(Graph&& src) {
}
return;
}
-
- if (fp32_ops.count(old_node->op()->name) > 0 || excluded_syms.count(old_node->attrs.name) > 0) {
+ auto opt_constraints = common::flag_attr_accumulate<std::underlying_type_t<OptConstraint>>(
+ old_node->attrs, OPT_CONSTRAINT_ATTR);
+ if (fp32_ops.count(old_node->op()->name) > 0 ||
+ (opt_constraints & static_cast<int>(OptConstraint::DisableAMP))) {
KeepOriginalNode(old_node, node_map, &entry_map);
} else if (target_dtype_ops.count(old_node->op()->name) > 0) {
if (!TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, &entry_map)) {
@@ -384,7 +423,20 @@ Graph ReducePrecision(Graph&& src) {
} else {
HandleDTypeNeutralNode(target_dtype, old_node, node_map, nodes_entries, &entry_map);
}
+ };
+
+ // Because some nodes depend on casts present in the graph, the order of visited nodes will
+ // determine whether some nodes are converted or not. To avoid this, first we make a virtual
+ // conversion pass in order to have all the necessary casts already present (in the
+ // MappedNodeEntry instances) during the second (true) conversion pass
+
+ // virtual conversion pass
+ DFSVisit(src.outputs, [&](const ObjectPtr& old_node) {
+ convert_node_fn(old_node);
+ node_map[old_node.get()]->inputs.clear(); // make this pass "virtual" by removing edges
});
+ // true conversion pass
+ DFSVisit(src.outputs, [&](const ObjectPtr& old_node) { convert_node_fn(old_node); });
std::vector<NodeEntry> outputs;
for (const auto& out_ne : src.outputs) {
diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h
index 3337c0a833..7f297d6e57 100644
--- a/src/operator/quantization/quantize_v2-inl.h
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -151,8 +151,20 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 3U);
const QuantizeV2Param& param = nnvm::get<QuantizeV2Param>(attrs.parsed);
+
+#if MXNET_USE_ONEDNN == 1
+ if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+ CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kBfloat16 ||
+ in_attrs->at(0) == mshadow::kUint8 || in_attrs->at(0) == mshadow::kInt8);
+ } else {
+ CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
+ in_attrs->at(0) == mshadow::kInt8);
+ }
+#else
CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
in_attrs->at(0) == mshadow::kInt8);
+#endif
+
auto out_type = GetQuantizeOutputType(param);
if (out_type == mshadow::kUint8) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
diff --git a/tests/python/amp/common.py b/tests/python/amp/common.py
new file mode 100644
index 0000000000..3e221c613e
--- /dev/null
+++ b/tests/python/amp/common.py
@@ -0,0 +1,245 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import warnings
+import collections
+import mxnet as mx
+from mxnet import amp
+from mxnet.gluon import nn
+from mxnet.operator import get_all_registered_operators_grouped
+
+
+def test_amp_coverage(lp_dtype, lp_name):
+ conditional = [item[0] for item in amp.list_conditional_fp32_ops(lp_dtype)]
+ lp16_ops = amp.list_lp16_ops(lp_dtype)
+ lp16_fp32_ops = amp.list_lp16_fp32_ops(lp_dtype)
+ fp32_ops = amp.list_fp32_ops(lp_dtype)
+ widest_ops = amp.list_widest_type_cast(lp_dtype)
+ all_lp_lists = [lp16_ops, lp16_fp32_ops, fp32_ops, widest_ops, conditional]
+
+ # Check for duplicates
+ for op_list in all_lp_lists:
+ ret = [op for op, count in collections.Counter(op_list).items() if count > 1]
+ assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
+
+ all_lp_ops = [op for op_list in all_lp_lists for op in op_list]
+ ret = [op for op, count in collections.Counter(all_lp_ops).items() if count > 1]
+ assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
+
+ # Check the coverage
+ covered_ops = set(all_lp_ops)
+ all_mxnet_ops = get_all_registered_operators_grouped()
+ required_ops = {op for op in all_mxnet_ops if not "backward" in op}
+
+ extra_ops = covered_ops - required_ops
+ assert not extra_ops, f"{len(extra_ops)} operators are not needed in the AMP lists: {sorted(extra_ops)}"
+
+ guidelines = f"""Please follow these guidelines for choosing a proper list:
+ - if your operator is not to be used in a computational graph
+ (e.g. image manipulation operators, optimizers) or does not have
+ inputs, put it in {lp_name.upper()}_FP32_FUNCS list,
+ - if your operator requires FP32 inputs or is not safe to use with lower
+ precision, put it in FP32_FUNCS list,
+ - if your operator supports both FP32 and lower precision, has
+ multiple inputs and expects all inputs to be of the same
+ type, put it in WIDEST_TYPE_CASTS list,
+ - if your operator supports both FP32 and lower precision and has
+ either a single input or supports inputs of different type,
+ put it in {lp_name.upper()}_FP32_FUNCS list,
+ - if your operator is both safe to use in lower precision and
+ it is highly beneficial to use it in lower precision, then
+ put it in {lp_name.upper()}_FUNCS (this is unlikely for new operators)
+ - If you are not sure which list to choose, FP32_FUNCS is the
+ safest option"""
+ missing_ops = required_ops - covered_ops
+
+ if len(missing_ops) > 0:
+ warnings.warn(f"{len(missing_ops)} operators {sorted(missing_ops)} do not exist in AMP lists "
+ f"(in python/mxnet/amp/lists/symbol_{lp_name.lower()}.py) - please add them. \n{guidelines}")
+
+
+def test_amp_basic_use(lp_dtype):
+ class TestNet(nn.HybridBlock):
+ def __init__(self):
+ super().__init__()
+ self.fc1 = nn.Dense(4)
+ self.fc2 = nn.Dense(4)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.fc2(x)
+ return x.reshape((-1, 2, 2))
+
+ data_example = mx.np.random.uniform(-1, 1, (4, 4))
+
+ net = TestNet()
+ net.initialize()
+ net = amp.convert_hybrid_block(net, data_example, lp_dtype)
+
+ lp16_casts = 1 # cast for network input
+ lp16_casts += 2 # cast for weights and bias of `fc1`
+ lp16_casts += 2 # cast for weights and bias of `fc2`
+
+ other_casts = 1 # cast for the network output (from lp16 to f32)
+
+ lp16_tensors = 1 # cast network input
+ lp16_tensors += 3 # cast weights and bias of `fc1`, `fc1` output
+ lp16_tensors += 3 # cast weights and bias of `fc2`, `fc2` output
+ lp16_tensors += 1 # reshape output
+ check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num=lp16_tensors, lp16_casts_num=lp16_casts,
+ other_casts_num=other_casts)
+
+
+def test_amp_offline_casting(lp_dtype):
+ class TestNet(nn.HybridBlock):
+ def __init__(self):
+ super().__init__()
+ self.lp16_op1 = nn.Conv2D(4, 3)
+ self.lp16_op2 = nn.Conv2DTranspose(4, 3)
+ self.fp32_op = nn.Dense(4)
+
+ def forward(self, x):
+ x = self.lp16_op1(x)
+ x = self.lp16_op2(x)
+ x = x.reshape(x.shape[0], -1)
+ with nn.HybridBlock.OptConstraint.disable_amp():
+ x = self.fp32_op(x)
+ return x
+
+ net = TestNet()
+ net.initialize()
+ data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
+ lp_net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
+
+ check_amp_net_stats(lp_dtype, lp_net, data_example, lp16_tensors_num=4,
+ lp16_casts_num=1, other_casts_num=1)
+ for name, data in lp_net.collect_params().items():
+ assert mx.nd.get_dtype_name(data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
+
+
+def test_amp_offline_casting_shared_params(lp_dtype):
+ COMMON_SIZE = 4
+
+ class TestNet(nn.HybridBlock):
+ def __init__(self):
+ super().__init__()
+ self.lp16_op1 = nn.Dense(COMMON_SIZE)
+ self.lp16_op2 = nn.Dense(COMMON_SIZE)
+ self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
+ self.fp32_op = nn.Dense(COMMON_SIZE)
+ self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
+
+ def forward(self, x):
+ x = self.lp16_op1(x)
+ x1 = self.lp16_op2(x)
+ with nn.HybridBlock.OptConstraint.disable_amp():
+ x2 = self.fp32_op(x)
+ x = mx.np.concat((x1, x2), axis=1)
+ return x
+
+ net = TestNet()
+ net.initialize()
+ data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
+ lp_net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
+
+ check_amp_net_stats(lp_dtype, lp_net, data_example, lp16_tensors_num=4,
+ lp16_casts_num=2, other_casts_num=2)
+ for name, data in lp_net.collect_params().items():
+ assert mx.nd.get_dtype_name(data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
+
+
+def test_lp16_fp32_ops_order_independence(lp_dtype):
+ class TestNet(nn.HybridBlock):
+ def __init__(self, lp16_fp32_is_first):
+ super().__init__()
+ if lp16_fp32_is_first:
+ self.first = mx.npx.batch_flatten # lp16_fp32_op
+ self.second = nn.Dense(4)
+ else:
+ self.first = nn.Dense(4)
+ self.second = mx.npx.batch_flatten # lp16_fp32_op
+
+ def forward(self, x):
+ x = 2**x
+ x1 = self.first(x)
+ x2 = self.second(x)
+ return x1, x2
+
+ data_example = mx.np.random.uniform(-1, 1, (4, 16))
+
+ for lp16_fp32_is_second in [False, True]:
+ net = TestNet(lp16_fp32_is_second)
+ net.initialize()
+ net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
+ check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num=3,
+ lp16_casts_num=1, other_casts_num=2)
+
+
+def test_amp_node_excluding(lp_dtype):
+ DISABLE_AMP_ATTR_DICT = {'__opt_constraint__': str(
+ mx.gluon.HybridBlock.OptConstraint.Flag.DisableAMP.value)}
+
+ data = mx.sym.var('data')
+ wei = mx.sym.var('weights')
+ bias = mx.sym.var('bias')
+ # manually excluded
+ fc1 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc1', attr=DISABLE_AMP_ATTR_DICT)
+ # to be excluded using the conversion API
+ fc2 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc2')
+ symnet = mx.sym.Group([fc1, fc2])
+
+ net = mx.gluon.SymbolBlock(symnet, [data])
+ net.initialize()
+
+ # exclude only nodes with set attribute (only 1 node - `fc1`)
+ data_example = mx.np.random.uniform(-1, 1, (4, 16))
+ net_1_excluded = amp.convert_hybrid_block(net, data_example, lp_dtype)
+
+ lp16_tensors = 4 # cast `data`, weights and bias of `fc1`, `fc1` output
+ lp16_casts = 3 # `data` cast, casts for weights and bias of `fc1`
+ other_casts = 1 # cast for the network output (from lp16 to f32)
+ check_amp_net_stats(lp_dtype, net_1_excluded, data_example, lp16_tensors_num=lp16_tensors,
+ lp16_casts_num=lp16_casts, other_casts_num=other_casts)
+
+ # exclude using the `excluded_sym_names` argument (both nodes)
+ net_2_excluded = amp.convert_hybrid_block(net, data_example, lp_dtype,
+ excluded_sym_names=['fc1', 'fc2'])
+ check_amp_net_stats(lp_dtype, net_2_excluded, data_example, lp16_tensors_num=0,
+ lp16_casts_num=0, other_casts_num=0)
+
+
+def check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num, lp16_casts_num, other_casts_num):
+ lp16_tensors = set()
+ lp16_casts = set()
+ other_casts = set()
+
+ def inspect_output(tensor_name, op_name, tensor):
+ dtype = mx.nd.get_dtype_name(tensor.dtype)
+ if op_name == 'amp_cast':
+ if dtype == lp_dtype:
+ lp16_casts.add(tensor_name)
+ else:
+ other_casts.add(tensor_name)
+ if dtype == lp_dtype:
+ lp16_tensors.add(tensor_name)
+
+ net.register_op_hook(inspect_output)
+ net(data_example)
+
+ assert len(lp16_tensors) == lp16_tensors_num, f'Bad lp16 tensors! Present tensors: {sorted(lp16_tensors)}'
+ assert len(lp16_casts) == lp16_casts_num, f'Bad lp16 casts! Present casts: {sorted(lp16_casts)}'
+ assert len(other_casts) == other_casts_num, f'Bad casts! Present casts: {sorted(other_casts)}'
diff --git a/tests/python/dnnl/subgraphs/test_amp_subgraph.py b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
index 4dd7337855..2c5c6e1b45 100644
--- a/tests/python/dnnl/subgraphs/test_amp_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
@@ -17,15 +17,21 @@
import json
import mxnet as mx
-import mxnet.gluon.nn as nn
from mxnet import amp
-from mxnet.amp.amp import bfloat16
+from mxnet.gluon import nn
from mxnet.test_utils import assert_almost_equal
from subgraph_common import SG_PASS_NAME, QUANTIZE_SG_PASS_NAME
from test_matmul_subgraph import MultiHeadAttention
+import sys
+from pathlib import Path
+curr_path = Path(__file__).resolve().parent
+sys.path.insert(0, str(curr_path.parent.parent))
+
+from amp.common import check_amp_net_stats
+
AMP_SG_PASS_NAME = 'ONEDNN_AMP'
-AMP_DTYPE = bfloat16
+AMP_DTYPE = 'bfloat16'
# Checks if amp (after the AMP_SG_PASS_NAME fuse) changes the name of tensors for calibration
@@ -86,7 +92,7 @@ def check_amp_fuse(net, data_example, expected_sym=None, quantized_nodes=[], rto
@mx.util.use_np
def test_amp_fc():
- class TestNet(mx.gluon.HybridBlock):
+ class TestNet(nn.HybridBlock):
def __init__(self):
super(TestNet, self).__init__()
self.fc1 = nn.Dense(16)
@@ -115,7 +121,7 @@ def test_amp_fc():
@mx.util.use_np
def test_amp_conv():
- class TestNet(mx.gluon.HybridBlock):
+ class TestNet(nn.HybridBlock):
def __init__(self):
super(TestNet, self).__init__()
self.conv1 = nn.Conv2D(16, (3, 3))
@@ -166,7 +172,7 @@ def test_amp_transformers():
@mx.util.use_np
def test_amp_concat():
- class TestNet(mx.gluon.HybridBlock):
+ class TestNet(nn.HybridBlock):
def __init__(self):
super(TestNet, self).__init__()
self.fc1 = nn.Dense(16)
@@ -241,3 +247,30 @@ def test_amp_fuse_with_branch():
exp_sym = mx.sym.Group([lp16_op_2, f32_op])
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
check_amp_fuse(net, [data_example], exp_sym)
+
+
+def test_amp_excluding_after_graph_pass():
+ class TestNet(nn.HybridBlock):
+ def __init__(self):
+ super(TestNet, self).__init__()
+ self.fc1 = nn.Dense(16)
+ self.fc2 = nn.Dense(16)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ with nn.HybridBlock.OptConstraint.disable_amp():
+ x = self.fc2(x)
+ return x
+
+ data_example = mx.np.random.uniform(-1, 1, (1, 8))
+ net = TestNet()
+ net.initialize()
+
+ net_before = amp.convert_hybrid_block(net, data_example, AMP_DTYPE, cast_params_offline=True)
+ check_amp_net_stats(AMP_DTYPE, net_before, data_example, lp16_tensors_num=2,
+ lp16_casts_num=1, other_casts_num=1)
+
+ net.optimize_for(data_example, backend=SG_PASS_NAME) # introduces new nodes
+ net_after = amp.convert_hybrid_block(net, data_example, AMP_DTYPE, cast_params_offline=True)
+ check_amp_net_stats(AMP_DTYPE, net_after, data_example, lp16_tensors_num=2,
+ lp16_casts_num=1, other_casts_num=1)
diff --git a/tests/python/dnnl/test_amp.py b/tests/python/dnnl/test_amp.py
index 90a66bc608..73be9bb9a8 100644
--- a/tests/python/dnnl/test_amp.py
+++ b/tests/python/dnnl/test_amp.py
@@ -15,133 +15,42 @@
# specific language governing permissions and limitations
# under the License.
-import os
import sys
+from pathlib import Path
+curr_path = Path(__file__).resolve().parent
+sys.path.insert(0, str(curr_path.parent))
+
import mxnet as mx
-import numpy as np
-import warnings
-import collections
-import ctypes
-from mxnet import amp
-from mxnet.amp.amp import bfloat16
-from mxnet.gluon import nn
-from mxnet.operator import get_all_registered_operators_grouped
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-
-
-def test_amp_coverage():
- conditional = [item[0] for item in amp.lists.symbol_bf16.CONDITIONAL_FP32_FUNCS]
-
- # Check for duplicates
- for a in [amp.lists.symbol_bf16.BF16_FUNCS,
- amp.lists.symbol_bf16.BF16_FP32_FUNCS,
- amp.lists.symbol_bf16.FP32_FUNCS,
- amp.lists.symbol_bf16.WIDEST_TYPE_CASTS,
- conditional]:
- ret = [item for item, count in collections.Counter(a).items() if count > 1]
- assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
-
- t = []
- for a in [amp.lists.symbol_bf16.BF16_FUNCS,
- amp.lists.symbol_bf16.BF16_FP32_FUNCS,
- amp.lists.symbol_bf16.FP32_FUNCS,
- amp.lists.symbol_bf16.WIDEST_TYPE_CASTS,
- conditional]:
- t += a
- ret = [item for item, count in collections.Counter(t).items() if count > 1]
- assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
-
- # Check the coverage
- covered = set(t)
- ops = get_all_registered_operators_grouped()
- required = set(k for k in ops
- if not k.startswith(("_backward", "_contrib_backward", "_npi_backward")) and
- not k.endswith("_backward"))
-
- extra = covered - required
- assert not extra, f"{len(extra)} operators are not needed in the AMP lists: {sorted(extra)}"
-
- guidelines = """Please follow these guidelines for choosing a proper list:
- - if your operator is not to be used in a computational graph
- (e.g. image manipulation operators, optimizers) or does not have
- inputs, put it in BF16_FP32_FUNCS list,
- - if your operator requires FP32 inputs or is not safe to use with lower
- precision, put it in FP32_FUNCS list,
- - if your operator supports both FP32 and lower precision, has
- multiple inputs and expects all inputs to be of the same
- type, put it in WIDEST_TYPE_CASTS list,
- - if your operator supports both FP32 and lower precision and has
- either a single input or supports inputs of different type,
- put it in BF16_FP32_FUNCS list,
- - if your operator is both safe to use in lower precision and
- it is highly beneficial to use it in lower precision, then
- put it in BF16_FUNCS (this is unlikely for new operators)
- - If you are not sure which list to choose, FP32_FUNCS is the
- safest option"""
- diff = required - covered
-
- if len(diff) > 0:
- warnings.warn(f"{len(diff)} operators {sorted(diff)} do not exist in AMP lists (in "
- f"python/mxnet/amp/lists/symbol_bf16.py) - please add them. "
- f"\n{guidelines}")
+import amp.common as amp_common_tests
+
+
+AMP_DTYPE = 'bfloat16'
+
+
+def test_bf16_coverage():
+ amp_common_tests.test_amp_coverage(AMP_DTYPE, 'BF16')
+
+
+@mx.util.use_np
+def test_bf16_basic_use():
+ amp_common_tests.test_amp_basic_use(AMP_DTYPE)
@mx.util.use_np
def test_bf16_offline_casting():
- class TestNet(nn.HybridBlock):
- def __init__(self):
- super().__init__()
- self.lp16_op1 = nn.Conv2D(4, 3)
- self.lp16_op2 = nn.Conv2DTranspose(4, 3)
- self.fp32_op = nn.Dense(4)
-
- def forward(self, x):
- x = self.lp16_op1(x)
- x = self.lp16_op2(x)
- x = x.reshape(x.shape[0], -1)
- x = self.fp32_op(x)
- return x
-
- net = TestNet()
- net.initialize()
- data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
- lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=bfloat16,
- target_dtype_ops=['Convolution'], fp32_ops=['FullyConnected'],
- cast_params_offline=True, device=mx.current_context())
- lp_net(data_example)
- for name, data in lp_net.collect_params().items():
- assert data.dtype == (np.float32 if 'fp32_op' in name else bfloat16)
+ amp_common_tests.test_amp_offline_casting(AMP_DTYPE)
@mx.util.use_np
def test_bf16_offline_casting_shared_params():
- COMMON_SIZE = 4
-
- class TestNet(nn.HybridBlock):
- def __init__(self):
- super().__init__()
- self.lp16_op1 = nn.Dense(COMMON_SIZE)
- self.lp16_op2 = nn.Dense(COMMON_SIZE)
- self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
- self.fp32_op = nn.Conv1D(COMMON_SIZE, 3)
- self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
-
- def forward(self, x):
- x = self.lp16_op1(x)
- x1 = self.lp16_op2(x)
- x2 = mx.np.expand_dims(x, 1)
- x2 = self.fp32_op(x2)
- x2 = mx.npx.batch_flatten(x2)
- x = mx.np.concat((x1, x2), axis=1)
- return x
-
- net = TestNet()
- net.initialize()
- data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
- lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=bfloat16,
- target_dtype_ops=['FullyConnected'], fp32_ops=['Convolution'],
- cast_params_offline=True, device=mx.current_context())
- lp_net(data_example)
- for name, data in lp_net.collect_params().items():
- assert data.dtype == (np.float32 if 'fp32_op' in name else bfloat16)
+ amp_common_tests.test_amp_offline_casting_shared_params(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_bf16_fp32_ops_order_independence():
+ amp_common_tests.test_lp16_fp32_ops_order_independence(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_bf16_test_node_excluding():
+ amp_common_tests.test_amp_node_excluding(AMP_DTYPE)
diff --git a/tests/python/gpu/test_amp.py b/tests/python/gpu/test_amp.py
index abc8f326e3..0c8ce79a71 100644
--- a/tests/python/gpu/test_amp.py
+++ b/tests/python/gpu/test_amp.py
@@ -15,86 +15,54 @@
# specific language governing permissions and limitations
# under the License.
-import os
import sys
+from pathlib import Path
+curr_path = Path(__file__).resolve().parent
+sys.path.insert(0, str(curr_path.parent))
+sys.path.insert(0, str(curr_path.parent/'unittest'))
+
import mxnet as mx
-import numpy as np
-from random import randint
-import warnings
-import collections
-import ctypes
-from mxnet import amp
import pytest
-from mxnet.test_utils import set_default_device, same_symbol_structure
-from mxnet.gluon.model_zoo.vision import get_model
-from mxnet.gluon import SymbolBlock, nn, rnn
-from mxnet.operator import get_all_registered_operators_grouped
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.insert(0, os.path.join(curr_path, '../unittest'))
+from mxnet import amp
+from mxnet.test_utils import set_default_device
+from mxnet.gluon import nn, rnn
+
+import amp.common as amp_common_tests
from common import assert_raises_cudnn_not_satisfied
-sys.path.insert(0, os.path.join(curr_path, '../train'))
+
+AMP_DTYPE = 'float16'
+
set_default_device(mx.gpu(0))
-@pytest.fixture()
-def amp_tests(request):
- def teardown():
- mx.nd.waitall()
-
- request.addfinalizer(teardown)
-
-def test_amp_coverage(amp_tests):
- conditional = [item[0] for item in amp.lists.symbol_fp16.CONDITIONAL_FP32_FUNCS]
-
- # Check for duplicates
- for a in [amp.lists.symbol_fp16.FP16_FUNCS,
- amp.lists.symbol_fp16.FP16_FP32_FUNCS,
- amp.lists.symbol_fp16.FP32_FUNCS,
- amp.lists.symbol_fp16.WIDEST_TYPE_CASTS,
- conditional]:
- ret = [item for item, count in collections.Counter(a).items() if count > 1]
- assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
-
- t = []
- for a in [amp.lists.symbol_fp16.FP16_FUNCS,
- amp.lists.symbol_fp16.FP16_FP32_FUNCS,
- amp.lists.symbol_fp16.FP32_FUNCS,
- amp.lists.symbol_fp16.WIDEST_TYPE_CASTS,
- conditional]:
- t += a
- ret = [item for item, count in collections.Counter(t).items() if count > 1]
- assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
-
- # Check the coverage
- covered = set(t)
- ops = get_all_registered_operators_grouped()
- required = set(k for k in ops
- if not k.startswith(("_backward", "_contrib_backward", "_npi_backward")) and
- not k.endswith("_backward"))
-
- extra = covered - required
- assert not extra, f"{len(extra)} operators are not needed in the AMP lists: {sorted(extra)}"
-
- guidelines = """Please follow these guidelines for choosing a proper list:
- - if your operator is not to be used in a computational graph
- (e.g. image manipulation operators, optimizers) or does not have
- inputs, put it in FP16_FP32_FUNCS list,
- - if your operator requires FP32 inputs or is not safe to use with lower
- precision, put it in FP32_FUNCS list,
- - if your operator supports both FP32 and lower precision, has
- multiple inputs and expects all inputs to be of the same
- type, put it in WIDEST_TYPE_CASTS list,
- - if your operator supports both FP32 and lower precision and has
- either a single input or supports inputs of different type,
- put it in FP16_FP32_FUNCS list,
- - if your operator is both safe to use in lower precision and
- it is highly beneficial to use it in lower precision, then
- put it in FP16_FUNCS (this is unlikely for new operators)
- - If you are not sure which list to choose, FP32_FUNCS is the
- safest option"""
- diff = required - covered
- assert not diff, f"{len(diff)} operators {sorted(diff)} do not exist in AMP lists (in " \
- f"python/mxnet/amp/lists/symbol_fp16.py) - please add them. " \
- f"\n{guidelines}"
+
+def test_fp16_coverage():
+ amp_common_tests.test_amp_coverage(AMP_DTYPE, 'FP16')
+
+
+@mx.util.use_np
+def test_fp16_basic_use():
+ amp_common_tests.test_amp_basic_use(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_fp16_offline_casting():
+ amp_common_tests.test_amp_offline_casting(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_fp16_offline_casting_shared_params():
+ amp_common_tests.test_amp_offline_casting_shared_params(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_fp16_fp32_ops_order_independence():
+ amp_common_tests.test_lp16_fp32_ops_order_independence(AMP_DTYPE)
+
+
+@mx.util.use_np
+def test_fp16_test_node_excluding():
+ amp_common_tests.test_amp_node_excluding(AMP_DTYPE)
+
@pytest.mark.skip(reason='Error during waitall(). Tracked in #18099')
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@@ -109,63 +77,3 @@ def test_amp_conversion_rnn(amp_tests):
new_model = amp.convert_hybrid_block(model)
out2 = new_model(mx.nd.ones((2, 3, 4)))
mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2)
-
-
-@mx.util.use_np
-def test_fp16_offline_casting():
- class TestNet(nn.HybridBlock):
- def __init__(self):
- super().__init__()
- self.lp16_op1 = nn.Conv2D(4, 3)
- self.lp16_op2 = nn.Conv2DTranspose(4, 3)
- self.fp32_op = nn.Dense(4)
-
- def forward(self, x):
- x = self.lp16_op1(x)
- x = self.lp16_op2(x)
- x = x.reshape(x.shape[0], -1)
- x = self.fp32_op(x)
- return x
-
- net = TestNet()
- net.initialize()
- data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
- lp_net = amp.convert_hybrid_block(net, data_example, target_dtype='float16',
- target_dtype_ops=['Convolution'], fp32_ops=['FullyConnected'],
- cast_params_offline=True, device=mx.current_context())
- lp_net(data_example)
- for name, data in lp_net.collect_params().items():
- assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')
-
-
-@mx.util.use_np
-def test_fp16_offline_casting_shared_params():
- COMMON_SIZE = 4
-
- class TestNet(nn.HybridBlock):
- def __init__(self):
- super().__init__()
- self.lp16_op1 = nn.Dense(COMMON_SIZE)
- self.lp16_op2 = nn.Dense(COMMON_SIZE)
- self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
- self.fp32_op = nn.Conv1D(COMMON_SIZE, 3)
- self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
-
- def forward(self, x):
- x = self.lp16_op1(x)
- x1 = self.lp16_op2(x)
- x2 = mx.np.expand_dims(x, 1)
- x2 = self.fp32_op(x2)
- x2 = nn.Flatten()(x2)
- x = mx.np.concat((x1, x2), axis=1)
- return x
-
- net = TestNet()
- net.initialize()
- data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
- lp_net = amp.convert_hybrid_block(net, data_example, target_dtype='float16',
- target_dtype_ops=['FullyConnected'], fp32_ops=['Convolution'],
- cast_params_offline=True, device=mx.current_context())
- lp_net(data_example)
- for name, data in lp_net.collect_params().items():
- assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 65d973171f..5b992bcde3 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -75,6 +75,23 @@ def test_quantize_float32_to_int8():
qdata_np = (onp.sign(data_np) * onp.minimum(onp.abs(data_np) * scale + 0.5, quantized_range)).astype(onp.int8)
assert_almost_equal(qdata.asnumpy(), qdata_np, atol = 1)
+def test_calibrated_quantize_v2_bfloat16_to_int8():
+ shape = rand_shape_nd(4)
+ data = mx.nd.random.normal(0, 1, shape).astype('bfloat16')
+ min_range = mx.nd.min(data).asscalar()
+ max_range = mx.nd.max(data).asscalar()
+ qdata, min_val, max_val = mx.nd.contrib.quantize_v2(data, 'int8', min_range, max_range)
+ data_np = data.asnumpy()
+ real_range = onp.maximum(onp.abs(min_range), onp.abs(max_range))
+ quantized_range = 127.0
+ scale = quantized_range / real_range
+ assert qdata.dtype == onp.int8
+ assert min_val.dtype == onp.float32
+ assert max_val.dtype == onp.float32
+ assert same(min_val.asscalar(), -real_range)
+ assert same(max_val.asscalar(), real_range)
+ qdata_np = (onp.sign(data_np) * onp.minimum(onp.abs(data_np) * scale + 0.5, quantized_range)).astype(onp.int8)
+ assert_almost_equal(qdata.asnumpy(), qdata_np, atol=1)
def test_dequantize_int8_to_float32():