You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/02/04 09:21:21 UTC
[tvm] branch main updated: [Relay] Align strided slice shape functions (#10155)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cbf6468 [Relay] Align strided slice shape functions (#10155)
cbf6468 is described below
commit cbf6468c392f5dc39dab2bfb2187490962f18f61
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Feb 4 02:20:58 2022 -0700
[Relay] Align strided slice shape functions (#10155)
* fix static strided slice shape func for out-of-bounds negative stride slicing
* Trigger CI
* Trigger CI
---
python/tvm/relay/op/_transform.py | 66 ++++++++++++++++++++++++++-------------
tests/python/relay/test_any.py | 56 +++++++++++++++++++++++----------
2 files changed, 83 insertions(+), 39 deletions(-)
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index b67579a..25ecf0e 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -217,43 +217,54 @@ def arange_shape_func(attrs, inputs, _):
@script
def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode):
- ndim = data_shape.shape[0]
+ ndim = len(data_shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
+ dim_size = int64(data_shape[i])
cbegin = int64(0)
- cend = int64(data_shape[i])
+ cend = dim_size
cstride = int64(1)
+
if len(strides) > i:
cstride = int64(strides[i])
+
if len(begin) > i:
cbegin = int64(begin[i])
- if cbegin < 0:
- cbegin += int64(data_shape[i])
+ elif cstride < 0:
+ cbegin = dim_size
+
if len(end) <= i:
- cend = int64(data_shape[i])
+ if cstride < 0:
+ cend = int64(0)
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
- cend = int64(data_shape[i])
+ cend = dim_size
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data_shape[i]:
- cend = int64(data_shape[i])
- elif end[i] < -data_shape[i]:
- cend = int64(-1)
+ cend = dim_size
else:
cend = int64(end[i])
- if cend < 0:
- cend += int64(data_shape[i])
+
assert cstride != 0, "Strides can't be zero."
+
+ if cbegin < 0:
+ cbegin += dim_size
+ if cend < 0:
+ cend += dim_size
+
if cstride < 0:
+ if cend < 0:
+ cend = int64(-1)
+ if cbegin > dim_size - 1:
+ cbegin = dim_size - 1
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride
-
out[i] = int64(ceil_div(slice_range, step))
return out
@@ -266,34 +277,45 @@ def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_m
out[i] = data_shape[i]
for i in const_range(len(axes)):
+ dim_size = int64(data_shape[axes[i]])
cbegin = int64(0)
- cend = int64(data_shape[axes[i]])
+ cend = dim_size
cstride = int64(1)
+
if len(strides) > i:
cstride = int64(strides[i])
+
if len(begin) > i:
cbegin = int64(begin[i])
- if cbegin < 0:
- cbegin += int64(data_shape[axes[i]])
+ elif cstride < 0:
+ cbegin = dim_size
+
if len(end) <= i:
- cend = int64(data_shape[axes[i]])
+ cend = dim_size
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
- cend = int64(data_shape[axes[i]])
+ cend = dim_size
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data_shape[i]:
- cend = int64(data_shape[axes[i]])
- elif end[i] < -data_shape[i]:
- cend = int64(-1)
+ cend = dim_size
else:
cend = int64(end[i])
- if cend < 0:
- cend += int64(data_shape[axes[i]])
+
assert cstride != 0, "Strides can't be zero."
+
+ if cbegin < 0:
+ cbegin += dim_size
+ if cend < 0:
+ cend += dim_size
+
if cstride < 0:
+ if cend < 0:
+ cend = int64(-1)
+ if cbegin > dim_size - 1:
+ cbegin = dim_size - 1
slice_range = cbegin - cend
step = -cstride
else:
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 97770f5..279507f 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -1286,45 +1286,60 @@ def test_arange_with_dynamic_shape():
check_result([data], mod, np.array(range(10)).astype("int32") + 1)
-def verify_any_strided_slice(
- data_shape,
+def verify_any_random_strided_slice(
begin_shape,
end_shape,
strides_shape,
- data_np_shape,
+ data_shape,
slice_mode="end",
const_attrs=False,
):
# Generate random numpy input data
- np_data = np.random.uniform(size=data_np_shape).astype("float32")
np_begin = np.random.randint(2, size=begin_shape, dtype="int32")
np_end = np.random.randint(5, 10, size=end_shape, dtype="int32")
np_strides = np.random.randint(
1, 2 if slice_mode == "size" else 3, size=strides_shape, dtype="int32"
)
+
+ verify_any_strided_slice(
+ np_begin, np_end, np_strides, data_shape, slice_mode=slice_mode, const_attrs=const_attrs
+ )
+
+
+def verify_any_strided_slice(
+ np_begin,
+ np_end,
+ np_strides,
+ data_shape,
+ axes=None,
+ slice_mode="end",
+ const_attrs=False,
+):
+ np_data = np.random.uniform(size=data_shape).astype("float32")
# target numpy result
ref_res = tvm.topi.testing.strided_slice_python(
- np_data, np_begin, np_end, np_strides, slice_mode
+ np_data, np_begin, np_end, np_strides, slice_mode, axes
)
# Relay Module
mod = tvm.IRModule()
- data = relay.var("data", shape=data_shape, dtype="float32")
+ data = relay.var("data", shape=any_dims(len(data_shape)), dtype="float32")
if const_attrs:
- data = relay.var("data", shape=data_shape, dtype="float32")
begin = relay.const(np_begin)
end = relay.const(np_end)
strides = relay.const(np_strides)
args = [data]
np_inputs = [np_data]
else:
- begin = relay.var("begin", shape=begin_shape, dtype="int32")
- end = relay.var("end", shape=end_shape, dtype="int32")
- strides = relay.var("strides", shape=strides_shape, dtype="int32")
+ begin = relay.var("begin", shape=np_begin.shape, dtype="int32")
+ end = relay.var("end", shape=np_end.shape, dtype="int32")
+ strides = relay.var("strides", shape=np_strides.shape, dtype="int32")
args = [data, begin, end, strides]
np_inputs = [np_data, np_begin, np_end, np_strides]
- y = relay.strided_slice(data, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
+ y = relay.strided_slice(
+ data, begin=begin, end=end, strides=strides, axes=axes, slice_mode=slice_mode
+ )
mod["main"] = relay.Function(args, y)
check_result(np_inputs, mod, ref_res)
@@ -1332,12 +1347,19 @@ def verify_any_strided_slice(
@tvm.testing.uses_gpu
def test_any_strided_slice():
- verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21))
- verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21))
- verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (23, 29, 41))
- verify_any_strided_slice(any_dims(4), (4,), (4,), (4,), (40, 50, 60, 70))
- verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21), slice_mode="size")
- verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21), const_attrs=True)
+ verify_any_random_strided_slice((2,), (2,), (2,), (15, 21))
+ verify_any_random_strided_slice((3,), (3,), (3,), (15, 17, 21))
+ verify_any_random_strided_slice((3,), (3,), (3,), (23, 29, 41))
+ verify_any_random_strided_slice((4,), (4,), (4,), (40, 50, 60, 70))
+ verify_any_random_strided_slice((3,), (3,), (3,), (15, 17, 21), slice_mode="size")
+ verify_any_random_strided_slice((2,), (2,), (2,), (15, 21), const_attrs=True)
+
+ begin = np.array([0, 1000000]).astype("int32")
+ end = np.array([1000000, -1000000]).astype("int32")
+ strides = np.array([1, -1]).astype("int32")
+ verify_any_strided_slice(begin, end, strides, (15, 21), const_attrs=False)
+ verify_any_strided_slice(begin, end, strides, (15, 21), const_attrs=True)
+ verify_any_strided_slice(begin, end, strides, (15, 17, 21), axes=[0, 2], const_attrs=True)
@tvm.testing.uses_gpu