You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/02/26 21:05:36 UTC
[tvm] branch main updated: [ONNX]fix datatype on Reciprocal op
(#7519)
This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2d57470 [ONNX]fix datatype on Reciprocal op (#7519)
2d57470 is described below
commit 2d5747054ca05a0863236b317e2fed281b455a00
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Feb 26 14:05:22 2021 -0700
[ONNX]fix datatype on Reciprocal op (#7519)
* fix datatype on Reciprocal op
* clean up test case
---
python/tvm/relay/frontend/onnx.py | 3 ++-
tests/python/frontend/onnx/test_forward.py | 11 +++++++----
2 files changed, 9 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 58c2dbc..860753d 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -839,7 +839,8 @@ class Reciprocal(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
- return _expr.const(1.0) / inputs[0]
+ dtype = infer_type(inputs[0]).checked_type.dtype
+ return _expr.const(1.0, dtype=dtype) / inputs[0]
class Flatten(OnnxOpConverter):
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 8dbd049..1e13416 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1830,23 +1830,26 @@ def test_unary_ops():
dtype = "float32"
out_shape = in_shape
- def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5):
+ def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5, dtype="float32"):
+ x = x.astype(dtype)
+ ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
z = helper.make_node(op, ["in1"], ["out"])
graph = helper.make_graph(
[z],
"_test",
inputs=[
- helper.make_tensor_value_info("in1", TensorProto.FLOAT, list(in_shape)),
+ helper.make_tensor_value_info("in1", ONNX_DTYPE, list(in_shape)),
],
- outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
+ outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))],
)
model = helper.make_model(graph, producer_name="_test")
verify_with_ort_with_inputs(model, [x], rtol=rtol, atol=atol)
- x = np.random.uniform(size=in_shape).astype(dtype)
+ x = np.random.uniform(size=in_shape)
verify_unary_ops("Neg", x)
verify_unary_ops("Abs", x)
verify_unary_ops("Reciprocal", x)
+ verify_unary_ops("Reciprocal", x, dtype="float16")
verify_unary_ops("Sqrt", x)
verify_unary_ops("Relu", x)
verify_unary_ops("Exp", x)