You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ha...@apache.org on 2020/07/08 05:37:22 UTC

[incubator-tvm] branch master updated: [Frontend][MXNet] MXNet frontend support for AMP cast op (#5976)

This is an automated email from the ASF dual-hosted git repository.

haichen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new d9f009a  [Frontend][MXNet] MXNet frontend support for AMP cast op (#5976)
d9f009a is described below

commit d9f009a560fbec1f1f2394fdbbafbe7d43a92768
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue Jul 7 22:37:10 2020 -0700

    [Frontend][MXNet] MXNet frontend support for AMP cast op (#5976)
    
    * amp_cast
    
    * fix test
    
    * more tests
    
    * test more ctxs
    
    * fix doc
    
    * fix typo
    
    * address CR comment
    
    * fix lint
    
    * revert doc change
    
    * Revert "revert doc change"
    
    This reverts commit a410dd5569730ac81af67ddb333c3afbe97eddd7.
    
    * fix doc
    
    * Update relay_pass_infra.rst
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-42-138.ec2.internal>
---
 docs/dev/relay_add_pass.rst                 |  6 ++--
 docs/dev/virtual_machine.rst                |  4 +--
 python/tvm/relay/frontend/mxnet.py          | 19 +++++++++++-
 src/relay/transforms/fold_constant.cc       |  2 +-
 src/relay/transforms/partial_eval.cc        |  2 +-
 tests/python/frontend/mxnet/test_forward.py | 47 +++++++++++++++++++++++++++++
 6 files changed, 72 insertions(+), 8 deletions(-)

diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst
index a82ae4f..fc26559 100644
--- a/docs/dev/relay_add_pass.rst
+++ b/docs/dev/relay_add_pass.rst
@@ -181,7 +181,7 @@ Example: Constant Folding
 -------------------------
 
 In order to better understand the process of writing a pass, we will look at
-the constant folding pass (found in `src/relay/pass/fold_constant.cc`_)
+the constant folding pass (found in `src/relay/transforms/fold_constant.cc`_)
 as a guide, because it is a relatively simple pass that incorporates
 both types of traversals.
 
@@ -329,7 +329,7 @@ Now, we construct a more convenient interface ``FoldConstant`` for our constant
 folder. ``FoldConstant`` is a standalone function outside of the ``ConstantFolder``
 class that takes an expression and internally creates and uses a
 ``ConstantFolder`` instance (the full definition can be found in
-`src/relay/pass/fold_constant.cc`_).
+`src/relay/transforms/fold_constant.cc`_).
 
 
 Registering a Pass with the Pass Manager
@@ -403,4 +403,4 @@ in `src/relay/pass/`_.
 
 .. _src/relay/pass/: https://github.com/apache/incubator-tvm/tree/master/src/relay/pass
 
-.. _src/relay/pass/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/pass/fold_constant.cc
+.. _src/relay/transforms/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/transforms/fold_constant.cc
diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst
index 5bb5ade..5878003 100644
--- a/docs/dev/virtual_machine.rst
+++ b/docs/dev/virtual_machine.rst
@@ -38,8 +38,8 @@ them on the runtime. Graph runtime provides a fast execution experience but only
 subset of Relay programs.
 
 An alternative but not-standard approach is Relay's ahead-of-time compiler,
-which compiles a Relay program into a shared library containing an ahead-
-of-time implementation. The ahead-of-time compiler provides compelling performance
+which compiles a Relay program into a shared library containing an ahead-of-time
+implementation. The ahead-of-time compiler provides compelling performance
 but is difficult to extend and instrument, which can only be done by modifying the
 code generation and optimization mechanisms.
 
diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py
index 135756b..97b9d7a 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -903,6 +903,21 @@ def _mx_resize(inputs, attrs):
     return _op.image.resize(inputs[0], size,
                             coordinate_transformation_mode="align_corners")
 
+def _mx_amp_multicast(inputs, attrs):
+    cast_narrow = attrs.get_bool("cast_narrow", False)
+    dtypes = [_infer_type(x).checked_type.dtype for x in inputs]
+    supported_dtypes = ['float16', 'float32']
+    assert all([x in supported_dtypes for x in dtypes]), \
+            "amp_multicast support is limited to float16 and float32 inputs only."
+    has_float16 = any(x == "float16" for x in dtypes)
+    has_float32 = any(x == "float32" for x in dtypes)
+    dtype = dtypes[0]
+    if cast_narrow and has_float16:
+        dtype = 'float16'
+    if not cast_narrow and has_float32:
+        dtype = 'float32'
+    return [_op.cast(x, dtype) for x in inputs]
+
 def _mx_grid_generator(inputs, attrs):
     transform_type = attrs.get_str("transform_type")
     if transform_type == 'affine':
@@ -1481,7 +1496,7 @@ def _qnn_contrib_concat(inputs, attrs):
         # Get all dtypes. Find input and output scales, call concatenate.
         dtypes = [_infer_type(x).checked_type.dtype for x in input_exprs]
         assert all([x == 'uint8' for x in dtypes]), \
-                "Current suppor is limited to uint8 inputs only."
+                "Current support is limited to uint8 inputs only."
         new_min = min(mins)
         new_max = max(maxs)
         assert new_min == 0
@@ -2184,6 +2199,8 @@ _convert_map = {
     "Reshape"       : _reshape,
     "reshape"       : _reshape,
     "Cast"          : _cast,
+    "amp_cast"      : _cast,
+    "amp_multicast" : _mx_amp_multicast,
     "clip"          : _clip,
     "transpose"     : _transpose,
     "UpSampling"    : _upsampling,
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 50de871..d66d6bc 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -194,7 +194,7 @@ class ConstantFolder : public ExprMutator {
       return Expr();
     }
   }
-  // Constant evaluate a expression.
+  // Constant evaluate an expression.
   Expr ConstEvaluate(Expr expr) {
     std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::ToANormalForm(),
                                            transform::InferType()};
diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc
index 371142a..63bd04d 100644
--- a/src/relay/transforms/partial_eval.cc
+++ b/src/relay/transforms/partial_eval.cc
@@ -901,7 +901,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     }
   }
 
-  // Constant evaluate a expression.
+  // Constant evaluate an expression.
   PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
     std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::InferType()};
     auto mod = IRModule::FromExpr(expr);
diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py
index 4d8b1e9..c8bbf88 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -1131,6 +1131,51 @@ def test_forward_cond():
     verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
     verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
 
+def test_forward_amp_cast():
+    def verify(from_dtype, to_dtype):
+        from_np = np.random.uniform(size=(1,3,18)).astype(from_dtype)
+        x_var = mx.sym.var('x', dtype=from_dtype)
+        mx_sym = mx.sym.amp_cast(x_var, dtype=to_dtype)
+        shape_dict = {'x': (1,3,18)}
+        dtype_dict = {'x': from_dtype}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "vm", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(from_np)
+                assert op_res.dtype == to_dtype, op_res.dtype
+                tvm.testing.assert_allclose(op_res.asnumpy(), from_np.astype(to_dtype))
+
+    verify('float32', 'float16')
+    verify('float16', 'float32')
+
+def test_forward_amp_multicast():
+    def verify(dtypes, cast_narrow, expected_dtype):
+        x_nps = [np.random.uniform(size=(1,3,18)).astype(dtype) for dtype in dtypes]
+        x_vars = [mx.sym.var(str(i), dtype=dtype) for i, dtype in enumerate(dtypes)]
+        mx_sym = mx.sym.amp_multicast(*x_vars, cast_narrow=cast_narrow,
+                                      num_outputs=len(dtypes))
+        shape_dict = {}
+        dtype_dict = {}
+        for i, dtype in enumerate(dtypes):
+            shape_dict[str(i)] = (1,3,18)
+            dtype_dict[str(i)] = dtype
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "vm", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(*x_nps)
+                for i, res in enumerate(op_res):
+                    assert res.dtype == expected_dtype, res.dtype
+                    tvm.testing.assert_allclose(res.asnumpy(), x_nps[i].astype(expected_dtype))
+
+    verify(['float32', 'float16'], False, 'float32')
+    verify(['float32', 'float16'], True,  'float16')
+    verify(['float32', 'float32'], False, 'float32')
+    verify(['float32', 'float32'], True,  'float32')
+    verify(['float16', 'float16'], False, 'float16')
+    verify(['float16', 'float16'], True, 'float16')
+
 
 def test_forward_unravel_index():
     def verify(x, shape, dtype):
@@ -1402,3 +1447,5 @@ if __name__ == '__main__':
     test_forward_interleaved_matmul_selfatt_qk()
     test_forward_interleaved_matmul_selfatt_valatt()
     test_forward_box_decode()
+    test_forward_amp_multicast()
+    test_forward_amp_cast()