You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2019/05/21 03:22:09 UTC

[incubator-mxnet] branch master updated: MXNet AMP (automatic mixed precision) (#14173)

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5bc08ce  MXNet AMP (automatic mixed precision) (#14173)
5bc08ce is described below

commit 5bc08cec232d04bd812cc8ade2dc1b5469e0bd2b
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Mon May 20 20:21:26 2019 -0700

    MXNet AMP (automatic mixed precision) (#14173)
    
    * Beginning of AMP
    
    * Optimize noop cast
    
    * More operations added
    
    * Backward cast
    
    * Adding AMPCast and AMPMultiCast
    
    * Fix some of lint
    
    * Changed symbol wrapper to handle hidden inputs
    Added PoC of dynamic loss scaling
    
    * Moved back to dmlc/tvm repo
    
    * fix counter reset to increase loss scale every 2k iterations
    
    * Fix indentation
    
    * Add contrib from symbol and ndarray to symbol list
    
    * Adding where to widest type cast
    
    * Do not cast in imperative mode on CPU context
    
    * Update dmlc-core to fix unittests
    
    * Fix wrapper metadata, fix self handling
    
    * Blacklist sync batchnorm (since its implementation is FP32 only)
    
    * Fix lint
    
    * Enable losses to be tuple
    
    * Get rid of AMP handle
    
    * Add scaling to Output functions
    
    * Fix pylint
    
    * Update dmlc-core
    
    * Changing prints in AMP to logging.info
    
    * NNVM -> MXNet for FInferShape
    
    * Bring the inplaceidentity fix to copied pass from NNVM
    
    * Added tutorial for AMP
    
    * Making Windows compiler happy
    
    * Fixes to tutorial
    
    * More fixes
    
    * Fix lint
    
    * Fix
    
    * Add amp/index.md to whitelist for tutorial tests
    
    * Whitelisting cuDNN RNN
    
    * Manual unscale
    
    * _internal functions wrapping
    
    * Make SymbolFunctor from Symbol
    
    * Fix the type infer function of AMP multicast
    
    * Added ability to override casting lists
    
    * Making clang-tidy and pylint happy
    
    * More cleaning
    
    * Making clang-tidy really happy
    
    * remove amp_cast and amp_multicast before saving the model
    
    * Changes from review
    
    * Add RemoveAmpCast in a separate c_api function, add the option in symbol.save
    
    * add remove_amp_cast option (True by default) to everyway of saving symbol
    
    * Fix
    
    * First stab at adding the gray list
    
    * More ops added
    
    * Adding the rest of the functions
    
    * Improvements to AMP test
    
    * Changing of names and changing wrapping
    
    * Moving to contrib
    
    * Modifying tutorial for contrib AMP
    
    * Removing non existent functions
    
    * Fix import in test
    
    * Fix lint
    
    * Added new functions
    
    * Added assert
    
    * Fix the unknown ndim in PlanMemory pass
    
    * Moving back to FP16_FUNCS and FP16_FP32_FUNCS
    
    * Removing unnecessary ops
    
    * Adding ops that exist only in some build configurations and removing
    tests checking that AMP lists contain only existing ops
    
    * Removing warning when not every function was found during AMP init
    because of functions being available only in specific configurations
    
    * Add tests and doc
    
    * Fix the CPU version of all_finite
    
    * Adding test cases for all_finite operator
    
    * Add new operators
    
    * Fix
---
 3rdparty/tvm                                     |   2 +-
 docs/tutorials/amp/amp_tutorial.md               | 266 ++++++++++
 docs/tutorials/amp/index.md                      |  25 +
 docs/tutorials/index.md                          |   2 +
 include/mxnet/c_api.h                            |  15 +
 python/mxnet/contrib/amp/__init__.py             |  22 +
 python/mxnet/contrib/amp/amp.py                  | 344 +++++++++++++
 python/mxnet/contrib/amp/lists/__init__.py       |  21 +
 python/mxnet/contrib/amp/lists/symbol.py         | 609 +++++++++++++++++++++++
 python/mxnet/contrib/amp/loss_scaler.py          |  77 +++
 python/mxnet/gluon/block.py                      |   4 +-
 python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py  |  12 +-
 python/mxnet/gluon/contrib/rnn/rnn_cell.py       |   8 +-
 python/mxnet/gluon/rnn/rnn_cell.py               |  20 +-
 python/mxnet/model.py                            |  12 +-
 python/mxnet/module/module.py                    |   4 +-
 python/mxnet/symbol/symbol.py                    |  17 +-
 src/c_api/c_api_symbolic.cc                      |  24 +
 src/imperative/cached_op.cc                      |   2 +-
 src/imperative/imperative.cc                     |   2 +-
 src/nnvm/plan_memory.cc                          |   8 +-
 src/operator/contrib/all_finite-inl.h            | 100 ++++
 src/operator/contrib/all_finite.cc               | 168 +++++++
 src/operator/contrib/all_finite.cu               | 107 ++++
 src/operator/contrib/amp_graph_pass.cc           |  61 +++
 src/operator/custom/custom.cc                    |   2 +-
 src/operator/tensor/amp_cast.cc                  | 150 ++++++
 src/operator/tensor/amp_cast.cu                  |  40 ++
 src/operator/tensor/amp_cast.h                   | 165 ++++++
 src/operator/tensor/broadcast_reduce_op_value.cc |   2 +-
 src/operator/tensor/elemwise_sum.cc              |   2 +-
 src/operator/tensor/elemwise_unary_op.h          |   5 +-
 src/operator/tensor/elemwise_unary_op_basic.cc   |  12 +-
 tests/python/unittest/test_amp.py                |  83 +++
 tests/python/unittest/test_operator.py           | 126 +++--
 tests/tutorials/test_sanity_tutorials.py         |   3 +-
 tests/tutorials/test_tutorials.py                |   3 +
 37 files changed, 2456 insertions(+), 69 deletions(-)

diff --git a/3rdparty/tvm b/3rdparty/tvm
index 0f053c8..8518c7d 160000
--- a/3rdparty/tvm
+++ b/3rdparty/tvm
@@ -1 +1 @@
-Subproject commit 0f053c82a747b4dcdf49570ec87c17e0067b7439
+Subproject commit 8518c7ddb561afba8112324fad4b35b8d111c525
diff --git a/docs/tutorials/amp/amp_tutorial.md b/docs/tutorials/amp/amp_tutorial.md
new file mode 100644
index 0000000..02bf82a
--- /dev/null
+++ b/docs/tutorials/amp/amp_tutorial.md
@@ -0,0 +1,266 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Using AMP (Automatic Mixed Precision) in MXNet
+
+Training Deep Learning networks is a very computationally intensive task. Novel model architectures tend to have increasing number of layers and parameters, which slows down training. Fortunately, new generations of training hardware as well as software optimizations, make it a feasible task. 
+
+However, where most of the (both hardware and software) optimization opportunities exists is in exploiting lower precision (like FP16) to, for example, utilize Tensor Cores available on new Volta and Turing GPUs. While training in FP16 showed great success in image classification tasks, other more complicated neural networks typically stayed in FP32 due to difficulties in applying the FP16 training guidelines.
+
+That is where AMP (Automatic Mixed Precision) comes into play. It automatically applies the guidelines of FP16 training, using FP16 precision where it provides the most benefit, while conservatively keeping in full FP32 precision operations unsafe to do in FP16.
+
+This tutorial shows how to get started with mixed precision training using AMP for MXNet. As an example of a network we will use SSD network from GluonCV.
+
+## Data loader and helper functions
+
+For demonstration purposes we will use synthetic data loader.
+
+
+```python
+import logging
+import warnings
+import time
+import mxnet as mx
+import mxnet.gluon as gluon
+from mxnet import autograd
+import gluoncv as gcv
+from gluoncv.model_zoo import get_model
+
+data_shape = 512
+batch_size = 8
+lr = 0.001
+wd = 0.0005
+momentum = 0.9
+
+# training contexts
+ctx = [mx.gpu(0)]
+
+# set up logger
+logging.basicConfig()
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+ce_metric = mx.metric.Loss('CrossEntropy')
+smoothl1_metric = mx.metric.Loss('SmoothL1')
+```
+
+
+```python
+class SyntheticDataLoader(object):
+    def __init__(self, data_shape, batch_size):
+        super(SyntheticDataLoader, self).__init__()
+        self.counter = 0
+        self.epoch_size = 200
+        shape = (batch_size, 3, data_shape, data_shape)
+        cls_targets_shape = (batch_size, 6132)
+        box_targets_shape = (batch_size, 6132, 4)
+        self.data = mx.nd.random.uniform(-1, 1, shape=shape, ctx=mx.cpu_pinned())
+        self.cls_targets = mx.nd.random.uniform(0, 1, shape=cls_targets_shape, ctx=mx.cpu_pinned())
+        self.box_targets = mx.nd.random.uniform(0, 1, shape=box_targets_shape, ctx=mx.cpu_pinned())
+    
+    def next(self):
+        if self.counter >= self.epoch_size:
+            self.counter = self.counter % self.epoch_size
+            raise StopIteration
+        self.counter += 1
+        return [self.data, self.cls_targets, self.box_targets]
+    
+    __next__ = next
+    
+    def __iter__(self):
+        return self
+    
+train_data = SyntheticDataLoader(data_shape, batch_size)
+```
+
+
+```python
+def get_network():
+    # SSD with RN50 backbone
+    net_name = 'ssd_512_resnet50_v1_coco'
+    net = get_model(net_name, pretrained_base=True, norm_layer=gluon.nn.BatchNorm)
+    async_net = net
+    with warnings.catch_warnings(record=True) as w:
+        warnings.simplefilter("always")
+        net.initialize()
+        net.collect_params().reset_ctx(ctx)
+
+    return net
+```
+
+# Training in FP32
+
+First, let us create the network.
+
+
+```python
+net = get_network()
+net.hybridize(static_alloc=True, static_shape=True)
+```
+
+    /mxnet/code/python/mxnet/gluon/block.py:1138: UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
+    	data: None
+      input_sym_arg_type = in_param.infer_type()[0]
+
+
+Next, we need to create a Gluon Trainer.
+
+
+```python
+trainer = gluon.Trainer(
+    net.collect_params(), 'sgd',
+    {'learning_rate': lr, 'wd': wd, 'momentum': momentum})
+```
+
+
+```python
+mbox_loss = gcv.loss.SSDMultiBoxLoss()
+
+for epoch in range(1):
+    ce_metric.reset()
+    smoothl1_metric.reset()
+    tic = time.time()
+    btic = time.time()
+
+    for i, batch in enumerate(train_data):
+        batch_size = batch[0].shape[0]
+        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
+        cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
+        box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
+        with autograd.record():
+            cls_preds = []
+            box_preds = []
+            for x in data:
+                cls_pred, box_pred, _ = net(x)
+                cls_preds.append(cls_pred)
+                box_preds.append(box_pred)
+            sum_loss, cls_loss, box_loss = mbox_loss(
+                cls_preds, box_preds, cls_targets, box_targets)
+            autograd.backward(sum_loss)
+        trainer.step(1)
+        ce_metric.update(0, [l * batch_size for l in cls_loss])
+        smoothl1_metric.update(0, [l * batch_size for l in box_loss])
+        if not (i + 1) % 50:
+            name1, loss1 = ce_metric.get()
+            name2, loss2 = smoothl1_metric.get()
+            logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
+                epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2))
+        btic = time.time()
+```
+
+    INFO:root:[Epoch 0][Batch 49], Speed: 58.105 samples/sec, CrossEntropy=1.190, SmoothL1=0.688
+    INFO:root:[Epoch 0][Batch 99], Speed: 58.683 samples/sec, CrossEntropy=0.693, SmoothL1=0.536
+    INFO:root:[Epoch 0][Batch 149], Speed: 58.915 samples/sec, CrossEntropy=0.500, SmoothL1=0.453
+    INFO:root:[Epoch 0][Batch 199], Speed: 58.422 samples/sec, CrossEntropy=0.396, SmoothL1=0.399
+
+
+## Training with AMP
+
+### AMP initialization
+
+In order to start using AMP, we need to import and initialize it. This has to happen before we create the network.
+
+
+```python
+from mxnet.contrib import amp
+
+amp.init()
+```
+
+    INFO:root:Using AMP
+
+
+After that, we can create the network exactly the same way we did in FP32 training.
+
+
+```python
+net = get_network()
+net.hybridize(static_alloc=True, static_shape=True)
+```
+
+    /mxnet/code/python/mxnet/gluon/block.py:1138: UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
+    	data: None
+      input_sym_arg_type = in_param.infer_type()[0]
+
+
+For some models that may be enough to start training in mixed precision, but the full FP16 recipe recommends using dynamic loss scaling to guard against over- and underflows of FP16 values. Therefore, as a next step, we create a trainer and initialize it with support for AMP's dynamic loss scaling. Currently, support for dynamic loss scaling is limited to trainers created with `update_on_kvstore=False` option, and so we add it to our trainer initialization.
+
+
+```python
+trainer = gluon.Trainer(
+    net.collect_params(), 'sgd',
+    {'learning_rate': lr, 'wd': wd, 'momentum': momentum},
+    update_on_kvstore=False)
+
+amp.init_trainer(trainer)
+```
+
+### Dynamic loss scaling in the training loop
+
+The last step is to apply the dynamic loss scaling during the training loop and . We can achieve that using the `amp.scale_loss` function.
+
+
+```python
+mbox_loss = gcv.loss.SSDMultiBoxLoss()
+
+for epoch in range(1):
+    ce_metric.reset()
+    smoothl1_metric.reset()
+    tic = time.time()
+    btic = time.time()
+
+    for i, batch in enumerate(train_data):
+        batch_size = batch[0].shape[0]
+        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
+        cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
+        box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
+        with autograd.record():
+            cls_preds = []
+            box_preds = []
+            for x in data:
+                cls_pred, box_pred, _ = net(x)
+                cls_preds.append(cls_pred)
+                box_preds.append(box_pred)
+            sum_loss, cls_loss, box_loss = mbox_loss(
+                cls_preds, box_preds, cls_targets, box_targets)
+            with amp.scale_loss(sum_loss, trainer) as scaled_loss:
+                autograd.backward(scaled_loss)
+        trainer.step(1)
+        ce_metric.update(0, [l * batch_size for l in cls_loss])
+        smoothl1_metric.update(0, [l * batch_size for l in box_loss])
+        if not (i + 1) % 50:
+            name1, loss1 = ce_metric.get()
+            name2, loss2 = smoothl1_metric.get()
+            logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
+                epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2))
+        btic = time.time()
+```
+
+    INFO:root:[Epoch 0][Batch 49], Speed: 93.585 samples/sec, CrossEntropy=1.166, SmoothL1=0.684
+    INFO:root:[Epoch 0][Batch 99], Speed: 93.773 samples/sec, CrossEntropy=0.682, SmoothL1=0.533
+    INFO:root:[Epoch 0][Batch 149], Speed: 93.399 samples/sec, CrossEntropy=0.493, SmoothL1=0.451
+    INFO:root:[Epoch 0][Batch 199], Speed: 93.674 samples/sec, CrossEntropy=0.391, SmoothL1=0.397
+
+
+We got 60% speed increase from 3 additional lines of code!
+
+## Current limitations of AMP
+
+- AMP's dynamic loss scaling currently supports only Gluon trainer with `update_on_kvstore=False` option set
+- Using `SoftmaxOutput`, `LinearRegressionOutput`, `LogisticRegressionOutput`, `MAERegressionOutput` with dynamic loss scaling does not work when training networks with multiple Gluon trainers and so multiple loss scales
+
+<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
diff --git a/docs/tutorials/amp/index.md b/docs/tutorials/amp/index.md
new file mode 100644
index 0000000..faf6526
--- /dev/null
+++ b/docs/tutorials/amp/index.md
@@ -0,0 +1,25 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Tutorials
+
+```eval_rst
+.. toctree::
+   :glob:
+
+   *
+```
diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md
index 01c59b1..2527ccf 100644
--- a/docs/tutorials/index.md
+++ b/docs/tutorials/index.md
@@ -38,6 +38,7 @@
    tensorrt/index.md
    unsupervised_learning/index.md
    vision/index.md
+   amp/index.md
 ```
 
 MXNet tutorials can be found in this section. A variety of language bindings are available for MXNet (including Python, Scala, Java, Clojure, C++ and R) and we have a different tutorial section for each language.
@@ -102,6 +103,7 @@ Select API:&nbsp;
     * [Profiling MXNet Models](/tutorials/python/profiler.html)
     * [Module to Gluon API](/tutorials/python/module_to_gluon.html)<span style="color:red"> (new!)</span>
     * [Gluon end to end from training to inference](/tutorials/gluon/gluon_from_experiment_to_deployment.html)
+    * [Automatic Mixed Precision in Gluon](/tutorials/amp/amp_tutorial.html)
 
 * API Guides
     * Core APIs
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index f79f224..511bff2 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1292,6 +1292,13 @@ MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out);
  */
 MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out);
 /*!
+ * \brief Remove the operators amp_cast and amp_multicast
+ * \param sym_handle the input symbol.
+ * \param ret_sym_handle the output symbol.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle);
+/*!
  * \brief Save a symbol into a json file.
  * \param symbol the input symbol.
  * \param fname the file name.
@@ -1747,6 +1754,14 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
 MXNET_DLL int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
                                    SymbolHandle *ret_sym_handle);
 
+/*!
+ * \brief Generate atomic symbol (able to be composed) from a source symbol
+ * \param sym_handle source symbol
+ * \param ret_sym_handle returned atomic symbol
+ */
+MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle);
+
+
 //--------------------------------------------
 // Part 4: Executor interface
 //--------------------------------------------
diff --git a/python/mxnet/contrib/amp/__init__.py b/python/mxnet/contrib/amp/__init__.py
new file mode 100644
index 0000000..7aebc41
--- /dev/null
+++ b/python/mxnet/contrib/amp/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Automatic mixed precision module."""
+
+from .amp import *
diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py
new file mode 100755
index 0000000..77a566e
--- /dev/null
+++ b/python/mxnet/contrib/amp/amp.py
@@ -0,0 +1,344 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Functions for enabling AMP (automatic mixed precision)."""
+__all__ = ['init', 'init_trainer', 'scale_loss', 'unscale']
+
+from types import MethodType
+import logging
+import contextlib
+import numpy as np
+
+from ... import symbol
+from ...symbol import Symbol
+from ...symbol import contrib as symbol_contrib
+from ... import ndarray
+from ...ndarray import NDArray
+from . import lists
+from ...gluon import trainer
+from ... import base
+from ... import optimizer as opt
+from .loss_scaler import LossScaler
+
+def _cast_symbol_NDArray(s, dtype):
+    float_types = (np.float16, np.float32)
+    if isinstance(s, Symbol):
+        return symbol.amp_cast(s, dtype=dtype)
+    elif isinstance(s, NDArray):
+        if (s.dtype != dtype and
+                s.dtype in float_types and
+                s.context.device_type != 'cpu'):
+            return ndarray.amp_cast(s, dtype=dtype)
+        else:
+            return s
+    else:
+        return s
+
+def _get_fun_to_wrap(name, module, submodule_dict):
+    module_internal = getattr(module, "_internal")
+    prefix = base._get_op_name_prefix(name)
+    if len(prefix) > 0:
+        if prefix != '_random_' or name.endswith('_like'):
+            func_name = name[len(prefix):]
+            cur_module = submodule_dict[prefix]
+        else:
+            func_name = name
+            cur_module = module_internal
+    elif name.startswith('_'):
+        func_name = name
+        cur_module = module_internal
+    else:
+        func_name = name
+        cur_module = module
+    return func_name, cur_module
+
+def _wrap_symbol_functions(module, target_dtype, target_precision_ops=None,
+                           conditional_fp32_ops=None, fp32_ops=None):
+    def _ndarray_wrapper(f, target_dtype, cond_arg=None):
+        def _new_fun(*args, **kwargs):
+            if cond_arg is not None:
+                if (cond_arg[0] not in kwargs or
+                        kwargs[cond_arg[0]] not in cond_arg[1]):
+                    return f(*args, **kwargs)
+            new_args = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype), args))
+            args = tuple(new_args)
+            kwargs = {k: _cast_symbol_NDArray(v, target_dtype) for k, v in kwargs.items()}
+            return f(*args, **kwargs)
+        _new_fun.__name__ = f.__name__
+        _new_fun.__module__ = f.__module__
+        _new_fun.__doc__ = f.__doc__
+        return _new_fun
+
+    def _symbol_wrapper(f, target_dtype, cond_arg=None):
+        def _new_fun(*args, **kwargs):
+            if cond_arg is not None:
+                if (cond_arg[0] not in kwargs or
+                        kwargs[cond_arg[0]] not in cond_arg[1]):
+                    return f(*args, **kwargs)
+            sym = f(*args, **kwargs)
+            inputs = sym.get_children()
+            aux = sym.list_auxiliary_states()
+            inputs = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype)
+                              if x.name not in aux else x, inputs))
+            atomic_sym = sym._gen_atomic_symbol()
+            wrapped_sym = atomic_sym(*inputs)
+            wrapped_sym._set_attr(name=sym.name)
+            return wrapped_sym
+        _new_fun.__name__ = f.__name__
+        _new_fun.__module__ = f.__module__
+        _new_fun.__doc__ = f.__doc__
+        return _new_fun
+
+    def _symbol_widest_wrapper(f):
+        def _new_fun(*args, **kwargs):
+            symbols = []
+            is_symbol = False
+            args = list(args)
+            for i, arg in enumerate(args):
+                if isinstance(arg, (Symbol, NDArray)):
+                    symbols.append((args, i, arg))
+                    is_symbol = is_symbol or isinstance(arg, Symbol)
+            for k, arg in kwargs.items():
+                if isinstance(arg, (Symbol, NDArray)):
+                    symbols.append((kwargs, k, arg))
+                    is_symbol = is_symbol or isinstance(arg, Symbol)
+            if not is_symbol:
+                # NDArray case
+                widest_type = target_dtype
+                for _, _, arg in symbols:
+                    if isinstance(arg, NDArray):
+                        if arg.dtype == np.float32:
+                            widest_type = np.float32
+                for arr, index, arg in symbols:
+                    if arg.dtype != widest_type and arg.dtype == target_dtype:
+                        arr[index] = ndarray.amp_cast(arg, dtype=widest_type)
+            else:
+                # Symbol case
+                sym_to_check = list(map(lambda x: x[2], symbols))
+                casted_syms = symbol.amp_multicast(*sym_to_check, num_outputs=len(sym_to_check))
+                symbols = list(map(lambda x_y: (x_y[0][0], x_y[0][1], x_y[1]),
+                                   zip(symbols, casted_syms)))
+                for arr, index, arg in symbols:
+                    arr[index] = arg
+
+            return f(*args, **kwargs)
+        _new_fun.__name__ = f.__name__
+        _new_fun.__module__ = f.__module__
+        _new_fun.__doc__ = f.__doc__
+        return _new_fun
+
+    _wrapper = _symbol_wrapper if module in (symbol, Symbol, symbol_contrib) else _ndarray_wrapper
+
+    submodule_dict = {}
+    for op_name_prefix in base._OP_NAME_PREFIX_LIST:
+        submodule_dict[op_name_prefix] =\
+                getattr(module, op_name_prefix[1:-1])
+
+    wrap_list = target_precision_ops if target_precision_ops is not None \
+                    else lists.symbol.FP16_FUNCS
+    for fun_name in wrap_list:
+        try:
+            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
+            f_to_wrap = getattr(cur_module, fun_name)
+            setattr(cur_module, fun_name, _wrapper(f_to_wrap, target_dtype))
+            if cur_module == module:
+                setattr(module.op, fun_name, _wrapper(f_to_wrap, target_dtype))
+        except AttributeError:
+            pass
+
+    wrap_list = fp32_ops if fp32_ops is not None else lists.symbol.FP32_FUNCS
+    for fun_name in wrap_list:
+        try:
+            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
+            f_to_wrap = getattr(cur_module, fun_name)
+            setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32))
+            if cur_module == module:
+                setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32))
+        except AttributeError:
+            pass
+
+    wrap_list = conditional_fp32_ops if conditional_fp32_ops is not None \
+                    else lists.symbol.CONDITIONAL_FP32_FUNCS
+    for fun_name, arg, arg_values in wrap_list:
+        try:
+            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
+            f_to_wrap = getattr(cur_module, fun_name)
+            setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32, (arg, arg_values)))
+            if cur_module == module:
+                setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32, (arg, arg_values)))
+        except AttributeError:
+            pass
+
+    for fun_name in lists.symbol.WIDEST_TYPE_CASTS:
+        try:
+            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
+            f_to_wrap = getattr(cur_module, fun_name)
+            setattr(cur_module, fun_name, _symbol_widest_wrapper(f_to_wrap))
+            if cur_module == module:
+                setattr(module.op, fun_name, _symbol_widest_wrapper(f_to_wrap))
+        except AttributeError:
+            pass
+
+def _wrap_loss_output_functions(module, ls):
+    if module == ndarray:
+        def _wrapper(f):
+            def _scaling_wrapper(*args, **kwargs):
+                if 'grad_scale' in kwargs:
+                    kwargs['grad_scale'] = kwargs['grad_scale'] * ls.loss_scale
+                else:
+                    kwargs['grad_scale'] = ls.loss_scale
+                return f(*args, **kwargs)
+            _scaling_wrapper.__name__ = f.__name__
+            _scaling_wrapper.__module__ = f.__module__
+            _scaling_wrapper.__doc__ = f.__doc__
+            return _scaling_wrapper
+    else:
+        def _wrapper(f):
+            def _warning_wrapper(*args, **kwargs):
+                logging.warning("%s does not support dynamic loss scaling "
+                                "in symbolic and hybridized execution.", f.__name__)
+                return f(*args, **kwargs)
+            _warning_wrapper.__name__ = f.__name__
+            _warning_wrapper.__module__ = f.__module__
+            _warning_wrapper.__doc__ = f.__doc__
+            return _warning_wrapper
+
+    for fun_name in lists.symbol.LOSS_OUTPUT_FUNCTIONS:
+        try:
+            f_to_wrap = getattr(module, fun_name)
+            setattr(module, fun_name, _wrapper(f_to_wrap))
+        except AttributeError:
+            pass
+
+_amp_initialized = False
+_amp_loss_scale_initialized = False
+_loss_scaler = None
+
+@contextlib.contextmanager
+def scale_loss(loss, optimizer_or_trainer):
+    assert optimizer_or_trainer._amp_loss_scaler is not None, \
+        'Loss scaler is not initialized, did you forget to call amp.init_trainer()?'
+    optimizer_or_trainer._scale = (optimizer_or_trainer._amp_original_scale /
+                                   optimizer_or_trainer._amp_loss_scaler.loss_scale)
+    if isinstance(loss, (list, tuple)):
+        yield [l * optimizer_or_trainer._amp_loss_scaler.loss_scale for l in loss]
+    else:
+        yield optimizer_or_trainer._amp_loss_scaler.loss_scale * loss
+
+def init(target_dtype='float16', target_precision_ops=None,
+         conditional_fp32_ops=None, fp32_ops=None):
+    """Initialize AMP (automatic mixed precision).
+
+    This needs to be done before model creation.
+
+    Parameters
+    ----------
+    target_dtype : {'float16'}
+        Target low precision type for AMP. Currently only float16 is supported.
+    target_precision_ops : list of string
+        Override the list of functions casted to FP16. Entries in this list
+        are names of the functions casted to FP16.
+    conditional_fp32_ops : list of (string, string, list of string)
+        Override the list of functions conditionally casted to FP32. The format
+        of the list is (name of the function, name of the parameter, list of
+        values of the parameter that make the function be casted to FP32).
+    fp32_ops : list of string
+        Override the list of functions casted to FP32. Entries in this list
+        are names of the functions casted to FP32.
+    """
+    global _amp_initialized
+    global _loss_scaler
+    if not _amp_initialized:
+        assert target_dtype == 'float16' or target_dtype == np.float16, \
+               "AMP currently supports only float16 as a target_dtype"
+        _amp_initialized = True
+        logging.info("Using AMP")
+        target_dtype = np.dtype(target_dtype)
+        _wrap_symbol_functions(symbol, target_dtype, target_precision_ops,
+                               conditional_fp32_ops, fp32_ops)
+        _wrap_symbol_functions(ndarray, target_dtype, target_precision_ops,
+                               conditional_fp32_ops, fp32_ops)
+        _loss_scaler = LossScaler()
+        _wrap_loss_output_functions(ndarray, _loss_scaler)
+        _wrap_loss_output_functions(symbol, _loss_scaler)
+
+def init_trainer(optimizer_or_trainer):
+    """Initialize trainer or optimizer to work with AMP dynamic loss scaling.
+
+    Parameters
+    ----------
+    optimizer_or_trainer : Optimizer or Trainer
+        MXNet Optimizer or Gluon trainer to initialize with AMP
+    """
+    global _amp_loss_scale_initialized
+    global _amp_initialized
+    global _loss_scaler
+    assert _amp_initialized, "AMP not initialized, did you forget to call amp.init()?"
+    if not _amp_loss_scale_initialized:
+        _amp_loss_scale_initialized = True
+        loss_scaler = _loss_scaler
+    else:
+        loss_scaler = LossScaler()
+    #_wrap_output
+    if isinstance(optimizer_or_trainer, trainer.Trainer):
+        optimizer_or_trainer._amp_loss_scaler = loss_scaler
+        optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
+        skip_update = optimizer_or_trainer._amp_loss_scaler.wait_and_update
+        optimizer_or_trainer._optimizer.old_update_multi_precision = \
+                optimizer_or_trainer._optimizer.update_multi_precision
+        def new_update_multi_precision(self, index, weight, grad, state):
+            if not skip_update():
+                self.old_update_multi_precision(index, weight, grad, state)
+        optimizer_or_trainer._optimizer.update_multi_precision = \
+            MethodType(new_update_multi_precision, optimizer_or_trainer._optimizer)
+        launch_check_overflow = optimizer_or_trainer._amp_loss_scaler.launch_check_overflow
+        optimizer_or_trainer._old_update = optimizer_or_trainer._update
+        def new_update(self, ignore_stale_grad=False):
+            launch_check_overflow(self._params)
+            self._old_update(ignore_stale_grad)
+        optimizer_or_trainer._update = MethodType(new_update, optimizer_or_trainer)
+
+    elif isinstance(optimizer_or_trainer, opt.Optimizer):
+        # TODO(ptredak): make it work with the optimizer
+        raise TypeError("AMP is currently only compatible with Gluon Trainer")
+    else:
+        raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
+                        "an optimizer, instead is %s" % type(optimizer_or_trainer))
+
+def unscale(optimizer_or_trainer):
+    """Check and unscale the gradients manually. This function should only be used
+    if accessing gradients is necessary, e.g. for gradient clipping.
+
+    Parameters
+    ----------
+    optimizer_or_trainer : Optimizer or Trainer
+        MXNet optimizer or Gluon Trainer used when scaling the gradients
+    """
+    if isinstance(optimizer_or_trainer, trainer.Trainer):
+        valid_grads = [p._grad for p in optimizer_or_trainer._params if p._grad is not None]
+        for grads in valid_grads:
+            # TODO(ptredak): make a bulked unscale
+            for g in grads:
+                g[:] *= optimizer_or_trainer._scale
+        optimizer_or_trainer._scale = 1.
+    elif isinstance(optimizer_or_trainer, opt.Optimizer):
+        # TODO(ptredak): make it work with the optimizer
+        raise TypeError("AMP is currently only compatible with Gluon Trainer")
+    else:
+        raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
+                        "an optimizer, instead is %s" % type(optimizer_or_trainer))
diff --git a/python/mxnet/contrib/amp/lists/__init__.py b/python/mxnet/contrib/amp/lists/__init__.py
new file mode 100644
index 0000000..e128944
--- /dev/null
+++ b/python/mxnet/contrib/amp/lists/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Lists of functions whitelisted/blacklisted for automatic mixed precision."""
+
+from . import symbol
diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py
new file mode 100644
index 0000000..2f8b4f0
--- /dev/null
+++ b/python/mxnet/contrib/amp/lists/symbol.py
@@ -0,0 +1,609 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API."""
+
+# Functions that should be cast to lower precision
+FP16_FUNCS = [
+    'Convolution',
+    'Deconvolution',
+    'FullyConnected',
+    'RNN',
+    ]
+
+# Functions that should not be casted, either because
+# they are irrelevant (not used in the network itself
+# like image transformations or optimizers) or they
+# are dtype neutral (can work in both fp16 and fp32)
+FP16_FP32_FUNCS = [
+    'BatchNorm',
+    'BatchNorm_v1',
+    'BilinearSampler',
+    'BlockGrad',
+    'Cast',
+    'cast',
+    'cast_storage',
+    'Crop',
+    'Dropout',
+    'Embedding',
+    '_sparse_Embedding',
+    '_sparse_FullyConnected',
+    'Flatten',
+    'GridGenerator',
+    'Pad',
+    'Pooling',
+    'Pooling_v1',
+    'ROIPooling',
+    'Reshape',
+    'SequenceLast',
+    'SequenceMask',
+    'SequenceReverse',
+    'SliceChannel',
+    'SpatialTransformer',
+    'SwapAxis',
+    'UpSampling',
+    '_CachedOp',
+    '_CrossDeviceCopy',
+    '_CustomFunction',
+    '_DivScalar',
+    '_EqualScalar',
+    '_GreaterScalar',
+    '_GreaterEqualScalar',
+    '_LesserScalar',
+    '_LesserEqualScalar',
+    '_LogicalAndScalar',
+    '_LogicalOrScalar',
+    '_LogicalXorScalar',
+    '_MaximumScalar',
+    '_MinimumScalar',
+    '_MinusScalar',
+    '_ModScalar',
+    '_MulScalar',
+    '_NoGradient',
+    '_NotEqualScalar',
+    '_PlusScalar',
+    '_RMinusScalar',
+    '_RModScalar',
+    '_adamw_update',
+    '_add',
+    '_arange',
+    '_broadcast_backward',
+    '_cond',
+    '_contrib_AdaptiveAvgPooling2D',
+    '_contrib_BilinearResize2D',
+    '_contrib_SparseEmbedding',
+    '_contrib_bipartite_matching',
+    '_contrib_dequantize',
+    '_contrib_div_sqrt_dim',
+    '_contrib_boolean_mask',
+    '_contrib_getnnz',
+    '_contrib_gradientmultiplier',
+    '_contrib_group_adagrad_update',
+    '_contrib_ifft',
+    '_contrib_index_copy',
+    '_contrib_quadratic',
+    '_contrib_quantize',
+    '_contrib_quantize_v2',
+    '_contrib_quantized_concat',
+    '_contrib_quantized_conv',
+    '_contrib_quantized_flatten',
+    '_contrib_quantized_fully_connected',
+    '_contrib_quantized_pooling',
+    '_contrib_quantized_elemwise_add',
+    '_contrib_quantized_act',
+    '_image_crop',
+    '_linspace',
+    '_contrib_requantize',
+    '_copy',
+    '_copyto',
+    '_crop_assign',
+    '_crop_assign_scalar',
+    '_cvcopyMakeBorder',
+    '_cvimdecode',
+    '_cvimread',
+    '_cvimresize',
+    '_div_scalar',
+    '_equal_scalar',
+    '_eye',
+    '_foreach',
+    '_while_loop',
+    '_full',
+    '_grad_add',
+    '_greater_scalar',
+    '_greater_equal_scalar',
+    '_histogram',
+    '_identity_with_attr_like_rhs',
+    '_image_adjust_lighting',
+    '_image_flip_left_right',
+    '_image_flip_top_bottom',
+    '_image_normalize',
+    '_image_random_brightness',
+    '_image_random_color_jitter',
+    '_image_random_contrast',
+    '_image_random_flip_left_right',
+    '_image_random_flip_top_bottom',
+    '_image_random_hue',
+    '_image_random_lighting',
+    '_image_random_saturation',
+    '_image_resize',
+    '_image_to_tensor',
+    '_imdecode',
+    '_lesser_scalar',
+    '_lesser_equal_scalar',
+    '_logical_and_scalar',
+    '_logical_or_scalar',
+    '_logical_xor_scalar',
+    '_maximum_scalar',
+    '_minimum_scalar',
+    '_minus_scalar',
+    '_mod_scalar',
+    '_mp_adamw_update',
+    '_mul_scalar',
+    '_not_equal_scalar',
+    '_onehot_encode',
+    '_ones',
+    '_plus_scalar',
+    '_random_exponential',
+    '_random_exponential_like',
+    '_random_gamma',
+    '_random_gamma_like',
+    '_random_generalized_negative_binomial',
+    '_random_generalized_negative_binomial_like',
+    '_random_negative_binomial',
+    '_random_negative_binomial_like',
+    '_random_normal',
+    '_random_normal_like',
+    '_random_poisson',
+    '_random_poisson_like',
+    '_random_randint',
+    '_random_uniform',
+    '_random_uniform_like',
+    '_ravel_multi_index',
+    '_rminus_scalar',
+    '_rmod_scalar',
+    '_rnn_param_concat',
+    '_sample_exponential',
+    '_sample_gamma',
+    '_sample_generalized_negative_binomial',
+    '_sample_multinomial',
+    '_sample_negative_binomial',
+    '_sample_normal',
+    '_sample_poisson',
+    '_sample_uniform',
+    '_sample_unique_zipfian',
+    '_scatter_minus_scalar',
+    '_scatter_plus_scalar',
+    '_scatter_set_nd',
+    '_set_value',
+    '_shuffle',
+    '_slice_assign',
+    '_slice_assign_scalar',
+    '_sparse_abs',
+    '_sparse_adagrad_update',
+    '_sparse_adam_update',
+    '_sparse_arccosh',
+    '_sparse_arcsinh',
+    '_sparse_arctan',
+    '_sparse_cast_storage',
+    '_sparse_cbrt',
+    '_sparse_ceil',
+    '_sparse_clip',
+    '_sparse_concat',
+    '_sparse_cos',
+    '_sparse_degrees',
+    '_sparse_fix',
+    '_sparse_floor',
+    '_sparse_ftrl_update',
+    '_sparse_negative',
+    '_sparse_radians',
+    '_sparse_relu',
+    '_sparse_retain',
+    '_sparse_rint',
+    '_sparse_round',
+    '_sparse_sgd_mom_update',
+    '_sparse_sgd_update',
+    '_sparse_sigmoid',
+    '_sparse_sign',
+    '_sparse_sin',
+    '_sparse_sinh',
+    '_sparse_slice',
+    '_sparse_sqrt',
+    '_sparse_stop_gradient',
+    '_sparse_tanh',
+    '_sparse_trunc',
+    '_sparse_zeros_like',
+    '_split_v2',
+    '_split_v2_backward',
+    '_unravel_index',
+    '_zeros',
+    '_zeros_without_dtype',
+    'abs',
+    'adam_update',
+    'all_finite',
+    'amp_cast',
+    'amp_multicast',
+    'arccosh',
+    'arcsinh',
+    'arctan',
+    'argmax',
+    'argmax_channel',
+    'argmin',
+    'batch_take',
+    'broadcast_axes',
+    'broadcast_axis',
+    'broadcast_like',
+    'broadcast_to',
+    'cbrt',
+    'ceil',
+    'choose_element_0index',
+    'clip',
+    'cos',
+    'crop',
+    'degrees',
+    'depth_to_space',
+    'diag',
+    'erf',
+    'expand_dims',
+    'fill_element_0index',
+    'fix',
+    'flatten',
+    'flip',
+    'floor',
+    'ftml_update',
+    'ftrl_update',
+    'gather_nd',
+    'hard_sigmoid',
+    'identity',
+    'logical_not',
+    'max_axis',
+    'max',
+    'min',
+    'min_axis',
+    'mp_sgd_mom_update',
+    'mp_sgd_update',
+    'multi_all_finite',
+    'multi_mp_sgd_mom_update',
+    'multi_mp_sgd_update',
+    'multi_sgd_mom_update',
+    'multi_sgd_update',
+    'negative',
+    'normal',
+    'one_hot',
+    'ones_like',
+    'pad',
+    'pick',
+    'radians',
+    'random_exponential',
+    'random_gamma',
+    'random_generalized_negative_binomial',
+    'random_negative_binomial',
+    'random_normal',
+    'random_poisson',
+    'random_randint',
+    'random_uniform',
+    'ravel_multi_index',
+    'relu',
+    'repeat',
+    'reshape',
+    'reshape_like',
+    'reverse',
+    'rint',
+    'rmsprop_update',
+    'rmspropalex_update',
+    'round',
+    'sample_exponential',
+    'sample_gamma',
+    'sample_generalized_negative_binomial',
+    'sample_multinomial',
+    'sample_negative_binomial',
+    'sample_normal',
+    'sample_poisson',
+    'sample_uniform',
+    'scatter_nd',
+    'sgd_mom_update',
+    'sgd_update',
+    'shape_array',
+    'shuffle',
+    'sigmoid',
+    'sign',
+    'signsgd_update',
+    'signum_update',
+    'sin',
+    'size_array',
+    'slice',
+    'slice_axis',
+    'slice_like',
+    'softsign',
+    'sort',
+    'space_to_depth',
+    'split',
+    'sqrt',
+    'squeeze',
+    'stop_gradient',
+    'swapaxes',
+    'take',
+    'tanh',
+    'tile',
+    'topk',
+    'transpose',
+    'trunc',
+    'uniform',
+    'unravel_index',
+    'zeros_like',
+    '_sg_mkldnn_conv',
+    '_sg_mkldnn_fully_connected',
+    'CuDNNBatchNorm',
+    '_TensorRT',
+    ]
+
+# Functions that have to be cast to FP32 due to possible
+# overflows
+FP32_FUNCS = [
+    'Convolution_v1',
+    'IdentityAttachKLSparseReg',
+    'arccos',
+    '_sparse_arccos',
+    'arcsin',
+    'cosh',
+    '_sparse_cosh',
+    'erfinv',
+    'sinh',
+    'tan',
+    '_sparse_tan',
+    'arctanh',
+    '_sparse_arcsin',
+    '_sparse_arctanh',
+
+    # Exponents
+    'exp',
+    'expm1',
+    '_sparse_exp',
+    '_sparse_expm1',
+    'log',
+    'log10',
+    'log2',
+    'log1p',
+
+    # Powers
+    'broadcast_power',
+    'square',
+    '_sparse_square',
+    'reciprocal',
+    '_RDivScalar',
+    '_rdiv_scalar',
+    'rsqrt',
+    'rcbrt',
+    '_Power',
+    '_PowerScalar',
+    '_power',
+    '_power_scalar',
+    '_RPowerScalar',
+    '_rpower_scalar',
+    'linalg_sumlogdiag',
+    '_Hypot',
+    '_HypotScalar',
+    '_hypot',
+    '_hypot_scalar',
+    'broadcast_hypot',
+    '_square_sum',
+    '_contrib_hawkesll',
+
+    # Reductions
+    'sum',
+    'sum_axis',
+    'nansum',
+    'prod',
+    'nanprod',
+    'mean',
+    'norm',
+    'softmin',
+    'khatri_rao',
+    'moments',
+
+    # Misc
+    'gamma',
+    'gammaln',
+    '_linalg_gelqf',
+    '_linalg_gemm',
+    '_linalg_gemm2',
+    '_linalg_potrf',
+    '_linalg_potri',
+    '_linalg_sumlogdiag',
+    '_linalg_syevd',
+    '_linalg_syrk',
+    '_linalg_trmm',
+    '_linalg_trsm',
+    '_linalg_makediag',
+    '_linalg_extractdiag',
+    '_linalg_maketrian',
+    '_linalg_extracttrian',
+    '_linalg_inverse',
+    'linalg_syrk',
+    'linalg_potrf',
+    'linalg_potri',
+    'linalg_gemm2',
+    'linalg_gemm',
+    'linalg_gelqf',
+    'linalg_trmm',
+    'linalg_trsm',
+    'linalg_makediag',
+    'linalg_extractdiag',
+    'linalg_maketrian',
+    'linalg_extracttrian',
+    'linalg_inverse',
+    '_NDArray',
+    '_Native',
+    '_contrib_count_sketch',
+    '_contrib_SyncBatchNorm',
+    '_contrib_fft',
+    '_sparse_gamma',
+    '_sparse_gammaln',
+    '_sparse_log',
+    '_sparse_log10',
+    '_sparse_log1p',
+    '_sparse_log2',
+    '_sparse_make_loss',
+    '_sparse_mean',
+    '_sparse_norm',
+    '_sparse_rsqrt',
+    'argsort',
+
+    # Neural network
+    'SoftmaxOutput',
+    'softmax',
+    'Softmax',
+    'log_softmax',
+    'InstanceNorm',
+    'LayerNorm',
+    'L2Normalization',
+    'LRN',
+    'SoftmaxActivation',
+    'LinearRegressionOutput',
+    'LogisticRegressionOutput',
+    'MAERegressionOutput',
+    '_sparse_LinearRegressionOutput',
+    '_sparse_LogisticRegressionOutput',
+    '_sparse_MAERegressionOutput',
+    'SVMOutput',
+    'softmax_cross_entropy',
+    'smooth_l1',
+    'MakeLoss',
+    'make_loss',
+    'Custom',
+    'CTCLoss',
+    '_contrib_CTCLoss',
+    '_contrib_ctc_loss',
+    'ctc_loss',
+    '_contrib_DeformableConvolution',
+    '_contrib_DeformablePSROIPooling',
+    ]
+
+# Functions that have to be cast to FP32 only for
+# some values of their parameters
+CONDITIONAL_FP32_FUNCS = [
+    ('Activation', 'act_type', ['softrelu']),
+    ('LeakyReLU', 'act_type', ['elu', 'selu']),
+    ]
+
+# Functions with multiple inputs, that need the same
+# type of all their inputs
+WIDEST_TYPE_CASTS = [
+    '_Plus',
+    '_plus',
+    '_Minus',
+    '_sub',
+    '_Mul',
+    '_Div',
+    '_div',
+    '_scatter_elemwise_div',
+    '_Mod',
+    '_Not_Equal',
+    '_Equal',
+    '_equal',
+    '_Greater',
+    '_greater',
+    '_Greater_Equal',
+    '_greater_equal',
+    '_Lesser',
+    '_Lesser_Equal',
+    '_lesser',
+    '_lesser_equal',
+    '_Logical_And',
+    '_Logical_Or',
+    '_Logical_Xor',
+    '_logical_and',
+    '_logical_or',
+    '_logical_xor',
+    '_maximum',
+    '_minimum',
+    '_minus',
+    '_mod',
+    '_mul',
+    '_not_equal',
+    'Concat',
+    'concat',
+    'Correlation',
+    'ElementWiseSum',
+    '_sparse_ElementWiseSum',
+    'add_n',
+    '_sparse_add_n',
+    'batch_dot',
+    'broadcast_add',
+    'broadcast_plus',
+    'broadcast_div',
+    'broadcast_equal',
+    'broadcast_greater',
+    'broadcast_greater_equal',
+    'broadcast_lesser',
+    'broadcast_lesser_equal',
+    'broadcast_logical_and',
+    'broadcast_logical_or',
+    'broadcast_logical_xor',
+    'broadcast_maximum',
+    'broadcast_minimum',
+    'broadcast_minus',
+    'broadcast_mod',
+    'broadcast_mul',
+    'broadcast_not_equal',
+    'broadcast_sub',
+    'dot',
+    'elemwise_add',
+    'elemwise_div',
+    'elemwise_mul',
+    'elemwise_sub',
+    'stack',
+    '_Maximum',
+    '_Minimum',
+    '_contrib_MultiBoxDetection',
+    '_contrib_MultiBoxPrior',
+    '_contrib_MultiBoxTarget',
+    '_contrib_MultiProposal',
+    '_contrib_PSROIPooling',
+    '_contrib_Proposal',
+    '_contrib_ROIAlign',
+    '_contrib_box_iou',
+    '_contrib_box_nms',
+    '_contrib_box_non_maximum_suppression',
+    '_contrib_dgl_adjacency',
+    '_contrib_dgl_csr_neighbor_non_uniform_sample',
+    '_contrib_dgl_csr_neighbor_uniform_sample',
+    '_contrib_dgl_graph_compact',
+    '_contrib_dgl_subgraph',
+    '_contrib_edge_id',
+    'where',
+    '_sparse_where',
+    '_sparse_broadcast_add',
+    '_sparse_broadcast_div',
+    '_sparse_broadcast_minus',
+    '_sparse_broadcast_mul',
+    '_sparse_broadcast_plus',
+    '_sparse_broadcast_sub',
+    '_sparse_dot',
+    '_sparse_elemwise_add',
+    '_sparse_elemwise_div',
+    '_sparse_elemwise_mul',
+    '_sparse_elemwise_sub',
+    '_sparse_sum',
+    ]
+
+LOSS_OUTPUT_FUNCTIONS = [
+    'SoftmaxOutput',
+    'LinearRegressionOutput',
+    'LogisticRegressionOutput',
+    'MAERegressionOutput',
+    ]
diff --git a/python/mxnet/contrib/amp/loss_scaler.py b/python/mxnet/contrib/amp/loss_scaler.py
new file mode 100755
index 0000000..a2600bc
--- /dev/null
+++ b/python/mxnet/contrib/amp/loss_scaler.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Dynamic loss scaler for AMP."""
+import logging
+
+from ...ndarray import multi_all_finite
+from ...ndarray import ndarray as nd
+from ... import autograd as ag
+
+class LossScaler(object):
+    """Dynamic loss scaler for AMP.
+
+    Properties
+    ----------
+    loss_scale : float
+        The current loss scale
+    """
+    def __init__(self):
+        self._loss_scale = 2.**16
+        self._next_loss_scale = self._loss_scale
+        self._max_loss_scale = 2.**24
+        self._scale_seq_len = 2000
+        self._unskipped = 0
+        self._has_overflow = False
+
+    @property
+    def loss_scale(self):
+        return self._loss_scale
+
+    def launch_check_overflow(self, params):
+        """Launch overflow checking for gradients."""
+        self._wait_for_outputs = True
+        self._has_overflow = False
+        with ag.pause():
+            chunk_size = 200
+            valid_params = [p._grad[0] for p in params if p._grad is not None]
+            gpu_output = nd.ones((1,), ctx=valid_params[0].context)
+            nb_params = len(valid_params)
+            for idx in range(0, nb_params, chunk_size):
+                multi_all_finite(*valid_params[idx:idx+chunk_size],
+                                 num_arrays=len(valid_params[idx:idx+chunk_size]),
+                                 init_output=False, out=gpu_output)
+            self.output = gpu_output
+
+    def wait_and_update(self):
+        """Wait for the results of overflow checking and update the loss scale."""
+        if self._wait_for_outputs:
+            self._has_overflow = not bool(self.output.asnumpy())
+            self._loss_scale = self._next_loss_scale
+            if self._has_overflow:
+                self._next_loss_scale = self._loss_scale / 2.
+                self._unskipped = 0
+                logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
+            else:
+                self._unskipped += 1
+            if self._unskipped == self._scale_seq_len:
+                self._unskipped = 0
+                self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
+                logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
+            self._wait_for_outputs = False
+        return self._has_overflow
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 2f3ed91..20f0a32 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -865,7 +865,7 @@ class HybridBlock(Block):
         """Infers data type of Parameters from inputs."""
         self._infer_attrs('infer_type', 'dtype', *args)
 
-    def export(self, path, epoch=0):
+    def export(self, path, epoch=0, remove_amp_cast=True):
         """Export HybridBlock to json format that can be loaded by
         `SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface.
 
@@ -885,7 +885,7 @@ class HybridBlock(Block):
                 "Please first call block.hybridize() and then run forward with "
                 "this block at least once before calling export.")
         sym = self._cached_graph[1]
-        sym.save('%s-symbol.json'%path)
+        sym.save('%s-symbol.json'%path, remove_amp_cast=remove_amp_cast)
 
         arg_names = set(sym.list_arguments())
         aux_names = set(sym.list_auxiliary_states())
diff --git a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
index b7a19f7..69ec92f 100644
--- a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
@@ -462,10 +462,10 @@ class _ConvLSTMCell(_BaseConvRNNCell):
         forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
         in_transform = self._get_activation(F, slice_gates[2], self._activation, name=prefix+'c')
         out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
-        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
-                                   name=prefix+'state')
-        next_h = F._internal._mul(out_gate, self._get_activation(F, next_c, self._activation),
-                                  name=prefix+'out')
+        next_c = F.elemwise_add(forget_gate * states[1], in_gate * in_transform,
+                                name=prefix+'state')
+        next_h = F.elemwise_mul(out_gate, self._get_activation(F, next_c, self._activation),
+                                name=prefix+'out')
 
         return next_h, [next_h, next_c]
 
@@ -753,8 +753,8 @@ class _ConvGRUCell(_BaseConvRNNCell):
         next_h_tmp = self._get_activation(F, i2h + reset_gate * h2h, self._activation,
                                           name=prefix+'h_act')
 
-        next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * states[0],
-                                   name=prefix+'out')
+        next_h = F.elemwise_add((1. - update_gate) * next_h_tmp, update_gate * states[0],
+                                name=prefix+'out')
 
         return next_h, [next_h]
 
diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index 3bd8e78..5a4d014 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -312,10 +312,10 @@ class LSTMPCell(HybridRecurrentCell):
         forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
         in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
         out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
-        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
-                                   name=prefix+'state')
-        hidden = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
-                                  name=prefix+'hidden')
+        next_c = F.elemwise_add(forget_gate * states[1], in_gate * in_transform,
+                                name=prefix+'state')
+        hidden = F.elemwise_mul(out_gate, F.Activation(next_c, act_type="tanh"),
+                                name=prefix+'hidden')
         next_r = F.FullyConnected(data=hidden, num_hidden=self._projection_size,
                                   weight=h2r_weight, no_bias=True, name=prefix+'out')
 
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 6ef3604..9154ccf 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -539,11 +539,11 @@ class LSTMCell(HybridRecurrentCell):
             F, slice_gates[2], self._activation, name=prefix+'c')
         out_gate = self._get_activation(
             F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
-        next_c = F._internal._plus(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'),
-                                   F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'),
-                                   name=prefix+'state')
-        next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'),
-                                  name=prefix+'out')
+        next_c = F.elemwise_add(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'),
+                                F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'),
+                                name=prefix+'state')
+        next_h = F.elemwise_mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'),
+                                name=prefix+'out')
 
         return next_h, [next_h, next_c]
 
@@ -667,11 +667,11 @@ class GRUCell(HybridRecurrentCell):
                                   name=prefix+'h_act')
 
         ones = F.ones_like(update_gate, name=prefix+"ones_like0")
-        next_h = F._internal._plus(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'),
-                                                  next_h_tmp,
-                                                  name=prefix+'mul1'),
-                                   F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'),
-                                   name=prefix+'out')
+        next_h = F.elemwise_add(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'),
+                                               next_h_tmp,
+                                               name=prefix+'mul1'),
+                                F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'),
+                                name=prefix+'out')
 
         return next_h, [next_h]
 
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index 9ff23b7..7e324a1 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -391,7 +391,7 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
     # end of all epochs
 
 
-def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params):
+def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params, remove_amp_cast=True):
     """Checkpoint the model data into file.
 
     Parameters
@@ -406,13 +406,15 @@ def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params):
         Model parameter, dict of name to NDArray of net's weights.
     aux_params : dict of str to NDArray
         Model parameter, dict of name to NDArray of net's auxiliary states.
+    remove_amp_cast : bool, optional
+        Whether to remove the amp_cast and amp_multicast operators, before saving the model.
     Notes
     -----
     - ``prefix-symbol.json`` will be saved for symbol.
     - ``prefix-epoch.params`` will be saved for parameters.
     """
     if symbol is not None:
-        symbol.save('%s-symbol.json' % prefix)
+        symbol.save('%s-symbol.json' % prefix, remove_amp_cast=remove_amp_cast)
 
     save_dict = {('arg:%s' % k) : v.as_in_context(cpu()) for k, v in arg_params.items()}
     save_dict.update({('aux:%s' % k) : v.as_in_context(cpu()) for k, v in aux_params.items()})
@@ -905,7 +907,7 @@ class FeedForward(BASE_ESTIMATOR):
                             sym_gen=self.sym_gen)
 
 
-    def save(self, prefix, epoch=None):
+    def save(self, prefix, epoch=None, remove_amp_cast=True):
         """Checkpoint the model checkpoint into file.
         You can also use `pickle` to do the job if you only work on Python.
         The advantage of `load` and `save` (as compared to `pickle`) is that
@@ -916,6 +918,8 @@ class FeedForward(BASE_ESTIMATOR):
         ----------
         prefix : str
             Prefix of model name.
+        remove_amp_cast : bool, optional
+            Whether to remove the amp_cast and amp_multicast operators, before saving the model.
 
         Notes
         -----
@@ -925,7 +929,7 @@ class FeedForward(BASE_ESTIMATOR):
         if epoch is None:
             epoch = self.num_epoch
         assert epoch is not None
-        save_checkpoint(prefix, epoch, self.symbol, self.arg_params, self.aux_params)
+        save_checkpoint(prefix, epoch, self.symbol, self.arg_params, self.aux_params, remove_amp_cast=remove_amp_cast)
 
     @staticmethod
     def load(prefix, epoch, ctx=None, **kwargs):
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index e83751d..c186728 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -162,7 +162,7 @@ class Module(BaseModule):
             mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
         return mod
 
-    def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
+    def save_checkpoint(self, prefix, epoch, save_optimizer_states=False, remove_amp_cast=True):
         """Saves current progress to checkpoint.
         Use `mx.callback.module_checkpoint` as `epoch_end_callback` to save during training.
 
@@ -175,7 +175,7 @@ class Module(BaseModule):
         save_optimizer_states : bool
             Whether to save optimizer states to continue training.
         """
-        self._symbol.save('%s-symbol.json'%prefix)
+        self._symbol.save('%s-symbol.json'%prefix, remove_amp_cast=remove_amp_cast)
         param_name = '%s-%04d.params' % (prefix, epoch)
         self.save_params(param_name)
         logging.info('Saved checkpoint to \"%s\"', param_name)
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 7c800df..0ea7c9f 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1275,7 +1275,7 @@ class Symbol(SymbolBase):
             self.handle, ctypes.byref(debug_str)))
         return py_str(debug_str.value)
 
-    def save(self, fname):
+    def save(self, fname, remove_amp_cast=True):
         """Saves symbol to a file.
 
         You can also use pickle to do the job if you only work on python.
@@ -1292,6 +1292,8 @@ class Symbol(SymbolBase):
             - "s3://my-bucket/path/my-s3-symbol"
             - "hdfs://my-bucket/path/my-hdfs-symbol"
             - "/path-to/my-local-symbol"
+        remove_amp_cast : bool, optional
+            Whether to remove the amp_cast and amp_multicast operators, before saving the model.
 
         See Also
         --------
@@ -1299,7 +1301,12 @@ class Symbol(SymbolBase):
         """
         if not isinstance(fname, string_types):
             raise TypeError('fname need to be string')
-        check_call(_LIB.MXSymbolSaveToFile(self.handle, c_str(fname)))
+        if remove_amp_cast:
+            handle = SymbolHandle()
+            check_call(_LIB.MXSymbolRemoveAmpCast(self.handle, ctypes.byref(handle)))
+            check_call(_LIB.MXSymbolSaveToFile(handle, c_str(fname)))
+        else:
+            check_call(_LIB.MXSymbolSaveToFile(self.handle, c_str(fname)))
 
     def tojson(self):
         """Saves symbol to a JSON string.
@@ -1371,6 +1378,12 @@ class Symbol(SymbolBase):
             raise TypeError('Only accept list of NDArrays or dict of str to NDArray')
         return c_array(NDArrayHandle, arg_handles), arg_arrays
 
+    def _gen_atomic_symbol(self):
+        handle = SymbolHandle()
+        check_call(_LIB.MXGenAtomicSymbolFromSymbol(self.handle, ctypes.byref(handle)))
+        return Symbol(handle)
+
+
     # pylint: disable=too-many-locals
     def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
                     group2ctx=None, shared_arg_names=None, shared_exec=None,
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 24a8852..a3b9fce 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -441,6 +441,16 @@ int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out) {
   API_END_HANDLE_ERROR(delete s);
 }
 
+int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle) {
+  nnvm::Symbol* s = new nnvm::Symbol();
+  API_BEGIN();
+  nnvm::Symbol *source = static_cast<nnvm::Symbol*>(sym_handle);
+  *s = source->Copy();
+  s->outputs = nnvm::ApplyPass(Symbol2Graph(*s), "RemoveAmpCast").outputs;
+  *ret_sym_handle = s;
+  API_END_HANDLE_ERROR(delete s);
+}
+
 int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) {
   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
   API_BEGIN();
@@ -839,3 +849,17 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
   *ret_sym_handle = s;
   API_END_HANDLE_ERROR(delete s);
 }
+
+int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle) {
+  nnvm::Symbol *s = new nnvm::Symbol();
+  API_BEGIN();
+  nnvm::Symbol *source = static_cast<nnvm::Symbol *>(sym_handle);
+  CHECK_EQ(source->outputs.size(), 1U)
+    << "Generating atomic symbol from other symbol only works for nongrouped symbol.";
+  const auto& node = source->outputs[0];
+  const auto *op = node.node->op();
+  const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow);
+  *s = nnvm::Symbol::CreateFunctor(op, attrs);
+  *ret_sym_handle = s;
+  API_END_HANDLE_ERROR(delete s);
+}
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 72e69df..dbc1cbf 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -118,7 +118,7 @@ CachedOp::CachedOp(
         if (_copy->attr_parser != nullptr) {
           _copy->attr_parser(&(copy_node->attrs));
         }
-        fwd_graph_.outputs.push_back(NodeEntry{copy_node, 0, 0});
+        fwd_graph_.outputs.emplace_back(copy_node, 0, 0);
       } else {
         dedup_out.insert({i, 0});
         fwd_graph_.outputs.push_back(i);
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index b027de0..a1c41ee 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -363,7 +363,7 @@ std::vector<NDArray*> Imperative::Backward(
       auto node = Node::Create();
       node->attrs.op = copy_op;
       node->inputs.push_back(e);
-      graph.outputs.push_back(NodeEntry{node, 0, 0});
+      graph.outputs.emplace_back(node, 0, 0);
     } else {
       graph.outputs.push_back(e);
     }
diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc
index ac70848..ce442ed 100644
--- a/src/nnvm/plan_memory.cc
+++ b/src/nnvm/plan_memory.cc
@@ -240,10 +240,16 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
         bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 &&
                                   fignore_inputs[inode.source->op()](
                                       inode.source->attrs).size() == inode.source->num_inputs());
+        // Identity should only be true if shape.Size() and types match
+        bool real_identity = identity[ipair] &&
+                             ndim_is_known(shape_vec[eid_out]) &&
+                             ndim_is_known(shape_vec[eid_in]) &&
+                             shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
+                             dtype_vec[eid_out] == dtype_vec[eid_in];
         if (taken[kv.first] == false &&
             sid_out == GraphAllocator::kBadStorageID &&
             sid_in >= 0 &&
-            ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || identity[ipair]) &&
+            ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || real_identity) &&
             entry_ref_count[eid_out] > 0 &&
             shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
              (dtype_vec[eid_out] == dtype_vec[eid_in] ||
diff --git a/src/operator/contrib/all_finite-inl.h b/src/operator/contrib/all_finite-inl.h
new file mode 100755
index 0000000..cf63fce
--- /dev/null
+++ b/src/operator/contrib/all_finite-inl.h
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file all_finite-inl.h
+ * \brief operator for checking if a group of array is all finite
+ * \author Clement Fuji Tsang
+ */
+
+#ifndef MXNET_OPERATOR_CONTRIB_ALL_FINITE_INL_H_
+#define MXNET_OPERATOR_CONTRIB_ALL_FINITE_INL_H_
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/operator_util.h>
+#include <mxnet/op_attr_types.h>
+#include <mshadow/base.h>
+#include <nnvm/op.h>
+#include <nnvm/op_attr_types.h>
+#include <vector>
+#include "../operator_common.h"
+#include "../mshadow_op.h"
+#include "../elemwise_op_common.h"
+#include "../mxnet_op.h"
+#include "../tensor/init_op.h"
+#include "../tensor/util/tensor_util-inl.h"
+
+namespace mxnet {
+namespace op {
+
+struct AllFiniteParam: public dmlc::Parameter<AllFiniteParam> {
+  bool init_output;
+  DMLC_DECLARE_PARAMETER(AllFiniteParam) {
+    DMLC_DECLARE_FIELD(init_output)
+    .set_default(true)
+    .describe("Initialize output to 1.");
+  }
+};
+
+struct MultiAllFiniteParam : public dmlc::Parameter<MultiAllFiniteParam> {
+  int num_arrays;
+  bool init_output;
+  DMLC_DECLARE_PARAMETER(MultiAllFiniteParam) {
+    DMLC_DECLARE_FIELD(num_arrays)
+    .set_default(1)
+    .describe("Number of arrays.");
+    DMLC_DECLARE_FIELD(init_output)
+    .set_default(true)
+    .describe("Initialize output to 1.");
+  }
+};
+
+template<typename DType>
+struct MultiAllFiniteKernelParam {
+  static const int N = 200;
+  int count;
+  size_t max_size;
+  size_t sizes[N];
+  DType *arrays[N];
+};
+
+template<typename xpu, typename DType>
+MultiAllFiniteKernelParam<DType> FillMultiAllFiniteParam(const MultiAllFiniteParam& op_param,
+                                                         const OpContext &ctx,
+                                                         const std::vector<TBlob> &inputs) {
+  MultiAllFiniteKernelParam<DType> param;
+  using namespace mxnet_op;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  param.count = op_param.num_arrays;
+  param.max_size = 0;
+  for (int i = 0; i < param.count; ++i) {
+    param.sizes[i] = inputs[i].shape_.Size();
+    if (param.max_size < param.sizes[i]) {
+      param.max_size = param.sizes[i];
+    }
+    param.arrays[i] = inputs[i].FlatTo2D<xpu, DType>(s).dptr_;
+  }
+  return param;
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_CONTRIB_ALL_FINITE_INL_H_
diff --git a/src/operator/contrib/all_finite.cc b/src/operator/contrib/all_finite.cc
new file mode 100755
index 0000000..5e77510
--- /dev/null
+++ b/src/operator/contrib/all_finite.cc
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file all_finite.cc 
+ * \brief operator for checking if a group of array is all finite
+ * \author Clement Fuji Tsang
+ */
+#include "./all_finite-inl.h"
+#include <cmath>
+
+namespace mxnet {
+namespace op {
+
+template<typename DType>
+struct AllFiniteCPUKernel {
+  MSHADOW_XINLINE static void Map(int i, const DType* in, float* out) {
+    bool is_finite = true;
+    is_finite = std::isfinite(static_cast<float>(in[i]))  ? is_finite : false;
+    if (!is_finite) {
+      out[0] = 0.;
+    }
+  }
+};
+
+inline void AllFiniteCPU(const nnvm::NodeAttrs& attrs,
+                         const OpContext &ctx,
+                         const std::vector<TBlob> &inputs,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  Stream<cpu>* s = ctx.get_stream<cpu>();
+  const AllFiniteParam& op_param = nnvm::get<AllFiniteParam>(attrs.parsed);
+  Tensor<cpu, 2, float> out = outputs[0].FlatTo2D<cpu, float>(s);
+  if (op_param.init_output) {
+    out = 1.;
+  }
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<cpu, 2, DType> in = inputs[0].FlatTo2D<cpu, DType>(s);
+    const int n = in.shape_.Size();
+    Kernel<AllFiniteCPUKernel<DType>, cpu>::Launch(s, n, in.dptr_, out.dptr_);
+  });
+}
+
+template<typename DType>
+struct MultiAllFiniteCPUKernel {
+  MSHADOW_XINLINE static void Map(int i, const MultiAllFiniteKernelParam<DType> param,
+                                  float* out) {
+    bool is_finite = true;
+    for (int index = 0; index < param.count; ++index) {
+      if ((size_t)i < param.sizes[index]) {
+        is_finite = std::isfinite(static_cast<float>(param.arrays[index][i])) ? is_finite : false;
+      }
+    }
+    if (!is_finite) {
+      out[0] = 0.;
+    }
+  }
+};
+
+inline void MultiAllFiniteCPU(const nnvm::NodeAttrs& attrs,
+                              const OpContext &ctx,
+                              const std::vector<TBlob> &inputs,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  Stream<cpu>* s = ctx.get_stream<cpu>();
+  const MultiAllFiniteParam& op_param = nnvm::get<MultiAllFiniteParam>(attrs.parsed);
+  Tensor<cpu, 2, float> out = outputs[0].FlatTo2D<cpu, float>(s);
+  if (op_param.init_output)
+    out = 1.;
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    MultiAllFiniteKernelParam<DType> param =
+      FillMultiAllFiniteParam<cpu, DType>(op_param, ctx, inputs);
+    Kernel<MultiAllFiniteCPUKernel<DType>, cpu>::Launch(s, param.max_size,
+                                                       param, out.dptr_);
+  });
+}
+
+DMLC_REGISTER_PARAMETER(AllFiniteParam);
+
+NNVM_REGISTER_OP(all_finite)
+.describe(R"code(Check if all the float numbers in the array are finite (used for AMP)
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<AllFiniteParam>)
+.set_attr<mxnet::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_attrs,
+     std::vector<TShape> *out_attrs){
+    (*out_attrs)[0] = TShape({1});
+    return true;
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int> *in_attrs,
+     std::vector<int> *out_attrs){
+    (*out_attrs)[0] = mshadow::kFloat32;
+    return true;
+  })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    std::vector<std::string> ret;
+    ret.emplace_back("data");
+    return ret;
+  })
+.add_argument("data", "NDArray", "Array")
+.add_arguments(AllFiniteParam::__FIELDS__())
+.set_attr<FCompute>("FCompute<cpu>", AllFiniteCPU);
+
+DMLC_REGISTER_PARAMETER(MultiAllFiniteParam);
+
+NNVM_REGISTER_OP(multi_all_finite)
+.describe(R"code(Check if all the float numbers in all the arrays are finite (used for AMP)
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    const MultiAllFiniteParam& param = dmlc::get<MultiAllFiniteParam>(attrs.parsed);
+    return static_cast<uint32_t>(param.num_arrays);
+  })
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<MultiAllFiniteParam>)
+.set_attr<mxnet::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_attrs,
+     std::vector<TShape> *out_attrs) {
+    (*out_attrs)[0] = TShape({1});
+    return true;
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int> *in_attrs,
+     std::vector<int> *out_attrs) {
+    (*out_attrs)[0] = mshadow::kFloat32;
+    return true;
+  })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    uint32_t num_args = dmlc::get<MultiAllFiniteParam>(attrs.parsed).num_arrays;
+    std::vector<std::string> ret;
+    for (uint32_t i = 0; i < num_args; ++i) {
+      ret.push_back(std::string("array_") + std::to_string(i));
+    }
+    return ret;
+  })
+.add_argument("data", "NDArray-or-Symbol[]", "Arrays")
+.add_arguments(MultiAllFiniteParam::__FIELDS__())
+.set_attr<FCompute>("FCompute<cpu>", MultiAllFiniteCPU);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/contrib/all_finite.cu b/src/operator/contrib/all_finite.cu
new file mode 100755
index 0000000..69ba35f
--- /dev/null
+++ b/src/operator/contrib/all_finite.cu
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file all_finite.cu
+ * \brief operator for checking if a group of array is all finite
+ * \author Clement Fuji Tsang
+ */
+
+#include "./all_finite-inl.h"
+
+namespace mxnet {
+namespace op {
+
+template <typename DType>
+__global__ void AllFiniteGPUKernel(const int size, const DType* in, float* out) {
+  bool is_finite = true;
+  CUDA_KERNEL_LOOP(i, size) {
+    is_finite = isfinite(static_cast<float>(in[i])) ? is_finite : false;
+  }
+  __syncthreads();
+  if (!is_finite) {
+    out[0] = 0.;
+  }
+}
+
+inline void AllFiniteGPU(const nnvm::NodeAttrs& attrs,
+                         const OpContext &ctx,
+                         const std::vector<TBlob> &inputs,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  const AllFiniteParam& op_param = nnvm::get<AllFiniteParam>(attrs.parsed);
+  Tensor<gpu, 2, float> out = outputs[0].FlatTo2D<gpu, float>(s);
+  if (op_param.init_output)
+    out = 1.;
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<gpu, 2, DType> in = inputs[0].FlatTo2D<gpu, DType>(s);
+    const int n = in.shape_.Size();
+    AllFiniteGPUKernel<DType><<<cuda_get_num_blocks(n),
+                                mshadow::cuda::kBaseThreadNum, 0,
+                                mshadow::Stream<gpu>::GetStream(s)>>>(n, in.dptr_, out.dptr_);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(AllFiniteGPUKernel<DType>);
+  });
+}
+
+template <typename DType>
+__global__ void MultiAllFiniteGPUKernel(const MultiAllFiniteKernelParam<DType> param, float* out) {
+  bool is_finite = true;
+  for (int index = 0; index < param.count; ++index) {
+    CUDA_KERNEL_LOOP(i, param.sizes[index]) {
+      is_finite = isfinite(static_cast<float>(param.arrays[index][i])) ? is_finite : false;
+    }
+  }
+  __syncthreads();
+  if (!is_finite) {
+    out[0] = 0.;
+  }
+}
+
+inline void MultiAllFiniteGPU(const nnvm::NodeAttrs& attrs,
+                              const OpContext &ctx,
+                              const std::vector<TBlob> &inputs,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  const MultiAllFiniteParam& op_param = nnvm::get<MultiAllFiniteParam>(attrs.parsed);
+  Tensor<gpu, 2, float> out = outputs[0].FlatTo2D<gpu, float>(s);
+  if (op_param.init_output)
+    out = 1.;
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    MultiAllFiniteKernelParam<DType> param =
+      FillMultiAllFiniteParam<gpu, DType>(op_param, ctx, inputs);
+    MultiAllFiniteGPUKernel<DType><<<cuda_get_num_blocks(param.max_size),
+                                     mshadow::cuda::kBaseThreadNum, 1,
+                                     mshadow::Stream<gpu>::GetStream(s)>>>(param, out.dptr_);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(MultiAllFiniteGPUKernel<DType>);
+  });
+}
+
+NNVM_REGISTER_OP(all_finite)
+.set_attr<FCompute>("FCompute<gpu>", AllFiniteGPU);
+
+NNVM_REGISTER_OP(multi_all_finite)
+.set_attr<FCompute>("FCompute<gpu>", MultiAllFiniteGPU);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/contrib/amp_graph_pass.cc b/src/operator/contrib/amp_graph_pass.cc
new file mode 100644
index 0000000..abecc4a
--- /dev/null
+++ b/src/operator/contrib/amp_graph_pass.cc
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file amp_graph_pass.cc
+ * \brief graph pass regarding AMP
+ * \author Clement Fuji Tsang
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+
+namespace mxnet {
+namespace op {
+
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::Graph;
+
+
+/*
+ * \brief Remove amp_cast and amp_multicast and replug the fp32 weights
+ */
+Graph RemoveAmpCast(Graph&& g) {
+  DFSVisit(g.outputs, [](const NodePtr& n) {
+    for (size_t i = 0; i < n->inputs.size(); ++i) {
+      auto e = n->inputs[i];
+      if (e.node->op() == Op::Get("amp_cast")) {
+        n->inputs[i] = e.node->inputs[0];
+      } else if (e.node->op() == Op::Get("amp_multicast")) {
+        n->inputs[i] = e.node->inputs[e.index];
+      }
+    }
+  });
+  return g;
+}
+
+NNVM_REGISTER_PASS(RemoveAmpCast)
+.describe("")
+.set_body(RemoveAmpCast)
+.set_change_graph(true);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 412bfa1..63b4200 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -224,7 +224,7 @@ std::vector<nnvm::NodeEntry> Gradient(
     size_t i = static_cast<size_t>(t);
     if (i >= params.num_outs + params.num_args) {
       uint32_t idx = static_cast<uint32_t>(i-params.num_outs-params.num_args);
-      g->inputs.push_back(nnvm::NodeEntry{n, idx, 0});
+      g->inputs.emplace_back(n, idx, 0);
     } else if (i >= params.num_outs) {
       g->inputs.push_back(n->inputs[i-params.num_outs]);
     } else {
diff --git a/src/operator/tensor/amp_cast.cc b/src/operator/tensor/amp_cast.cc
new file mode 100644
index 0000000..08d4387
--- /dev/null
+++ b/src/operator/tensor/amp_cast.cc
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file amp_cast.cc
+ * \brief Casts used by AMP
+ */
+
+#include "./amp_cast.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(AMPCastParam);
+DMLC_REGISTER_PARAMETER(AMPMultiCastParam);
+
+NNVM_REGISTER_OP(amp_cast)
+.describe(R"code(Cast function between low precision float/FP32 used by AMP.
+
+It casts only between low precision float/FP32 and does not do anything for other types.
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<AMPCastParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", AMPCastType)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+  [](const NodeAttrs& attrs){
+    return std::vector<bool>{true};
+  })
+.set_attr<FCompute>("FCompute<cpu>", AMPCastCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_amp_cast"})
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(AMPCastParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_amp_cast)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+  [](const NodeAttrs& attrs){
+    return std::vector<bool>{true};
+  })
+.set_attr<FCompute>("FCompute<cpu>", AMPCastCompute<cpu>);
+
+NNVM_REGISTER_OP(amp_multicast)
+.describe(R"code(Cast function used by AMP, that casts its inputs to the common widest type.
+
+It casts only between low precision float/FP32 and does not do anything for other types.
+
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    const AMPMultiCastParam& param = dmlc::get<AMPMultiCastParam>(attrs.parsed);
+    return static_cast<uint32_t>(param.num_outputs);
+  })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+    const AMPMultiCastParam& param = dmlc::get<AMPMultiCastParam>(attrs.parsed);
+    return static_cast<uint32_t>(param.num_outputs);
+  })
+.set_attr_parser(ParamParser<AMPMultiCastParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", AMPMultiCastShape)
+.set_attr<nnvm::FInferType>("FInferType", AMPMultiCastType)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    uint32_t num_args = dmlc::get<AMPMultiCastParam>(attrs.parsed).num_outputs;
+    std::vector<std::string> ret;
+    for (uint32_t i = 0; i < num_args; ++i) {
+      ret.push_back(std::string("data_") + std::to_string(i));
+    }
+    return ret;
+  })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    int num_args = dmlc::get<AMPMultiCastParam>(attrs.parsed).num_outputs;
+    std::vector<std::pair<int, int>> ret;
+    for (int i = 0; i < num_args; ++i) {
+      ret.emplace_back(i, i);
+    }
+    return ret;
+  })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+  [](const NodeAttrs& attrs){
+    int num_args = dmlc::get<AMPMultiCastParam>(attrs.parsed).num_outputs;
+    return std::vector<bool>(num_args, true);
+  })
+.set_attr<FCompute>("FCompute<cpu>", AMPMultiCastCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_amp_multicast"})
+.add_argument("data", "NDArray-or-Symbol[]", "Weights")
+.add_arguments(AMPMultiCastParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_amp_multicast)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    const AMPMultiCastParam& param = dmlc::get<AMPMultiCastParam>(attrs.parsed);
+    return static_cast<uint32_t>(param.num_outputs);
+  })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+    const AMPMultiCastParam& param = dmlc::get<AMPMultiCastParam>(attrs.parsed);
+    return static_cast<uint32_t>(param.num_outputs);
+  })
+.set_attr_parser(ParamParser<AMPMultiCastParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    uint32_t num_args = dmlc::get<AMPMultiCastParam>(attrs.parsed).num_outputs;
+    std::vector<std::string> ret;
+    for (uint32_t i = 0; i < num_args; ++i) {
+      ret.push_back(std::string("grad_") + std::to_string(i));
+    }
+    return ret;
+  })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    int num_args = dmlc::get<AMPMultiCastParam>(attrs.parsed).num_outputs;
+    std::vector<std::pair<int, int>> ret;
+    for (int i = 0; i < num_args; ++i) {
+      ret.emplace_back(i, i);
+    }
+    return ret;
+  })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+  [](const NodeAttrs& attrs){
+    int num_args = dmlc::get<AMPMultiCastParam>(attrs.parsed).num_outputs;
+    return std::vector<bool>(num_args, true);
+  })
+.set_attr<FCompute>("FCompute<cpu>", AMPMultiCastCompute<cpu>)
+.add_argument("grad", "NDArray-or-Symbol[]", "Gradients")
+.add_arguments(AMPMultiCastParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/amp_cast.cu b/src/operator/tensor/amp_cast.cu
new file mode 100644
index 0000000..0a4f7c5
--- /dev/null
+++ b/src/operator/tensor/amp_cast.cu
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file amp_cast.cu
+ * \brief Casts used by AMP (GPU operators)
+ */
+
+#include "./amp_cast.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(amp_cast)
+.set_attr<FCompute>("FCompute<gpu>", AMPCastCompute<gpu>);
+NNVM_REGISTER_OP(_backward_amp_cast)
+.set_attr<FCompute>("FCompute<gpu>", AMPCastCompute<gpu>);
+
+NNVM_REGISTER_OP(amp_multicast)
+.set_attr<FCompute>("FCompute<gpu>", AMPMultiCastCompute<gpu>);
+NNVM_REGISTER_OP(_backward_amp_multicast)
+.set_attr<FCompute>("FCompute<gpu>", AMPMultiCastCompute<gpu>);
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/amp_cast.h b/src/operator/tensor/amp_cast.h
new file mode 100644
index 0000000..a722b41
--- /dev/null
+++ b/src/operator/tensor/amp_cast.h
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file amp_cast.h
+ * \brief Function definition of casts used by AMP
+ */
+
+#ifndef MXNET_OPERATOR_TENSOR_AMP_CAST_H_
+#define MXNET_OPERATOR_TENSOR_AMP_CAST_H_
+
+#include <vector>
+#include <utility>
+#include <algorithm>
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../elemwise_op_common.h"
+#include "../operator_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct AMPCastParam : public dmlc::Parameter<AMPCastParam> {
+  // use int for enumeration
+  int dtype;
+  DMLC_DECLARE_PARAMETER(AMPCastParam) {
+    DMLC_DECLARE_FIELD(dtype)
+    MXNET_ADD_ALL_TYPES
+    .describe("Output data type.");
+  }
+};
+
+struct AMPMultiCastParam : public dmlc::Parameter<AMPMultiCastParam> {
+  int num_outputs;
+
+  DMLC_DECLARE_PARAMETER(AMPMultiCastParam) {
+    DMLC_DECLARE_FIELD(num_outputs)
+    .describe("Number of input/output pairs to be casted to the widest type.");
+  }
+};
+
+inline bool AMPCastType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int> *in_attrs,
+                        std::vector<int> *out_attrs) {
+  using mshadow::kFloat32;
+  using mshadow::kFloat16;
+  const AMPCastParam& param = nnvm::get<AMPCastParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  if ((*in_attrs)[0] == kFloat32 || (*in_attrs)[0] == kFloat16) {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
+  } else {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
+  }
+  return (*in_attrs)[0] != -1;
+}
+
+inline bool AMPMultiCastType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int> *in_attrs,
+                        std::vector<int> *out_attrs) {
+  using mshadow::kFloat32;
+  using mshadow::kFloat16;
+  const AMPMultiCastParam& param = nnvm::get<AMPMultiCastParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), param.num_outputs);
+  CHECK_EQ(out_attrs->size(), param.num_outputs);
+  bool ret = true;
+  int widest_type = kFloat16;
+  for (int i = 0; i < param.num_outputs; ++i) {
+    if ((*in_attrs)[i] == kFloat32 || (*out_attrs)[i] == kFloat32) {
+      widest_type = kFloat32;
+    }
+  }
+  for (int i = 0; i < param.num_outputs; ++i) {
+    if ((*in_attrs)[i] == kFloat32 || (*in_attrs)[i] == kFloat16) {
+      TYPE_ASSIGN_CHECK(*out_attrs, i, widest_type);
+    } else {
+      TYPE_ASSIGN_CHECK(*out_attrs, i, (*in_attrs)[i]);
+    }
+    ret = ret && ((*in_attrs)[i] != -1);
+  }
+  return ret;
+}
+
+inline bool AMPMultiCastShape(const nnvm::NodeAttrs& attrs,
+                              std::vector<TShape> *in_attrs,
+                              std::vector<TShape> *out_attrs) {
+  const AMPMultiCastParam& param = dmlc::get<AMPMultiCastParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), param.num_outputs);
+  CHECK_EQ(out_attrs->size(), param.num_outputs);
+
+  bool all_inferred = true;
+  for (size_t i = 0; i < in_attrs->size(); ++i) {
+    // forward inference
+    SHAPE_ASSIGN_CHECK(*out_attrs, i, (*in_attrs)[i]);
+    // backward inference
+    SHAPE_ASSIGN_CHECK(*in_attrs, i, (*out_attrs)[i]);
+    all_inferred = all_inferred && !shape_is_none((*in_attrs)[i]);
+  }
+  return all_inferred;
+}
+
+template<typename xpu>
+void AMPCastCompute(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DstDType, {
+    Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
+    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, {
+      Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
+      if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
+          req[0] != kWriteInplace) {
+        Assign(out, req[0], tcast<DstDType>(data));
+      }
+    });
+  });
+}
+
+template<typename xpu>
+void AMPMultiCastCompute(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    MSHADOW_TYPE_SWITCH(outputs[i].type_flag_, DstDType, {
+      Tensor<xpu, 1, DstDType> out = outputs[i].FlatTo1D<xpu, DstDType>(s);
+      MSHADOW_TYPE_SWITCH(inputs[i].type_flag_, SrcDType, {
+        Tensor<xpu, 1, SrcDType> data = inputs[i].FlatTo1D<xpu, SrcDType>(s);
+        if (outputs[i].type_flag_ != inputs[i].type_flag_ ||
+            req[i] != kWriteInplace) {
+          Assign(out, req[i], tcast<DstDType>(data));
+        }
+      });
+    });
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_TENSOR_AMP_CAST_H_
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index f890963..a806307 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -291,7 +291,7 @@ NNVM_REGISTER_OP(broadcast_like)
                                  {{"keepdims", "true"}});
       auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
                          {n->inputs[1]}, nullptr, &n);
-      lhs.push_back(nnvm::NodeEntry{ng, 0, 0});
+      lhs.emplace_back(ng, 0, 0);
       return lhs;
     })
 .add_argument("lhs", "NDArray-or-Symbol", "First input.")
diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc
index f1ec8b5..77044cb 100644
--- a/src/operator/tensor/elemwise_sum.cc
+++ b/src/operator/tensor/elemwise_sum.cc
@@ -54,7 +54,7 @@ std::vector<nnvm::NodeEntry> ElementWiseSumGrad(
     nnvm::NodePtr id_node = nnvm::Node::Create();
     id_node->attrs.op = copy_op;
     id_node->inputs = {ograds[0]};
-    ret.push_back(nnvm::NodeEntry{id_node, 0, 0});
+    ret.emplace_back(id_node, 0, 0);
   }
   return ret;
 }
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index 3085f6d..86e8b01 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -428,7 +428,10 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
     Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
     MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, {
       Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
-      Assign(out, req[0], tcast<DstDType>(data));
+      if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
+          req[0] != kWriteInplace) {
+        Assign(out, req[0], tcast<DstDType>(data));
+      }
     });
   });
 }
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 5114a5d..1634606 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -360,7 +360,7 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
                               std::unordered_map<std::string, std::string>());
       auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
                          {n->inputs[1]}, nullptr, &n);
-      lhs.push_back(nnvm::NodeEntry{ng, 0, 0});
+      lhs.emplace_back(ng, 0, 0);
       return lhs;
     })
 .add_argument("lhs", "NDArray-or-Symbol", "First input.")
@@ -499,7 +499,7 @@ Negative indices are supported, and `None` can be used for either `lhs_end` or `
                               std::unordered_map<std::string, std::string>());
       auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
                          {n->inputs[1]}, nullptr, &n);
-      lhs.push_back(nnvm::NodeEntry{ng, 0, 0});
+      lhs.emplace_back(ng, 0, 0);
       return lhs;
     })
 .add_argument("lhs", "NDArray-or-Symbol", "First input.")
@@ -624,6 +624,10 @@ Example::
   [](const NodeAttrs& attrs){
     return std::vector<std::pair<int, int> >{{0, 0}};
   })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+  [](const NodeAttrs& attrs){
+    return std::vector<bool>{true};
+  })
 .set_attr<FCompute>("FCompute<cpu>", CastCompute<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_cast"})
 .add_argument("data", "NDArray-or-Symbol", "The input.")
@@ -635,6 +639,10 @@ NNVM_REGISTER_OP(_backward_cast)
   [](const NodeAttrs& attrs){
     return std::vector<std::pair<int, int> >{{0, 0}};
   })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+  [](const NodeAttrs& attrs){
+    return std::vector<bool>{true};
+  })
 .set_attr<FCompute>("FCompute<cpu>", CastCompute<cpu>);
 
 // negative
diff --git a/tests/python/unittest/test_amp.py b/tests/python/unittest/test_amp.py
new file mode 100644
index 0000000..b3e5598
--- /dev/null
+++ b/tests/python/unittest/test_amp.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import collections
+import ctypes
+import mxnet.contrib.amp as amp
+
+def test_amp_coverage():
+    conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS]
+
+    # Check for duplicates
+    for a in [amp.lists.symbol.FP16_FUNCS,
+          amp.lists.symbol.FP16_FP32_FUNCS,
+          amp.lists.symbol.FP32_FUNCS,
+          amp.lists.symbol.WIDEST_TYPE_CASTS,
+          conditional]:
+        ret = [item for item, count in collections.Counter(a).items() if count > 1]
+        assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
+
+    t = []
+    for a in [amp.lists.symbol.FP16_FUNCS,
+              amp.lists.symbol.FP16_FP32_FUNCS,
+              amp.lists.symbol.FP32_FUNCS,
+              amp.lists.symbol.WIDEST_TYPE_CASTS,
+              conditional]:
+        t += a
+    ret = [item for item, count in collections.Counter(t).items() if count > 1]
+    assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
+
+    # Check the coverage
+    py_str = lambda x: x.decode('utf-8')
+
+    plist = ctypes.POINTER(ctypes.c_char_p)()
+    size = ctypes.c_uint()
+
+    mx.base._LIB.MXListAllOpNames(ctypes.byref(size),
+                                     ctypes.byref(plist))
+    op_names = []
+    for i in range(size.value):
+        s = py_str(plist[i])
+        if not s.startswith("_backward") \
+           and not s.startswith("_contrib_backward_"):
+            op_names.append(s)
+
+    ret1 = set(op_names) - set(t)
+
+    assert ret1 == set(), ("Operators " + str(ret1) + " do not exist in AMP lists (in "
+                           "python/mxnet/contrib/amp/lists/symbol.py) - please add them. "
+                           """Please follow these guidelines for choosing a proper list:
+                           - if your operator is not to be used in a computational graph
+                             (e.g. image manipulation operators, optimizers) or does not have
+                             inputs, put it in FP16_FP32_FUNCS list,
+                           - if your operator requires FP32 inputs or is not safe to use with lower
+                             precision, put it in FP32_FUNCS list,
+                           - if your operator supports both FP32 and lower precision, has
+                             multiple inputs and expects all inputs to be of the same
+                             type, put it in WIDEST_TYPE_CASTS list,
+                           - if your operator supports both FP32 and lower precision and has
+                             either a single input or supports inputs of different type,
+                             put it in FP16_FP32_FUNCS list,
+                           - if your operator is both safe to use in lower precision and
+                             it is highly beneficial to use it in lower precision, then
+                             put it in FP16_FUNCS (this is unlikely for new operators)
+                           - If you are not sure which list to choose, FP32_FUNCS is the
+                             safest option""")
+
+if __name__ == '__main__':
+    test_amp_coverage()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 18d2ace..56607dc 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4314,47 +4314,117 @@ def test_cast():
             assert_almost_equal(exe.outputs[0].asnumpy(), X.astype(srctype).astype(dsttype), rtol=1e-3, atol=1e-5)
             assert_almost_equal(exe.grad_arrays[0].asnumpy(), X.astype(dsttype).astype(srctype), rtol=1e-3, atol=1e-5)
 
-
-# Test requires all platforms to round float32->float16 with same round-to-nearest-even policy.
-@with_seed()
-def test_cast_float32_to_float16():
+def get_cast_op_data():
     FP16_FRACTION_BITS = 10
     FP32_FRACTION_BITS = 23
     FP32_EXP_MIN = -126
     FP32_EXP_MAX = 127
     # generate test cases in the vicinity of representable float16 mantissas
     # and mid-way between them, but over the full range of float32 exponents.
-    def get_data():
-        for sign_bit in [0, 1]:
-            for exponent in range(FP32_EXP_MIN - FP32_FRACTION_BITS - 1, FP32_EXP_MAX + 2):
-                denominator = 2**(FP16_FRACTION_BITS + 1)
-                for numerator in range(0, denominator):
-                    fraction = numerator / float(denominator)
-                    for y in [-1.0, 0.0, 1.0]:
-                        small_delta = y / 2**FP32_FRACTION_BITS
-                        val = (-1.0)**sign_bit * 2.0**exponent * (1.0 + fraction + small_delta)
-                        yield val
-        # Add np.nan as a final data value to process
-        yield np.nan
-
-    input_np = np.array(list(get_data())).astype(np.float32)
+
+    for sign_bit in [0, 1]:
+        for exponent in range(FP32_EXP_MIN - FP32_FRACTION_BITS - 1, FP32_EXP_MAX + 2):
+            denominator = 2**(FP16_FRACTION_BITS + 1)
+            for numerator in range(0, denominator):
+                fraction = numerator / float(denominator)
+                for y in [-1.0, 0.0, 1.0]:
+                    small_delta = y / 2**FP32_FRACTION_BITS
+                    val = (-1.0)**sign_bit * 2.0**exponent * (1.0 + fraction + small_delta)
+                    yield val
+    # Add np.nan as a final data value to process
+    yield np.nan
+
+# Test requires all platforms to round float32->float16 with same round-to-nearest-even policy.
+@with_seed()
+def test_cast_float32_to_float16():
+    input_np = np.array(list(get_cast_op_data())).astype(np.float32)
     # The intermediate cast to np.float64 below gets around a numpy rounding bug that is fixed
     # as of numpy 1.17 by PR https://github.com/numpy/numpy/pull/12722
     expected_output = input_np.astype(np.float64).astype(np.float16)
 
-    x = mx.sym.Variable('x', dtype=np.float32)
-    sym = mx.sym.Cast(x, dtype=np.float16)
+    def check_cast(op, input_np, expected_output):
+        x = mx.sym.Variable('x', dtype=np.float32)
+        sym = op(x, dtype=np.float16)
+        ctx = default_context()
+        exe = sym.bind(ctx, {'x': mx.nd.array(input_np, dtype=np.float32, ctx=ctx)})
+        assert exe.arg_arrays[0].dtype == np.float32
+        assert exe.outputs[0].dtype == np.float16
+        exe.forward(is_train=True)
+        sym_output = exe.outputs[0].asnumpy()
+        for fp32_val, model_fp16_val, np_fp16_val in zip(input_np, sym_output, expected_output):
+            assert (model_fp16_val == np_fp16_val) or \
+                   (np.isnan(model_fp16_val) and np.isnan(np_fp16_val)), \
+                   'fp32->fp16 cast mismatch: with fp32 value {}, model_fp16 = {}, numpy_fp16 = {}'.format(
+                    fp32_val, model_fp16_val, np_fp16_val)
+
+    check_cast(mx.sym.Cast, input_np, expected_output)
+    check_cast(mx.sym.amp_cast, input_np, expected_output)
+
+
+@with_seed()
+def test_amp_multicast():
+    x = mx.sym.Variable('x', dtype=np.float16)
+    y = mx.sym.Variable('y', dtype=np.float32)
+    z = mx.sym.Variable('z', dtype=np.float16)
     ctx = default_context()
-    exe = sym.bind(ctx, {'x' : mx.nd.array(input_np, dtype=np.float32, ctx=ctx)})
-    assert exe.arg_arrays[0].dtype == np.float32
-    assert exe.outputs[0].dtype == np.float16
+    res = mx.sym.amp_multicast(x, y, z, num_outputs=3)
+    exe = res.bind(ctx, {'x': mx.nd.random.uniform(shape=(3, 3), dtype=np.float16, ctx=ctx),
+                         'y': mx.nd.random.uniform(shape=(3, 3), dtype=np.float32, ctx=ctx),
+                         'z': mx.nd.random.uniform(shape=(3, 3), dtype=np.float16, ctx=ctx)})
+    exe.forward(is_train=True)
+    out1, out2, out3 = exe.outputs
+    assert out1.asnumpy().dtype == np.float32
+    assert out2.asnumpy().dtype == np.float32
+    assert out3.asnumpy().dtype == np.float32
+
+    def check_amp_multicast(input_np, expected_output):
+        x = mx.sym.Variable('x', dtype=np.float16)
+        y = mx.sym.Variable('y', dtype=np.float32)
+        z = mx.sym.Variable('z', dtype=np.float16)
+        ctx = default_context()
+        res = mx.sym.amp_multicast(x, y, z, num_outputs=3)
+        exe = res.bind(ctx, {'x': mx.nd.array(input_np, dtype=np.float16, ctx=ctx),
+                             'y': mx.nd.array(input_np, dtype=np.float32, ctx=ctx),
+                             'z': mx.nd.array(input_np, dtype=np.float16, ctx=ctx)})
+        exe.forward(is_train=True)
+        sym_output = exe.outputs[0].asnumpy()
+        for fp32_val, model_fp16_val, np_fp16_val in zip(input_np, sym_output, expected_output):
+            assert (model_fp16_val == np_fp16_val) or \
+                   (np.isnan(model_fp16_val) and np.isnan(np_fp16_val)), \
+                   'fp32->fp16 cast mismatch: with fp32 value {}, model_fp16 = {}, numpy_fp16 = {}'.format(
+                    fp32_val, model_fp16_val, np_fp16_val)
+
+    input_np = np.array(list(get_cast_op_data()), dtype=np.float16)
+    expected_output = input_np.astype(np.float32)
+    check_amp_multicast(input_np, expected_output)
+
+
+@with_seed()
+def test_all_finite():
+    data = mx.sym.Variable("data", dtype=np.float32)
+    data2 = mx.sym.Variable("data2", dtype=np.float32)
+    finite_arr = mx.nd.array([[0, 0]])
+    inf_arr = mx.nd.array([[np.inf, np.inf]])
+    z = mx.sym.all_finite(data)
+    ctx = default_context()
+    exe = z.bind(ctx, {'data': inf_arr})
+    exe.forward(is_train=False)
+    sym_output = exe.outputs[0].asnumpy()
+    assert sym_output[0] == 0
+    exe = z.bind(ctx, {'data': finite_arr})
+    exe.forward(is_train=False)
+    sym_output = exe.outputs[0].asnumpy()
+    assert sym_output[0] == 1
+    z = mx.sym.multi_all_finite(data, data2, num_arrays=2)
+    exe = z.bind(ctx, {'data': finite_arr, 'data2': inf_arr})
+    exe.forward(is_train=False)
+    sym_output = exe.outputs[0].asnumpy()
+    assert sym_output[0] == 0
+    z = mx.sym.multi_all_finite(data, data2, num_arrays=2)
+    exe = z.bind(ctx, {'data': finite_arr, 'data2': finite_arr})
     exe.forward(is_train=False)
     sym_output = exe.outputs[0].asnumpy()
-    for fp32_val, model_fp16_val, np_fp16_val in zip(input_np, sym_output, expected_output):
-        assert (model_fp16_val == np_fp16_val) or \
-               (np.isnan(model_fp16_val) and np.isnan(np_fp16_val)), \
-            'fp32->fp16 cast mismatch: with fp32 value {}, model_fp16 = {}, numpy_fp16 = {}'.format(
-                fp32_val, model_fp16_val, np_fp16_val)
+    assert sym_output[0] == 1
 
 
 @with_seed()
diff --git a/tests/tutorials/test_sanity_tutorials.py b/tests/tutorials/test_sanity_tutorials.py
index f89c234..fb751b4 100644
--- a/tests/tutorials/test_sanity_tutorials.py
+++ b/tests/tutorials/test_sanity_tutorials.py
@@ -63,7 +63,8 @@ whitelist = ['basic/index.md',
              'tensorrt/inference_with_trt.md',
              'java/index.md',
              'java/mxnet_java_on_intellij.md',
-             'java/ssd_inference.md']
+             'java/ssd_inference.md',
+             'amp/index.md']
 whitelist_set = set(whitelist)
 
 def test_tutorial_downloadable():
diff --git a/tests/tutorials/test_tutorials.py b/tests/tutorials/test_tutorials.py
index c58881c..bbb45c7 100644
--- a/tests/tutorials/test_tutorials.py
+++ b/tests/tutorials/test_tutorials.py
@@ -198,3 +198,6 @@ def test_vision_cnn_visualization():
 
 def test_control_flow():
     assert _test_tutorial_nb('control_flow/ControlFlowTutorial')
+
+def test_amp():
+    assert _test_tutorial_nb('amp/amp_tutorial')