You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/02/26 21:05:36 UTC

[tvm] branch main updated: [ONNX]fix datatype on Reciprocal op (#7519)

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d57470  [ONNX]fix datatype on Reciprocal op (#7519)
2d57470 is described below

commit 2d5747054ca05a0863236b317e2fed281b455a00
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Feb 26 14:05:22 2021 -0700

    [ONNX]fix datatype on Reciprocal op (#7519)
    
    * fix datatype on Reciprocal op
    
    * clean up test case
---
 python/tvm/relay/frontend/onnx.py          |  3 ++-
 tests/python/frontend/onnx/test_forward.py | 11 +++++++----
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 58c2dbc..860753d 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -839,7 +839,8 @@ class Reciprocal(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        return _expr.const(1.0) / inputs[0]
+        dtype = infer_type(inputs[0]).checked_type.dtype
+        return _expr.const(1.0, dtype=dtype) / inputs[0]
 
 
 class Flatten(OnnxOpConverter):
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 8dbd049..1e13416 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1830,23 +1830,26 @@ def test_unary_ops():
     dtype = "float32"
     out_shape = in_shape
 
-    def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5):
+    def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5, dtype="float32"):
+        x = x.astype(dtype)
+        ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
         z = helper.make_node(op, ["in1"], ["out"])
         graph = helper.make_graph(
             [z],
             "_test",
             inputs=[
-                helper.make_tensor_value_info("in1", TensorProto.FLOAT, list(in_shape)),
+                helper.make_tensor_value_info("in1", ONNX_DTYPE, list(in_shape)),
             ],
-            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+            outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))],
         )
         model = helper.make_model(graph, producer_name="_test")
         verify_with_ort_with_inputs(model, [x], rtol=rtol, atol=atol)
 
-    x = np.random.uniform(size=in_shape).astype(dtype)
+    x = np.random.uniform(size=in_shape)
     verify_unary_ops("Neg", x)
     verify_unary_ops("Abs", x)
     verify_unary_ops("Reciprocal", x)
+    verify_unary_ops("Reciprocal", x, dtype="float16")
     verify_unary_ops("Sqrt", x)
     verify_unary_ops("Relu", x)
     verify_unary_ops("Exp", x)