You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/04/19 22:36:04 UTC
[tvm] branch main updated: [ONNX] Fix more upstream tests (#7842)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d3e6227 [ONNX] Fix more upstream tests (#7842)
d3e6227 is described below
commit d3e622728874c9bdc07b14267f0e37bcbb0b30a8
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Apr 19 16:35:38 2021 -0600
[ONNX] Fix more upstream tests (#7842)
* fix unsqueeze test
* fix dynamic strided slice with negative indices
* add Shrink importer
* fix selu defaults
* Implement Hardmax
* add a comment to the test
* Fix typo
---
include/tvm/topi/transform.h | 2 +-
python/tvm/relay/frontend/onnx.py | 70 +++++++++++++++++++-----
python/tvm/relay/op/dyn/_transform.py | 33 +++++++----
tests/python/frontend/onnx/test_forward.py | 21 +++----
tests/python/relay/dyn/test_dynamic_op_level4.py | 18 +++---
5 files changed, 98 insertions(+), 46 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 3ad2305..114b8f6 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -577,7 +577,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (int32_t i = 0; i < src_tensor_dim; ++i) {
- real_indices.push_back(indices[i] * strides(i) + begin(i));
+ real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1));
}
return x(real_indices);
},
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index ffeb0dd..cc66cd3 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -930,8 +930,8 @@ class Selu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
- alpha = float(attr.get("alpha", 1.6732))
- gamma = float(attr.get("gamma", 1.0507))
+ alpha = float(attr.get("alpha", 1.67326319217681884765625))
+ gamma = float(attr.get("gamma", 1.05070102214813232421875))
return _expr.const(gamma) * (
_expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0]))
+ _op.nn.relu(inputs[0])
@@ -948,6 +948,20 @@ class ScaledTanh(OnnxOpConverter):
return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha)
+class Shrink(OnnxOpConverter):
+ """Operator converter for Shrink."""
+
+ @classmethod
+ def _impl_v9(cls, inputs, attr, params):
+ x = inputs[0]
+ dtype = infer_type(x).checked_type.dtype
+ lambd = _op.const(attr.get("lambd", 0.5), dtype=dtype)
+ bias = _op.const(attr.get("bias", 0.0), dtype=dtype)
+
+ zeros = _op.zeros_like(x)
+ return _op.where(x < -lambd, x + bias, zeros) + _op.where(x > lambd, x - bias, zeros)
+
+
class Softsign(OnnxOpConverter):
"""Operator converter for Softsign."""
@@ -1146,8 +1160,9 @@ class Unsqueeze(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
- for axes in attr["axes"]:
- inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1)
+ axes = sorted(attr["axes"])
+ for axis in axes:
+ inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1)
return inputs[0]
@@ -1545,10 +1560,7 @@ class Softmax(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
- # set default value when axis is not set in the model
- if "axis" not in attr:
- attr["axis"] = 1
- axis = attr["axis"]
+ axis = attr.get("axis", 1)
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
@@ -1564,10 +1576,7 @@ class LogSoftmax(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
- # set default value when axis is not set in the model
- if "axis" not in attr:
- attr["axis"] = 1
- axis = attr["axis"]
+ axis = attr.get("axis", 1)
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
@@ -1579,6 +1588,40 @@ class LogSoftmax(OnnxOpConverter):
return x - m - _op.log(s)
+class Hardmax(OnnxOpConverter):
+ """Operator converter for Hardmax."""
+
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ axis = attr.get("axis", 1)
+ ndim = len(infer_shape(inputs[0]))
+ if axis < 0:
+ axis += ndim
+ dtype = infer_type(inputs[0]).checked_type.dtype
+
+ if axis == 0:
+ pre = _op.const([1], "int64")
+ else:
+ pre = _op.prod(
+ _op.strided_slice(shape_of(inputs[0]), [0], [axis], [1]), axis=0, keepdims=True
+ )
+ post = _op.prod(
+ _op.strided_slice(shape_of(inputs[0]), [axis], [2147483647], [1]), axis=0, keepdims=True
+ )
+ newshape = _op.concatenate([pre, post], axis=0)
+ x = _op.reshape(inputs[0], fold_constant(newshape))
+ argmax = _op.argmax(x, axis=1)
+ onehot = _op.one_hot(
+ argmax,
+ _op.const(1.0, dtype),
+ _op.const(0.0, dtype),
+ fold_constant(_op.take(shape_of(x), _op.const([1], "int64"))),
+ 1,
+ dtype,
+ )
+ return _op.reshape(onehot, shape_of(inputs[0]))
+
+
class OneHot(OnnxOpConverter):
"""Operator converter for OneHot."""
@@ -2717,7 +2760,8 @@ def _get_convert_map(opset):
"Softmax": Softmax.get_converter(opset),
"LogSoftmax": LogSoftmax.get_converter(opset),
"OneHot": OneHot.get_converter(opset),
- # 'Hardmax'
+ "Hardmax": Hardmax.get_converter(opset),
+ "Shrink": Shrink.get_converter(opset),
"Softsign": Softsign.get_converter(opset),
"Gemm": Gemm.get_converter(opset),
"MatMul": MatMul.get_converter(opset),
diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py
index a36b562..de8ee08 100644
--- a/python/tvm/relay/op/dyn/_transform.py
+++ b/python/tvm/relay/op/dyn/_transform.py
@@ -151,40 +151,51 @@ def _strided_slice_shape_func_input_data(data_shape, begin, end, strides, slice_
ndim = len(data_shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
+ dim_size = int64(data_shape[i])
cbegin = int64(0)
- cend = int64(data_shape[i])
+ cend = dim_size
cstride = int64(1)
+
if strides.shape[0] > i:
cstride = int64(strides[i])
+
if begin.shape[0] > i:
cbegin = int64(begin[i])
- if cbegin < 0:
- cbegin += int64(data_shape[i])
+ elif cstride < 0:
+ cbegin = dim_size
+
if end.shape[0] <= i:
- cend = int64(data_shape[i])
+ if cstride < 0:
+ cend = int64(0)
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
- cend = int64(data_shape[i])
+ cend = dim_size
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data_shape[i]:
- cend = int64(data_shape[i])
- elif end[i] < -data_shape[i]:
- cend = int64(-1)
+ cend = dim_size
else:
cend = int64(end[i])
- if cend < 0:
- cend += int64(data_shape[i])
+
assert cstride != 0, "Strides can't be zero."
+
+ if cbegin < 0:
+ cbegin += dim_size
+ if cend < 0:
+ cend += dim_size
+
if cstride < 0:
+ if cend < 0:
+ cend = int64(-1)
+ if cbegin > dim_size - 1:
+ cbegin = dim_size - 1
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride
-
out[i] = int64(ceil_div(slice_range, step))
return out
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 6d22b5a..595a3b1 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4128,6 +4128,14 @@ def test_cumsum():
verify_cumsum(data, 1, 1, 1, type="int32")
+"""
+ The following parameterized tests loads the tests that ONNX ships as
+ serialized ONNX files, inputs, and outputs. The goal of this test
+ is to ensure the ONNX importer is in line with the ONNX specification.
+ To allow these tests to run in CI before all pass, a number of tests that
+ are not yet supported are skipped.
+"""
+
from onnx import numpy_helper
f = onnx.__file__
@@ -4159,13 +4167,6 @@ unsupported_onnx_tests = [
"test_eyelike_populate_off_main_diagonal/",
"test_eyelike_with_dtype/",
"test_eyelike_without_dtype/",
- "test_hardmax_axis_0/",
- "test_hardmax_axis_1/",
- "test_hardmax_axis_2/",
- "test_hardmax_default_axis/",
- "test_hardmax_example/",
- "test_hardmax_negative_axis/",
- "test_hardmax_one_hot/",
"test_isinf_negative/",
"test_isinf_positive/",
"test_matmulinteger/",
@@ -4209,13 +4210,8 @@ unsupported_onnx_tests = [
"test_scan9_sum/",
"test_scan_sum/",
"test_scatternd/",
- "test_selu_default/",
- "test_shrink_hard/",
- "test_shrink_soft/",
"test_simple_rnn_defaults/",
"test_simple_rnn_with_initial_bias/",
- "test_slice_neg_steps/",
- "test_slice_start_out_of_bounds/",
"test_strnormalizer_export_monday_casesensintive_lower/",
"test_strnormalizer_export_monday_casesensintive_nochangecase/",
"test_strnormalizer_export_monday_casesensintive_upper/",
@@ -4235,7 +4231,6 @@ unsupported_onnx_tests = [
"test_unique_sorted_with_axis_3d/",
"test_unique_sorted_with_negative_axis/",
"test_unique_sorted_without_axis/",
- "test_unsqueeze_unsorted_axes/",
"test_upsample_nearest/",
]
diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py
index 3cb7064..43e5beb 100644
--- a/tests/python/relay/dyn/test_dynamic_op_level4.py
+++ b/tests/python/relay/dyn/test_dynamic_op_level4.py
@@ -39,18 +39,19 @@ def test_dynamic_strided_slice():
# target numpy result
x_data = np.random.uniform(size=dshape).astype("float32")
ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode)
- data = [x_data, np.array(begin), np.array(end)]
-
- begin = relay.const(begin, dtype=dtype)
- end = relay.const(end, dtype=dtype)
+ data = [x_data, np.array(begin, dtype=dtype), np.array(end, dtype=dtype)]
+ begin = relay.var("begin", shape=[len(begin)], dtype=dtype)
+ end = relay.var("end", shape=[len(end)], dtype=dtype)
+ inputs = [x, begin, end]
if strides:
- data.append(np.array(strides))
- strides = relay.const(strides, dtype=dtype)
+ data.append(np.array(strides, dtype=dtype))
+ strides = relay.var("strides", shape=[len(strides)], dtype=dtype)
+ inputs.append(strides)
z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
else:
z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode)
- func = relay.Function([x], z)
+ func = relay.Function(inputs, z)
func = run_infer_type(func)
text = func.astext()
@@ -60,7 +61,7 @@ def test_dynamic_strided_slice():
for target, dev in tvm.testing.enabled_targets():
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor("vm", mod=mod, device=dev, target=target)
- op_res = intrp.evaluate()(x_data)
+ op_res = intrp.evaluate()(*data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
verify(
@@ -79,6 +80,7 @@ def test_dynamic_strided_slice():
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
+ verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2], (19, 3, 2))
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
)