You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/12 20:35:00 UTC
[tvm] branch main updated: [Relay][Op][Bug] Fix missing return in
scatter_nd cuda strategy (#7447)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b36bdf6 [Relay][Op][Bug] Fix missing return in scatter_nd cuda strategy (#7447)
b36bdf6 is described below
commit b36bdf6ec859e888e0ac8cf54d09e8955b436cc8
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Fri Feb 12 12:34:41 2021 -0800
[Relay][Op][Bug] Fix missing return in scatter_nd cuda strategy (#7447)
* fix missing return in scatter_nd cuda strategy
* add Relay test for scatter_nd, fix documentation
---
python/tvm/relay/op/strategy/cuda.py | 1 +
python/tvm/relay/op/transform.py | 2 +-
tests/python/relay/test_op_level3.py | 83 ++++++++++++++++++++----------------
3 files changed, 48 insertions(+), 38 deletions(-)
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 346e934..032d2dd 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -826,6 +826,7 @@ def scatter_nd_cuda(attrs, inputs, out_type, target):
name="scatter_nd.cuda",
plevel=10,
)
+ return strategy
@sort_strategy.register(["cuda", "gpu"])
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index e9d081e..d42ef47 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -321,7 +321,7 @@ def scatter_nd(data, indices, out_shape):
indices : relay.Expr
The index locations to update.
- out_shape : relay.Expr
+ out_shape : Union[Tuple[int], List[int]]
Output shape of the scatter.
Returns
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 559eb24..625c472 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1377,41 +1377,50 @@ def test_cumsum(target, ctx):
verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64")
+@tvm.testing.parametrize_targets
+def test_scatter_nd(target, ctx):
+ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
+ data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
+ indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype))
+
+ out = relay.op.scatter_nd(data, indices, shape)
+ func = relay.Function([data, indices], out)
+
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(data_np, indices_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
+
+ data = np.array([2, 3, 0])
+ indices = np.array([[1, 1, 0], [0, 1, 0]])
+ shape = (2, 2)
+ out = np.array([[0, 0], [2, 3]])
+ verify_scatter_nd(data, indices, shape, out)
+
+ data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+ indices = np.array([[0, 1], [1, 1]])
+ shape = (2, 2, 2, 2)
+ out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
+ verify_scatter_nd(data, indices, shape, out)
+
+ data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
+ indices = np.array([[1, 0, 0]])
+ shape = (2, 1560)
+ out = np.zeros(shape).astype("float32")
+ out[1, :] += data[0, :]
+ out[0, :] += data[1, :]
+ out[0, :] += data[2, :]
+ verify_scatter_nd(data, indices, shape, out)
+
+ data = np.ones((5, 3)).astype("float64")
+ indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
+ shape = (2, 7, 3)
+ out = np.zeros(shape).astype("float64")
+ for i in range(indices.shape[1]):
+ for j in range(data.shape[1]):
+ out[indices[0, i], indices[1, i], j] += data[i, j]
+ verify_scatter_nd(data, indices, shape, out)
+
+
if __name__ == "__main__":
- test_cast()
- test_zeros_ones()
- test_unary_identity()
- test_clip()
- test_transpose_infer_type()
- test_transpose()
- test_reshape_infer_type()
- test_reshape()
- test_reshape_fail()
- test_reshape_like_infer_type()
- test_reshape_like()
- test_take_infer_type()
- test_take()
- test_full_infer_type()
- test_full()
- test_full_like_infer_type()
- test_full_like()
- test_infer_type_leaky_relu()
- test_infer_type_prelu()
- test_squeeze()
- test_squeeze_infer_type()
- test_squeeze_bad_axes_infer_type()
- test_split_infer_type()
- test_arange()
- test_meshgrid()
- test_reverse()
- test_stack()
- test_tile()
- test_repeat()
- test_gather_nd()
- test_isfinite()
- test_isinf()
- test_unravel_index()
- test_sparse_to_dense()
- test_fixed_point_multiply()
- test_adv_index()
- test_cumsum()
+ pytest.main([__file__])