You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/04/19 22:36:04 UTC

[tvm] branch main updated: [ONNX] Fix more upstream tests (#7842)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d3e6227  [ONNX] Fix more upstream tests (#7842)
d3e6227 is described below

commit d3e622728874c9bdc07b14267f0e37bcbb0b30a8
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Apr 19 16:35:38 2021 -0600

    [ONNX] Fix more upstream tests (#7842)
    
    * fix unsqueeze test
    
    * fix dynamic strided slice with negative indices
    
    * add Shrink importer
    
    * fix selu defaults
    
    * Implement Hardmax
    
    * add a comment to the test
    
    * Fix typo
---
 include/tvm/topi/transform.h                     |  2 +-
 python/tvm/relay/frontend/onnx.py                | 70 +++++++++++++++++++-----
 python/tvm/relay/op/dyn/_transform.py            | 33 +++++++----
 tests/python/frontend/onnx/test_forward.py       | 21 +++----
 tests/python/relay/dyn/test_dynamic_op_level4.py | 18 +++---
 5 files changed, 98 insertions(+), 46 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 3ad2305..114b8f6 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -577,7 +577,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
       [&](const Array<tvm::tir::Var>& indices) {
         Array<PrimExpr> real_indices;
         for (int32_t i = 0; i < src_tensor_dim; ++i) {
-          real_indices.push_back(indices[i] * strides(i) + begin(i));
+          real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1));
         }
         return x(real_indices);
       },
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index ffeb0dd..cc66cd3 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -930,8 +930,8 @@ class Selu(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        alpha = float(attr.get("alpha", 1.6732))
-        gamma = float(attr.get("gamma", 1.0507))
+        alpha = float(attr.get("alpha", 1.67326319217681884765625))
+        gamma = float(attr.get("gamma", 1.05070102214813232421875))
         return _expr.const(gamma) * (
             _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0]))
             + _op.nn.relu(inputs[0])
@@ -948,6 +948,20 @@ class ScaledTanh(OnnxOpConverter):
         return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha)
 
 
+class Shrink(OnnxOpConverter):
+    """Operator converter for Shrink."""
+
+    @classmethod
+    def _impl_v9(cls, inputs, attr, params):
+        x = inputs[0]
+        dtype = infer_type(x).checked_type.dtype
+        lambd = _op.const(attr.get("lambd", 0.5), dtype=dtype)
+        bias = _op.const(attr.get("bias", 0.0), dtype=dtype)
+
+        zeros = _op.zeros_like(x)
+        return _op.where(x < -lambd, x + bias, zeros) + _op.where(x > lambd, x - bias, zeros)
+
+
 class Softsign(OnnxOpConverter):
     """Operator converter for Softsign."""
 
@@ -1146,8 +1160,9 @@ class Unsqueeze(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        for axes in attr["axes"]:
-            inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1)
+        axes = sorted(attr["axes"])
+        for axis in axes:
+            inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1)
         return inputs[0]
 
 
@@ -1545,10 +1560,7 @@ class Softmax(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        # set default value when axis is not set in the model
-        if "axis" not in attr:
-            attr["axis"] = 1
-        axis = attr["axis"]
+        axis = attr.get("axis", 1)
         ndim = len(infer_shape(inputs[0]))
         if axis < 0:
             axis += ndim
@@ -1564,10 +1576,7 @@ class LogSoftmax(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        # set default value when axis is not set in the model
-        if "axis" not in attr:
-            attr["axis"] = 1
-        axis = attr["axis"]
+        axis = attr.get("axis", 1)
         ndim = len(infer_shape(inputs[0]))
         if axis < 0:
             axis += ndim
@@ -1579,6 +1588,40 @@ class LogSoftmax(OnnxOpConverter):
         return x - m - _op.log(s)
 
 
+class Hardmax(OnnxOpConverter):
+    """Operator converter for Hardmax."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        axis = attr.get("axis", 1)
+        ndim = len(infer_shape(inputs[0]))
+        if axis < 0:
+            axis += ndim
+        dtype = infer_type(inputs[0]).checked_type.dtype
+
+        if axis == 0:
+            pre = _op.const([1], "int64")
+        else:
+            pre = _op.prod(
+                _op.strided_slice(shape_of(inputs[0]), [0], [axis], [1]), axis=0, keepdims=True
+            )
+        post = _op.prod(
+            _op.strided_slice(shape_of(inputs[0]), [axis], [2147483647], [1]), axis=0, keepdims=True
+        )
+        newshape = _op.concatenate([pre, post], axis=0)
+        x = _op.reshape(inputs[0], fold_constant(newshape))
+        argmax = _op.argmax(x, axis=1)
+        onehot = _op.one_hot(
+            argmax,
+            _op.const(1.0, dtype),
+            _op.const(0.0, dtype),
+            fold_constant(_op.take(shape_of(x), _op.const([1], "int64"))),
+            1,
+            dtype,
+        )
+        return _op.reshape(onehot, shape_of(inputs[0]))
+
+
 class OneHot(OnnxOpConverter):
     """Operator converter for OneHot."""
 
@@ -2717,7 +2760,8 @@ def _get_convert_map(opset):
         "Softmax": Softmax.get_converter(opset),
         "LogSoftmax": LogSoftmax.get_converter(opset),
         "OneHot": OneHot.get_converter(opset),
-        # 'Hardmax'
+        "Hardmax": Hardmax.get_converter(opset),
+        "Shrink": Shrink.get_converter(opset),
         "Softsign": Softsign.get_converter(opset),
         "Gemm": Gemm.get_converter(opset),
         "MatMul": MatMul.get_converter(opset),
diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py
index a36b562..de8ee08 100644
--- a/python/tvm/relay/op/dyn/_transform.py
+++ b/python/tvm/relay/op/dyn/_transform.py
@@ -151,40 +151,51 @@ def _strided_slice_shape_func_input_data(data_shape, begin, end, strides, slice_
     ndim = len(data_shape)
     out = output_tensor((ndim,), "int64")
     for i in const_range(ndim):
+        dim_size = int64(data_shape[i])
         cbegin = int64(0)
-        cend = int64(data_shape[i])
+        cend = dim_size
         cstride = int64(1)
+
         if strides.shape[0] > i:
             cstride = int64(strides[i])
+
         if begin.shape[0] > i:
             cbegin = int64(begin[i])
-            if cbegin < 0:
-                cbegin += int64(data_shape[i])
+        elif cstride < 0:
+            cbegin = dim_size
+
         if end.shape[0] <= i:
-            cend = int64(data_shape[i])
+            if cstride < 0:
+                cend = int64(0)
         elif slice_mode != 0:
             cstride = int64(1)
             if end[i] < 0:
-                cend = int64(data_shape[i])
+                cend = dim_size
             else:
                 cend = cbegin + int64(end[i])
         else:
             if end[i] > data_shape[i]:
-                cend = int64(data_shape[i])
-            elif end[i] < -data_shape[i]:
-                cend = int64(-1)
+                cend = dim_size
             else:
                 cend = int64(end[i])
-                if cend < 0:
-                    cend += int64(data_shape[i])
+
         assert cstride != 0, "Strides can't be zero."
+
+        if cbegin < 0:
+            cbegin += dim_size
+        if cend < 0:
+            cend += dim_size
+
         if cstride < 0:
+            if cend < 0:
+                cend = int64(-1)
+            if cbegin > dim_size - 1:
+                cbegin = dim_size - 1
             slice_range = cbegin - cend
             step = -cstride
         else:
             slice_range = cend - cbegin
             step = cstride
-
         out[i] = int64(ceil_div(slice_range, step))
     return out
 
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 6d22b5a..595a3b1 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4128,6 +4128,14 @@ def test_cumsum():
     verify_cumsum(data, 1, 1, 1, type="int32")
 
 
+"""
+  The following parameterized tests loads the tests that ONNX ships as
+  serialized ONNX files, inputs, and outputs. The goal of this test
+  is to ensure the ONNX importer is in line with the ONNX specification.
+  To allow these tests to run in CI before all pass, a number of tests that
+  are not yet supported are skipped.
+"""
+
 from onnx import numpy_helper
 
 f = onnx.__file__
@@ -4159,13 +4167,6 @@ unsupported_onnx_tests = [
     "test_eyelike_populate_off_main_diagonal/",
     "test_eyelike_with_dtype/",
     "test_eyelike_without_dtype/",
-    "test_hardmax_axis_0/",
-    "test_hardmax_axis_1/",
-    "test_hardmax_axis_2/",
-    "test_hardmax_default_axis/",
-    "test_hardmax_example/",
-    "test_hardmax_negative_axis/",
-    "test_hardmax_one_hot/",
     "test_isinf_negative/",
     "test_isinf_positive/",
     "test_matmulinteger/",
@@ -4209,13 +4210,8 @@ unsupported_onnx_tests = [
     "test_scan9_sum/",
     "test_scan_sum/",
     "test_scatternd/",
-    "test_selu_default/",
-    "test_shrink_hard/",
-    "test_shrink_soft/",
     "test_simple_rnn_defaults/",
     "test_simple_rnn_with_initial_bias/",
-    "test_slice_neg_steps/",
-    "test_slice_start_out_of_bounds/",
     "test_strnormalizer_export_monday_casesensintive_lower/",
     "test_strnormalizer_export_monday_casesensintive_nochangecase/",
     "test_strnormalizer_export_monday_casesensintive_upper/",
@@ -4235,7 +4231,6 @@ unsupported_onnx_tests = [
     "test_unique_sorted_with_axis_3d/",
     "test_unique_sorted_with_negative_axis/",
     "test_unique_sorted_without_axis/",
-    "test_unsqueeze_unsorted_axes/",
     "test_upsample_nearest/",
 ]
 
diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py
index 3cb7064..43e5beb 100644
--- a/tests/python/relay/dyn/test_dynamic_op_level4.py
+++ b/tests/python/relay/dyn/test_dynamic_op_level4.py
@@ -39,18 +39,19 @@ def test_dynamic_strided_slice():
         # target numpy result
         x_data = np.random.uniform(size=dshape).astype("float32")
         ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode)
-        data = [x_data, np.array(begin), np.array(end)]
-
-        begin = relay.const(begin, dtype=dtype)
-        end = relay.const(end, dtype=dtype)
+        data = [x_data, np.array(begin, dtype=dtype), np.array(end, dtype=dtype)]
 
+        begin = relay.var("begin", shape=[len(begin)], dtype=dtype)
+        end = relay.var("end", shape=[len(end)], dtype=dtype)
+        inputs = [x, begin, end]
         if strides:
-            data.append(np.array(strides))
-            strides = relay.const(strides, dtype=dtype)
+            data.append(np.array(strides, dtype=dtype))
+            strides = relay.var("strides", shape=[len(strides)], dtype=dtype)
+            inputs.append(strides)
             z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
         else:
             z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode)
-        func = relay.Function([x], z)
+        func = relay.Function(inputs, z)
 
         func = run_infer_type(func)
         text = func.astext()
@@ -60,7 +61,7 @@ def test_dynamic_strided_slice():
         for target, dev in tvm.testing.enabled_targets():
             mod = tvm.ir.IRModule.from_expr(func)
             intrp = relay.create_executor("vm", mod=mod, device=dev, target=target)
-            op_res = intrp.evaluate()(x_data)
+            op_res = intrp.evaluate()(*data)
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
 
     verify(
@@ -79,6 +80,7 @@ def test_dynamic_strided_slice():
     verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
     verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
     verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
+    verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2], (19, 3, 2))
     verify(
         (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
     )