You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/02/04 09:21:21 UTC

[tvm] branch main updated: [Relay] Align strided slice shape functions (#10155)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cbf6468  [Relay] Align strided slice shape functions (#10155)
cbf6468 is described below

commit cbf6468c392f5dc39dab2bfb2187490962f18f61
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Feb 4 02:20:58 2022 -0700

    [Relay] Align strided slice shape functions (#10155)
    
    * fix static strided slice shape func for out-of-bounds negative stride slicing
    
    * Trigger CI
    
    * Trigger CI
---
 python/tvm/relay/op/_transform.py | 66 ++++++++++++++++++++++++++-------------
 tests/python/relay/test_any.py    | 56 +++++++++++++++++++++++----------
 2 files changed, 83 insertions(+), 39 deletions(-)

diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index b67579a..25ecf0e 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -217,43 +217,54 @@ def arange_shape_func(attrs, inputs, _):
 
 @script
 def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode):
-    ndim = data_shape.shape[0]
+    ndim = len(data_shape)
     out = output_tensor((ndim,), "int64")
     for i in const_range(ndim):
+        dim_size = int64(data_shape[i])
         cbegin = int64(0)
-        cend = int64(data_shape[i])
+        cend = dim_size
         cstride = int64(1)
+
         if len(strides) > i:
             cstride = int64(strides[i])
+
         if len(begin) > i:
             cbegin = int64(begin[i])
-            if cbegin < 0:
-                cbegin += int64(data_shape[i])
+        elif cstride < 0:
+            cbegin = dim_size
+
         if len(end) <= i:
-            cend = int64(data_shape[i])
+            if cstride < 0:
+                cend = int64(0)
         elif slice_mode != 0:
             cstride = int64(1)
             if end[i] < 0:
-                cend = int64(data_shape[i])
+                cend = dim_size
             else:
                 cend = cbegin + int64(end[i])
         else:
             if end[i] > data_shape[i]:
-                cend = int64(data_shape[i])
-            elif end[i] < -data_shape[i]:
-                cend = int64(-1)
+                cend = dim_size
             else:
                 cend = int64(end[i])
-                if cend < 0:
-                    cend += int64(data_shape[i])
+
         assert cstride != 0, "Strides can't be zero."
+
+        if cbegin < 0:
+            cbegin += dim_size
+        if cend < 0:
+            cend += dim_size
+
         if cstride < 0:
+            if cend < 0:
+                cend = int64(-1)
+            if cbegin > dim_size - 1:
+                cbegin = dim_size - 1
             slice_range = cbegin - cend
             step = -cstride
         else:
             slice_range = cend - cbegin
             step = cstride
-
         out[i] = int64(ceil_div(slice_range, step))
     return out
 
@@ -266,34 +277,45 @@ def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_m
         out[i] = data_shape[i]
 
     for i in const_range(len(axes)):
+        dim_size = int64(data_shape[axes[i]])
         cbegin = int64(0)
-        cend = int64(data_shape[axes[i]])
+        cend = dim_size
         cstride = int64(1)
+
         if len(strides) > i:
             cstride = int64(strides[i])
+
         if len(begin) > i:
             cbegin = int64(begin[i])
-            if cbegin < 0:
-                cbegin += int64(data_shape[axes[i]])
+        elif cstride < 0:
+            cbegin = dim_size
+
         if len(end) <= i:
-            cend = int64(data_shape[axes[i]])
+            cend = dim_size
         elif slice_mode != 0:
             cstride = int64(1)
             if end[i] < 0:
-                cend = int64(data_shape[axes[i]])
+                cend = dim_size
             else:
                 cend = cbegin + int64(end[i])
         else:
             if end[i] > data_shape[i]:
-                cend = int64(data_shape[axes[i]])
-            elif end[i] < -data_shape[i]:
-                cend = int64(-1)
+                cend = dim_size
             else:
                 cend = int64(end[i])
-                if cend < 0:
-                    cend += int64(data_shape[axes[i]])
+
         assert cstride != 0, "Strides can't be zero."
+
+        if cbegin < 0:
+            cbegin += dim_size
+        if cend < 0:
+            cend += dim_size
+
         if cstride < 0:
+            if cend < 0:
+                cend = int64(-1)
+            if cbegin > dim_size - 1:
+                cbegin = dim_size - 1
             slice_range = cbegin - cend
             step = -cstride
         else:
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 97770f5..279507f 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -1286,45 +1286,60 @@ def test_arange_with_dynamic_shape():
     check_result([data], mod, np.array(range(10)).astype("int32") + 1)
 
 
-def verify_any_strided_slice(
-    data_shape,
+def verify_any_random_strided_slice(
     begin_shape,
     end_shape,
     strides_shape,
-    data_np_shape,
+    data_shape,
     slice_mode="end",
     const_attrs=False,
 ):
     # Generate random numpy input data
-    np_data = np.random.uniform(size=data_np_shape).astype("float32")
     np_begin = np.random.randint(2, size=begin_shape, dtype="int32")
     np_end = np.random.randint(5, 10, size=end_shape, dtype="int32")
     np_strides = np.random.randint(
         1, 2 if slice_mode == "size" else 3, size=strides_shape, dtype="int32"
     )
+
+    verify_any_strided_slice(
+        np_begin, np_end, np_strides, data_shape, slice_mode=slice_mode, const_attrs=const_attrs
+    )
+
+
+def verify_any_strided_slice(
+    np_begin,
+    np_end,
+    np_strides,
+    data_shape,
+    axes=None,
+    slice_mode="end",
+    const_attrs=False,
+):
+    np_data = np.random.uniform(size=data_shape).astype("float32")
     # target numpy result
     ref_res = tvm.topi.testing.strided_slice_python(
-        np_data, np_begin, np_end, np_strides, slice_mode
+        np_data, np_begin, np_end, np_strides, slice_mode, axes
     )
 
     # Relay Module
     mod = tvm.IRModule()
-    data = relay.var("data", shape=data_shape, dtype="float32")
+    data = relay.var("data", shape=any_dims(len(data_shape)), dtype="float32")
     if const_attrs:
-        data = relay.var("data", shape=data_shape, dtype="float32")
         begin = relay.const(np_begin)
         end = relay.const(np_end)
         strides = relay.const(np_strides)
         args = [data]
         np_inputs = [np_data]
     else:
-        begin = relay.var("begin", shape=begin_shape, dtype="int32")
-        end = relay.var("end", shape=end_shape, dtype="int32")
-        strides = relay.var("strides", shape=strides_shape, dtype="int32")
+        begin = relay.var("begin", shape=np_begin.shape, dtype="int32")
+        end = relay.var("end", shape=np_end.shape, dtype="int32")
+        strides = relay.var("strides", shape=np_strides.shape, dtype="int32")
         args = [data, begin, end, strides]
         np_inputs = [np_data, np_begin, np_end, np_strides]
 
-    y = relay.strided_slice(data, begin=begin, end=end, strides=strides, slice_mode=slice_mode)
+    y = relay.strided_slice(
+        data, begin=begin, end=end, strides=strides, axes=axes, slice_mode=slice_mode
+    )
     mod["main"] = relay.Function(args, y)
 
     check_result(np_inputs, mod, ref_res)
@@ -1332,12 +1347,19 @@ def verify_any_strided_slice(
 
 @tvm.testing.uses_gpu
 def test_any_strided_slice():
-    verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21))
-    verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21))
-    verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (23, 29, 41))
-    verify_any_strided_slice(any_dims(4), (4,), (4,), (4,), (40, 50, 60, 70))
-    verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21), slice_mode="size")
-    verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21), const_attrs=True)
+    verify_any_random_strided_slice((2,), (2,), (2,), (15, 21))
+    verify_any_random_strided_slice((3,), (3,), (3,), (15, 17, 21))
+    verify_any_random_strided_slice((3,), (3,), (3,), (23, 29, 41))
+    verify_any_random_strided_slice((4,), (4,), (4,), (40, 50, 60, 70))
+    verify_any_random_strided_slice((3,), (3,), (3,), (15, 17, 21), slice_mode="size")
+    verify_any_random_strided_slice((2,), (2,), (2,), (15, 21), const_attrs=True)
+
+    begin = np.array([0, 1000000]).astype("int32")
+    end = np.array([1000000, -1000000]).astype("int32")
+    strides = np.array([1, -1]).astype("int32")
+    verify_any_strided_slice(begin, end, strides, (15, 21), const_attrs=False)
+    verify_any_strided_slice(begin, end, strides, (15, 21), const_attrs=True)
+    verify_any_strided_slice(begin, end, strides, (15, 17, 21), axes=[0, 2], const_attrs=True)
 
 
 @tvm.testing.uses_gpu