You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/19 16:49:19 UTC

[incubator-tvm] branch master updated: [Relay][Frontend][ONNX] operator support NonZero (#5073)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new e1ebf06  [Relay][Frontend][ONNX] operator support NonZero (#5073)
e1ebf06 is described below

commit e1ebf062a76586b7ef261ee9e49bae3e9910d034
Author: Neo Chien <cc...@cs.ccu.edu.tw>
AuthorDate: Fri Mar 20 00:49:07 2020 +0800

    [Relay][Frontend][ONNX] operator support NonZero (#5073)
    
    * [Relay][Frontend][ONNX] operator support: NonZero
    
    * update
    
    * Solve the build fail
    
    * solve the build fail
    
    * Replace ctx_list with tvm.cpu()
---
 python/tvm/relay/frontend/onnx.py          | 13 +++++++
 tests/python/frontend/onnx/test_forward.py | 59 +++++++++++++++++++++++++++---
 2 files changed, 66 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index e1b0a7f..beb8e85 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1444,6 +1444,18 @@ class Resize(OnnxOpConverter):
         return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)
 
 
+class NonZero(OnnxOpConverter):
+    """Operator converter for NonZero
+    """
+    @classmethod
+    def _impl_v9(cls, inputs, attr, params):
+        if len(inputs) > 1:
+            raise ValueError("Expect 1 input only")
+
+        output = AttrCvt(op_name='argwhere')(inputs, attr, params)
+        return _op.transpose(output, axes=(1, 0))
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1573,6 +1585,7 @@ def _get_convert_map(opset):
         'Where': Where.get_converter(opset),
         'Or': Or.get_converter(opset),
         'Resize': Resize.get_converter(opset),
+        'NonZero': NonZero.get_converter(opset),
     }
 
 
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 44696f5..917ec99 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -30,21 +30,38 @@ from tvm.relay.testing.config import ctx_list
 import scipy
 
 
-def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
-    """ Generic function to execute and get tvm output"""
-    target = 'llvm'
+def get_input_data_shape_dict(graph_def, input_data):
     if isinstance(input_data, list):
         input_names = {}
         shape_dict = {}
-        dtype_dict = {}
         for i, _ in enumerate(input_data):
             input_names[i] = graph_def.graph.input[i].name
             shape_dict[input_names[i]] = input_data[i].shape
-            dtype_dict[input_names[i]] = input_data[i].dtype
     else:
         input_names = graph_def.graph.input[0].name
         shape_dict = {input_names: input_data.shape}
-        dtype_dict = {input_names: input_data.dtype}
+
+    return input_names, shape_dict
+
+
+def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None):
+    """ Generic function to execute and get tvm output with vm executor"""
+
+    _, shape_dict = get_input_data_shape_dict(graph_def, input_data)
+
+    mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
+
+    ex = relay.create_executor('vm', mod=mod, ctx=ctx, target=target)
+    indata = tvm.nd.array(input_data)
+    result = ex.evaluate()(indata)
+    return result.asnumpy()
+
+
+def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
+    """ Generic function to execute and get tvm output"""
+    target = 'llvm'
+
+    input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)
 
     mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
     with relay.build_config(opt_level=1):
@@ -2209,6 +2226,35 @@ def test_resize():
     verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel")
 
 
+def test_nonzero():
+
+    def verify_nonzero(indata, outdata, dtype):
+        node = helper.make_node('NonZero',
+                                inputs=['X'],
+                                outputs=['Y'],)
+
+        graph = helper.make_graph([node],
+                                  "nonzero_test",
+                                  inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))],
+                                  outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, list(outdata.shape))])
+
+        model = helper.make_model(graph, producer_name='nonzero_test')
+
+        onnx_out = get_onnxruntime_output(model, indata, dtype)
+
+        for target, ctx in [('llvm', tvm.cpu())]:
+            tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9)
+            tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
+
+    input_data = np.array([[1, 0], [1, 1]], dtype=np.int64)
+    result = np.array((np.nonzero(input_data)))  # expected output [[0, 1, 1], [0, 0, 1]]
+    verify_nonzero(input_data, result, dtype=np.int64)
+
+    input_data = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]], dtype=np.int64)
+    result = np.array((np.nonzero(input_data)))  # expected output [[0, 1, 2, 2], [0, 1, 0, 1]]
+    verify_nonzero(input_data, result, dtype=np.int64)
+
+
 if __name__ == '__main__':
     test_flatten()
     test_reshape()
@@ -2269,3 +2315,4 @@ if __name__ == '__main__':
     test_pooling()
     test_lstm()
     test_resize()
+    test_nonzero()