You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/08/07 00:57:29 UTC

[incubator-tvm] branch master updated: [ONNX]Mod operator, bug fix (#6160)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 87f9010  [ONNX]Mod operator, bug fix (#6160)
87f9010 is described below

commit 87f90107846841eba41409d65e8a77c82c033bf4
Author: Siju Samuel <si...@huawei.com>
AuthorDate: Fri Aug 7 06:27:20 2020 +0530

    [ONNX]Mod operator, bug fix (#6160)
    
    * Onnx mod, bug fix
    
    * Added comment for the mod/floor_mod behaviour difference between numpy & relay
---
 python/tvm/relay/frontend/onnx.py          |  7 ++++++-
 tests/python/frontend/onnx/test_forward.py | 29 +++++++++++++----------------
 2 files changed, 19 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 1568c97..74626d4 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -530,10 +530,15 @@ class Mod(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "Mod op take 2 inputs, {} given".format(len(inputs))
-        if attr['fmod'] == 1:
+
+        # Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod.
+        # attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment.
+        # The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod
+        if attr['fmod'] == 0:
             op_name = "floor_mod"
         else:
             op_name = "mod"
+
         return AttrCvt(op_name)(inputs, {}, params)
 
 
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 56ea96d..14b827c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -2374,17 +2374,11 @@ def test_pooling():
                        auto_pad='SAME_UPPER')
 
 
-def verify_mod(x_shape, y_shape, fmod, dtype='float32'):
-    x_np = np.random.uniform(size=x_shape).astype(dtype)
-    y_np = np.random.uniform(size=y_shape).astype(dtype)
+def verify_mod(x_shape, y_shape, fmod, out_shape, dtype='float32'):
+    x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype)
+    y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype)
     y_np = np.where(y_np==0, 1, y_np) #remove 0's to avoid division by zero error
 
-    if fmod:
-        np_out = np.fmod(x_np, y_np)
-    else:
-        np_out = np.mod(x_np, y_np)
-
-    out_shape = np_out.shape
     mod_node = helper.make_node("Mod",
                                 inputs=["x", "y"],
                                 outputs=["z"],
@@ -2401,22 +2395,25 @@ def verify_mod(x_shape, y_shape, fmod, dtype='float32'):
                                                                     onnx_dtype, list(out_shape))])
     model = helper.make_model(graph, producer_name='mod_test')
 
+    onnx_out = get_onnxruntime_output(model, [x_np, y_np], dtype)[0]
+
     for target, ctx in ctx_list():
         tvm_out = get_tvm_output(
             model, [x_np, y_np], target, ctx, out_shape)
-        tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 def test_mod():
     # Mod
-    verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=0)
-
-    verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, dtype="int32")
+    verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32")
+    verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=0, out_shape=(1, 32, 32, 32), dtype="int32")
 
     # fmod
-    verify_mod(x_shape=[1, 1, 32], y_shape=[1, 32, 32], fmod=1)
-
-    verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, dtype="int32")
+    verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, out_shape=(1, 32, 32), dtype="int32")
+    verify_mod(x_shape=[1, 1, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))
+    verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 1, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))
+    verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32), dtype="int32")
+    verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))
 
 
 def verify_xor(x_shape, y_shape):