You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/12/06 08:51:25 UTC

(tvm) branch unity updated: [Unity][Bugfix] Fix `tests/python/topi/test_topi_transform.py::test_relax_dynamic_strided_slice` (#16205)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4e8c975700 [Unity][Bugfix] Fix `tests/python/topi/test_topi_transform.py::test_relax_dynamic_strided_slice` (#16205)
4e8c975700 is described below

commit 4e8c97570012c548636c5a4fdaff862a3a772763
Author: Sunghyun Park <su...@umich.edu>
AuthorDate: Wed Dec 6 00:51:16 2023 -0800

    [Unity][Bugfix] Fix `tests/python/topi/test_topi_transform.py::test_relax_dynamic_strided_slice` (#16205)
    
    * fix
    
    * fix
---
 tests/python/topi/test_topi_transform.py | 21 +++++----------------
 1 file changed, 5 insertions(+), 16 deletions(-)

diff --git a/tests/python/topi/test_topi_transform.py b/tests/python/topi/test_topi_transform.py
index 862f4a66ed..575e7aa450 100644
--- a/tests/python/topi/test_topi_transform.py
+++ b/tests/python/topi/test_topi_transform.py
@@ -467,8 +467,6 @@ def verify_relax_dynamic_strided_slice(in_shape, begin, end, strides, output_sha
 
     B = topi.dynamic_strided_slice(A, Begin, End, Strides, output_shape) + 1
 
-    OutShape = topi.shape_func_dynamic_strided_slice(A, Begin, End, Strides)
-
     def check_device(target):
         dev = tvm.device(target, 0)
         if not tvm.testing.device_enabled(target):
@@ -478,27 +476,18 @@ def verify_relax_dynamic_strided_slice(in_shape, begin, end, strides, output_sha
         x_np = np.random.uniform(size=in_shape).astype(A.dtype)
         out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1
         data_nd = tvm.nd.array(x_np, dev)
-        out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype)
+        tvm_out = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype)
         begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev)
         end_nd = tvm.nd.array(np.array(end).astype("int64"), dev)
         strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev)
 
-        if target == "llvm":
-            # Check shape func
-            s = tvm.te.create_schedule(OutShape.op)
-            bar = tvm.build(
-                s, [A, Begin, End, Strides, OutShape], target, name="shape_func_stride_slice"
-            )
-            out_shape_nd = tvm.nd.empty((len(out_npy.shape),), device=dev, dtype="int64")
-            bar(data_nd, begin_nd, end_nd, strides_nd, out_shape_nd)
-
-            tvm.testing.assert_allclose(out_shape_nd.numpy(), output_shape)
-
         with tvm.target.Target(target):
             s = tvm.topi.testing.get_injective_schedule(target)(B)
         foo = tvm.build(s, [A, Begin, End, Strides, B], target, name="stride_slice")
-        foo(data_nd, begin_nd, end_nd, strides_nd, out_nd)
-        tvm.testing.assert_allclose(out_nd.numpy(), out_npy)
+        foo(data_nd, begin_nd, end_nd, strides_nd, tvm_out)
+        tvm_out_npy = tvm_out.numpy()
+        assert out_npy.shape == tvm_out_npy.shape
+        tvm.testing.assert_allclose(tvm_out_npy, out_npy)
 
     for target in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
         check_device(target)