You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/05/03 04:02:41 UTC

[tvm] branch main updated: [ONNX] More Unit Tests! (#7956)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c380a69  [ONNX] More Unit Tests! (#7956)
c380a69 is described below

commit c380a699e2c033a3d3356f3a32094c29c7c8f27d
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Sun May 2 22:02:29 2021 -0600

    [ONNX] More Unit Tests! (#7956)
    
    * support same lower and maxpool in autopad
    
    * fix isinf tests
    
    * lower tolerance on roialign test becuase the onnx result is cropped to 4 decimal places
    
    * slow support for bottom-k
    
    * throw with nullptr in gathernd and scatternd, fix typo
    
    * fix lint
    
    * fix a copy typo
---
 python/tvm/relay/frontend/onnx.py          | 98 +++++++++++++++++++++++++++---
 src/relay/op/tensor/transform.cc           |  4 ++
 tests/python/frontend/onnx/test_forward.py | 16 ++---
 3 files changed, 100 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index deb2948..1f5ac44 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -276,6 +276,7 @@ class Pool(OnnxOpConverter):
     def _impl_v1(cls, inputs, attr, params):
         data = inputs[0]
         input_shape = infer_shape(data)
+        input_dtype = infer_type(data).checked_type.dtype
         ndim = len(input_shape)
         if "auto_pad" in attr:
             attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
@@ -293,7 +294,19 @@ class Pool(OnnxOpConverter):
                 else:
                     # Warning: Pool does not yet support dynamic shapes,
                     # one will need to run dynamic_to_static on this model after import
-                    data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim)
+                    if "int" in input_dtype:
+                        pad_val = np.iinfo(np.dtype(input_dtype)).min
+                    else:
+                        pad_val = np.finfo(np.dtype(input_dtype)).min
+                    data = autopad(
+                        data,
+                        attr.get("strides", [1] * (ndim - 2)),
+                        attr["kernel_shape"],
+                        [1] * ndim,
+                        ndim,
+                        pad_value=pad_val,
+                        mode=attr["auto_pad"],
+                    )
             elif attr["auto_pad"] == "VALID":
                 attr["pads"] = tuple([0 for i in range(ndim - 2)])
             elif attr["auto_pad"] == "NOTSET":
@@ -356,7 +369,17 @@ class InstanceNorm(OnnxOpConverter):
         return AttrCvt(op_name="instance_norm")(inputs, attr, params)
 
 
-def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", deconv=False):
+def autopad(
+    data,
+    strides,
+    kernel_shape,
+    dilations,
+    ndim,
+    pad_type="constant",
+    deconv=False,
+    mode="SAME_UPPER",
+    pad_value=0.0,
+):
     """
     Perform autopadding with dynamic input shapes
     """
@@ -391,14 +414,19 @@ def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", d
     pad_after = total_pad - pad_before
 
     # combine
-    pad = _op.concatenate(
-        [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1
-    )
+    if "LOWER" in mode:
+        pad = _op.concatenate(
+            [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1
+        )
+    else:
+        pad = _op.concatenate(
+            [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1
+        )
 
     # pad N and C with zeros
     pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0)
 
-    return _op.nn.pad(data, fold_constant(pad), _op.const(0.0), pad_type)
+    return _op.nn.pad(data, fold_constant(pad), _op.const(pad_value), pad_type)
 
 
 class Conv(OnnxOpConverter):
@@ -427,6 +455,7 @@ class Conv(OnnxOpConverter):
                     attr["kernel_shape"],
                     attr.get("dilations", [1] * (ndim - 2)),
                     ndim,
+                    mode=attr["auto_pad"],
                 )
             elif attr["auto_pad"] == "VALID":
                 attr["pads"] = tuple([0 for i in range(ndim - 2)])
@@ -485,6 +514,7 @@ class ConvTranspose(OnnxOpConverter):
                     attr.get("dilations", [1] * (ndim - 2)),
                     ndim,
                     deconv=True,
+                    mode=attr["auto_pad"],
                 )
             elif attr["auto_pad"] == "VALID":
                 attr["pads"] = tuple([0 for i in range(ndim - 2)])
@@ -757,7 +787,14 @@ class LpPool(OnnxOpConverter):
             if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
                 # Warning: LpPool does not yet support dynamic shapes,
                 # one will need to run dynamic_to_static on this model after import
-                data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim)
+                data = autopad(
+                    data,
+                    attr["strides"],
+                    attr["kernel_shape"],
+                    [1] * ndim,
+                    ndim,
+                    mode=attr["auto_pad"],
+                )
             elif attr["auto_pad"] == "VALID":
                 attr["pads"] = tuple([0 for i in range(ndim - 2)])
             elif attr["auto_pad"] == "NOTSET":
@@ -1377,7 +1414,7 @@ class Scatter(OnnxOpConverter):
 
 
 class ScatterND(OnnxOpConverter):
-    """Operator converter for Scatter."""
+    """Operator converter for ScatterND."""
 
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
@@ -2228,7 +2265,32 @@ class TopK(OnnxOpConverter):
         largest = attr.get("largest", 1)
 
         if largest == 0:
-            raise NotImplementedError("TVM only supports finding TopK largest elements")
+            # TODO(mbrookhart): optimize this by adding a smallest attribute to topi if this
+            # ever becomes a bottleneck
+            ndim = len(infer_shape(inputs[0]))
+            if axis < 0:
+                axis += ndim
+            sort = _op.sort(inputs[0], axis=axis)
+            argsort = _op.argsort(inputs[0], axis=axis, dtype="int64")
+            begin = [0] * ndim
+            stride = [1] * ndim
+            end = _op.concatenate(
+                [
+                    _op.const([np.iinfo(np.int64).max] * axis, dtype="int64"),
+                    inputs[1],
+                    _op.const([np.iinfo(np.int64).max] * (ndim - axis - 1), dtype="int64"),
+                ],
+                axis=0,
+            )
+            return _expr.TupleWrapper(
+                _expr.Tuple(
+                    [
+                        _op.strided_slice(sort, begin, end, stride),
+                        _op.strided_slice(argsort, begin, end, stride),
+                    ]
+                ),
+                2,
+            )
 
         return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64")
 
@@ -2246,6 +2308,22 @@ class Range(OnnxOpConverter):
         )
 
 
+class IsInf(OnnxOpConverter):
+    """Operator converter for IsInf"""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        detect_negative = attr.get("detect_negative", 1)
+        detect_positive = attr.get("detect_positive", 1)
+        dtype = infer_type(inputs[0]).checked_type.dtype
+        isinf = _op.isinf(inputs[0])
+        if not detect_negative:
+            isinf = isinf * (inputs[0] > _op.const(0, dtype))
+        if not detect_positive:
+            isinf = isinf * (inputs[0] < _op.const(0, dtype))
+        return isinf
+
+
 class MaxRoiPool(OnnxOpConverter):
     """Operator converter for MaxRoiPool."""
 
@@ -2789,7 +2867,7 @@ def _get_convert_map(opset):
         "Floor": Renamer("floor"),
         "Ceil": Renamer("ceil"),
         "Round": Renamer("round"),
-        "IsInf": Renamer("isinf"),
+        "IsInf": IsInf.get_converter(opset),
         "IsNaN": Renamer("isnan"),
         "Sqrt": Renamer("sqrt"),
         "Relu": Renamer("relu"),
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 02cdb21..120d648 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1122,6 +1122,8 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   const auto out_shape = data->shape;
   const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
+  ICHECK(mdim) << "ScatterND needs a static shape for the first axis of indices, got "
+               << indices->shape;
   const size_t kdim = indices->shape.size() - 1;
   const size_t ndim = out_shape.size();
   ICHECK_LE(size_t(mdim->value), ndim)
@@ -3331,6 +3333,8 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
   const size_t ndim = data->shape.size();
   const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
+  ICHECK(mdim) << "GatherND needs a static shape for the first axis of indices, got "
+               << indices->shape;
   const size_t kdim = indices->shape.size() - 1;
   ICHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy.";
 
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index e11689c..c25dc01 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4209,12 +4209,8 @@ unsupported_onnx_tests = [
     "test_eyelike_populate_off_main_diagonal/",
     "test_eyelike_with_dtype/",
     "test_eyelike_without_dtype/",
-    "test_isinf_negative/",
-    "test_isinf_positive/",
     "test_matmulinteger/",
     "test_maxpool_2d_dilations/",
-    "test_maxpool_2d_same_lower/",
-    "test_maxpool_2d_same_upper/",
     "test_maxpool_with_argmax_2d_precomputed_pads/",
     "test_maxpool_with_argmax_2d_precomputed_strides/",
     "test_maxunpool_export_with_output_shape/",
@@ -4233,7 +4229,6 @@ unsupported_onnx_tests = [
     "test_reversesequence_batch/",
     "test_reversesequence_time/",
     "test_rnn_seq_length/",
-    "test_roialign/",
     "test_round/",
     "test_scan9_sum/",
     "test_scan_sum/",
@@ -4252,7 +4247,6 @@ unsupported_onnx_tests = [
     "test_tfidfvectorizer_tf_onlybigrams_levelempty/",
     "test_tfidfvectorizer_tf_onlybigrams_skip5/",
     "test_tfidfvectorizer_tf_uniandbigrams_skip5/",
-    "test_top_k_smallest/",
     "test_unique_not_sorted_without_axis/",
     "test_unique_sorted_with_axis/",
     "test_unique_sorted_with_axis_3d/",
@@ -4268,6 +4262,12 @@ def test_onnx_nodes(test):
         if failure in test:
             pytest.skip()
             break
+    atol = 1e-5
+    rtol = 1e-5
+    if "roialign" in test:
+        # for some reason the ONNX test crops the
+        # roialign results to 4 decimal places
+        atol = 1e-4
     onnx_model = onnx.load(test + "/model.onnx")
     inputs = []
     outputs = []
@@ -4285,10 +4285,10 @@ def test_onnx_nodes(test):
                 raise ImportError(str(tensor) + " not labeled as an import or an output")
         tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0))
         if len(outputs) == 1:
-            tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol)
         else:
             for output, val in zip(outputs, tvm_val):
-                tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5)
+                tvm.testing.assert_allclose(output, val, rtol=rtol, atol=atol)
 
 
 def test_wrong_input():