You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/01 23:26:29 UTC

[GitHub] [incubator-tvm] eric-haibin-lin opened a new pull request #5976: MXNet frontend support for AMP cast op

eric-haibin-lin opened a new pull request #5976:
URL: https://github.com/apache/incubator-tvm/pull/5976


   Thanks for contributing to TVM!   Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread.
   
   Support `amp_cast` and `amp_multicast` used by mxnet mixed precision. @icemelon9 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #5976: MXNet frontend support for AMP cast op

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #5976:
URL: https://github.com/apache/incubator-tvm/pull/5976#discussion_r449399713



##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -903,6 +903,22 @@ def _mx_resize(inputs, attrs):
     return _op.image.resize(inputs[0], size,
                             coordinate_transformation_mode="align_corners")
 
+def _mx_amp_multicast(inputs, attrs):
+    cast_narrow = attrs.get_bool("cast_narrow", False)
+    dtypes = [_infer_type(x).checked_type.dtype for x in inputs]
+    supported_dtypes = ['float16', 'float32']
+    assert all([x in supported_dtypes for x in dtypes]), \
+            "amp_multicast support is limited to float16 and float32 inputs only."
+    dtype = 'float32' if cast_narrow else dtypes[0]

Review comment:
       Just a minor and soft suggestion, so feel free to ignore.  Maybe following has better readability
   
   ~~~
   has_float16 = any(x == "float16" for x in dtypes)
   has_float32 = any(x == "float32" for x in dtypes)
   
   # logic with three vars - cart_narrow, has_float16, has_float32
   ~~~
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on pull request #5976: MXNet frontend support for AMP cast op

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #5976:
URL: https://github.com/apache/incubator-tvm/pull/5976#issuecomment-653201033


   Hmmm just to confirm, do we really want to keep those "fake" operators in the NNVM IR?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on pull request #5976: MXNet frontend support for AMP cast op

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on pull request #5976:
URL: https://github.com/apache/incubator-tvm/pull/5976#issuecomment-655297732


   Thanks @eric-haibin-lin @junrushao1994 @anijain2305 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5976: MXNet frontend support for AMP cast op

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5976:
URL: https://github.com/apache/incubator-tvm/pull/5976#discussion_r449321892



##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -903,6 +903,22 @@ def _mx_resize(inputs, attrs):
     return _op.image.resize(inputs[0], size,
                             coordinate_transformation_mode="align_corners")
 
+def _mx_amp_multicast(inputs, attrs):
+    cast_narrow = attrs.get_bool("cast_narrow", False)
+    dtypes = [_infer_type(x).checked_type.dtype for x in inputs]
+    supported_dtypes = ['float16', 'float32']
+    assert all([x in supported_dtypes for x in dtypes]), \
+            "amp_multicast support is limited to float16 and float32 inputs only."
+    dtype = 'float32' if cast_narrow else dtypes[0]
+    for t in dtypes:
+        if cast_narrow and t == 'float16':
+            dtype = 'float16'
+            break
+        elif not cast_narrow and t == 'float32':

Review comment:
       ```suggestion
           if not cast_narrow and t == 'float32':
   ```

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -903,6 +903,22 @@ def _mx_resize(inputs, attrs):
     return _op.image.resize(inputs[0], size,
                             coordinate_transformation_mode="align_corners")
 
+def _mx_amp_multicast(inputs, attrs):
+    cast_narrow = attrs.get_bool("cast_narrow", False)
+    dtypes = [_infer_type(x).checked_type.dtype for x in inputs]
+    supported_dtypes = ['float16', 'float32']
+    assert all([x in supported_dtypes for x in dtypes]), \
+            "amp_multicast support is limited to float16 and float32 inputs only."
+    dtype = 'float32' if cast_narrow else dtypes[0]

Review comment:
       ```suggestion
       dtype = dtypes[0]
   ```

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -903,6 +903,22 @@ def _mx_resize(inputs, attrs):
     return _op.image.resize(inputs[0], size,
                             coordinate_transformation_mode="align_corners")
 
+def _mx_amp_multicast(inputs, attrs):
+    cast_narrow = attrs.get_bool("cast_narrow", False)
+    dtypes = [_infer_type(x).checked_type.dtype for x in inputs]
+    supported_dtypes = ['float16', 'float32']
+    assert all([x in supported_dtypes for x in dtypes]), \
+            "amp_multicast support is limited to float16 and float32 inputs only."
+    dtype = 'float32' if cast_narrow else dtypes[0]
+    for t in dtypes:
+        if cast_narrow and t == 'float16':
+            dtype = 'float16'
+            break
+        elif not cast_narrow and t == 'float32':
+            dtype = 'float32'
+            break
+    return [relay.cast(x, dtype) for x in inputs]

Review comment:
       ```suggestion
       return [_op.cast(x, dtype) for x in inputs]
   ```

##########
File path: tests/python/frontend/mxnet/test_forward.py
##########
@@ -1307,6 +1353,8 @@ def verify(batch, seq_length, num_heads, head_dim):
 
 
 if __name__ == '__main__':
+    test_forward_amp_multicast()

Review comment:
       let's move to the end of test

##########
File path: tests/python/frontend/mxnet/test_forward.py
##########
@@ -1131,6 +1131,52 @@ def verify(a_np, b_np):
     verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
     verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
 
+def test_forward_amp_cast():
+    def verify(from_dtype, to_dtype):
+        from_nd = mx.nd.ones((2,2), dtype=from_dtype)

Review comment:
       Could you use random tensors with larger size?

##########
File path: tests/python/frontend/mxnet/test_forward.py
##########
@@ -1131,6 +1131,52 @@ def verify(a_np, b_np):
     verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
     verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
 
+def test_forward_amp_cast():
+    def verify(from_dtype, to_dtype):
+        from_nd = mx.nd.ones((2,2), dtype=from_dtype)
+        from_np = from_nd.asnumpy()
+        x_var = mx.sym.var('x', dtype=from_dtype)
+        mx_sym = mx.sym.amp_cast(x_var, dtype=to_dtype)
+        shape_dict = {'x': (2,2)}
+        dtype_dict = {'x': from_dtype}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "vm", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(from_np)
+                assert op_res.dtype == to_dtype, op_res.dtype
+                tvm.testing.assert_allclose(op_res.asnumpy(), 1.)
+
+    verify('float32', 'float16')
+    verify('float16', 'float32')
+
+def test_forward_amp_multicast():
+    def verify(dtypes, cast_narrow, expected_dtype):
+        x_nps = [np.ones((2,2), dtype=dtype) for dtype in dtypes]

Review comment:
       same here




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 merged pull request #5976: MXNet frontend support for AMP cast op

Posted by GitBox <gi...@apache.org>.
icemelon9 merged pull request #5976:
URL: https://github.com/apache/incubator-tvm/pull/5976


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org