You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/12 20:35:00 UTC

[tvm] branch main updated: [Relay][Op][Bug] Fix missing return in scatter_nd cuda strategy (#7447)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b36bdf6  [Relay][Op][Bug] Fix missing return in scatter_nd cuda strategy (#7447)
b36bdf6 is described below

commit b36bdf6ec859e888e0ac8cf54d09e8955b436cc8
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Fri Feb 12 12:34:41 2021 -0800

    [Relay][Op][Bug] Fix missing return in scatter_nd cuda strategy (#7447)
    
    * fix missing return in scatter_nd cuda strategy
    
    * add Relay test for scatter_nd, fix documentation
---
 python/tvm/relay/op/strategy/cuda.py |  1 +
 python/tvm/relay/op/transform.py     |  2 +-
 tests/python/relay/test_op_level3.py | 83 ++++++++++++++++++++----------------
 3 files changed, 48 insertions(+), 38 deletions(-)

diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 346e934..032d2dd 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -826,6 +826,7 @@ def scatter_nd_cuda(attrs, inputs, out_type, target):
         name="scatter_nd.cuda",
         plevel=10,
     )
+    return strategy
 
 
 @sort_strategy.register(["cuda", "gpu"])
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index e9d081e..d42ef47 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -321,7 +321,7 @@ def scatter_nd(data, indices, out_shape):
     indices : relay.Expr
         The index locations to update.
 
-    out_shape : relay.Expr
+    out_shape : Union[Tuple[int], List[int]]
         Output shape of the scatter.
 
     Returns
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 559eb24..625c472 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1377,41 +1377,50 @@ def test_cumsum(target, ctx):
     verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64")
 
 
+@tvm.testing.parametrize_targets
+def test_scatter_nd(target, ctx):
+    def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
+        data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
+        indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype))
+
+        out = relay.op.scatter_nd(data, indices, shape)
+        func = relay.Function([data, indices], out)
+
+        for kind in ["graph", "debug"]:
+            intrp = relay.create_executor(kind, ctx=ctx, target=target)
+            op_res = intrp.evaluate(func)(data_np, indices_np)
+            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
+
+    data = np.array([2, 3, 0])
+    indices = np.array([[1, 1, 0], [0, 1, 0]])
+    shape = (2, 2)
+    out = np.array([[0, 0], [2, 3]])
+    verify_scatter_nd(data, indices, shape, out)
+
+    data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+    indices = np.array([[0, 1], [1, 1]])
+    shape = (2, 2, 2, 2)
+    out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
+    verify_scatter_nd(data, indices, shape, out)
+
+    data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
+    indices = np.array([[1, 0, 0]])
+    shape = (2, 1560)
+    out = np.zeros(shape).astype("float32")
+    out[1, :] += data[0, :]
+    out[0, :] += data[1, :]
+    out[0, :] += data[2, :]
+    verify_scatter_nd(data, indices, shape, out)
+
+    data = np.ones((5, 3)).astype("float64")
+    indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
+    shape = (2, 7, 3)
+    out = np.zeros(shape).astype("float64")
+    for i in range(indices.shape[1]):
+        for j in range(data.shape[1]):
+            out[indices[0, i], indices[1, i], j] += data[i, j]
+    verify_scatter_nd(data, indices, shape, out)
+
+
 if __name__ == "__main__":
-    test_cast()
-    test_zeros_ones()
-    test_unary_identity()
-    test_clip()
-    test_transpose_infer_type()
-    test_transpose()
-    test_reshape_infer_type()
-    test_reshape()
-    test_reshape_fail()
-    test_reshape_like_infer_type()
-    test_reshape_like()
-    test_take_infer_type()
-    test_take()
-    test_full_infer_type()
-    test_full()
-    test_full_like_infer_type()
-    test_full_like()
-    test_infer_type_leaky_relu()
-    test_infer_type_prelu()
-    test_squeeze()
-    test_squeeze_infer_type()
-    test_squeeze_bad_axes_infer_type()
-    test_split_infer_type()
-    test_arange()
-    test_meshgrid()
-    test_reverse()
-    test_stack()
-    test_tile()
-    test_repeat()
-    test_gather_nd()
-    test_isfinite()
-    test_isinf()
-    test_unravel_index()
-    test_sparse_to_dense()
-    test_fixed_point_multiply()
-    test_adv_index()
-    test_cumsum()
+    pytest.main([__file__])