You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/12/06 08:51:25 UTC
(tvm) branch unity updated: [Unity][Bugfix] Fix `tests/python/topi/test_topi_transform.py::test_relax_dynamic_strided_slice` (#16205)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4e8c975700 [Unity][Bugfix] Fix `tests/python/topi/test_topi_transform.py::test_relax_dynamic_strided_slice` (#16205)
4e8c975700 is described below
commit 4e8c97570012c548636c5a4fdaff862a3a772763
Author: Sunghyun Park <su...@umich.edu>
AuthorDate: Wed Dec 6 00:51:16 2023 -0800
[Unity][Bugfix] Fix `tests/python/topi/test_topi_transform.py::test_relax_dynamic_strided_slice` (#16205)
* fix
* fix
---
tests/python/topi/test_topi_transform.py | 21 +++++----------------
1 file changed, 5 insertions(+), 16 deletions(-)
diff --git a/tests/python/topi/test_topi_transform.py b/tests/python/topi/test_topi_transform.py
index 862f4a66ed..575e7aa450 100644
--- a/tests/python/topi/test_topi_transform.py
+++ b/tests/python/topi/test_topi_transform.py
@@ -467,8 +467,6 @@ def verify_relax_dynamic_strided_slice(in_shape, begin, end, strides, output_sha
B = topi.dynamic_strided_slice(A, Begin, End, Strides, output_shape) + 1
- OutShape = topi.shape_func_dynamic_strided_slice(A, Begin, End, Strides)
-
def check_device(target):
dev = tvm.device(target, 0)
if not tvm.testing.device_enabled(target):
@@ -478,27 +476,18 @@ def verify_relax_dynamic_strided_slice(in_shape, begin, end, strides, output_sha
x_np = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1
data_nd = tvm.nd.array(x_np, dev)
- out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype)
+ tvm_out = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype)
begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev)
end_nd = tvm.nd.array(np.array(end).astype("int64"), dev)
strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev)
- if target == "llvm":
- # Check shape func
- s = tvm.te.create_schedule(OutShape.op)
- bar = tvm.build(
- s, [A, Begin, End, Strides, OutShape], target, name="shape_func_stride_slice"
- )
- out_shape_nd = tvm.nd.empty((len(out_npy.shape),), device=dev, dtype="int64")
- bar(data_nd, begin_nd, end_nd, strides_nd, out_shape_nd)
-
- tvm.testing.assert_allclose(out_shape_nd.numpy(), output_shape)
-
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(B)
foo = tvm.build(s, [A, Begin, End, Strides, B], target, name="stride_slice")
- foo(data_nd, begin_nd, end_nd, strides_nd, out_nd)
- tvm.testing.assert_allclose(out_nd.numpy(), out_npy)
+ foo(data_nd, begin_nd, end_nd, strides_nd, tvm_out)
+ tvm_out_npy = tvm_out.numpy()
+ assert out_npy.shape == tvm_out_npy.shape
+ tvm.testing.assert_allclose(tvm_out_npy, out_npy)
for target in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(target)