You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/02/18 14:45:03 UTC

[tvm] branch main updated: Set TOpPattern=kOpaque for scatter_nd (#7464)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 84c4b15  Set TOpPattern=kOpaque for scatter_nd (#7464)
84c4b15 is described below

commit 84c4b150ab25aa3ea822beed4702dcb56dddce4c
Author: Alexander Pivovarov <pi...@amazon.com>
AuthorDate: Thu Feb 18 06:44:40 2021 -0800

    Set TOpPattern=kOpaque for scatter_nd (#7464)
---
 src/relay/op/tensor/transform.cc     |  5 ++++-
 tests/python/relay/test_op_level3.py | 31 +++++++++++++++++++++++++++++++
 2 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 1e782a5..12db859 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1146,6 +1146,9 @@ Expr MakeScatterND(Expr data, Expr indices, const Array<Integer> out_shape) {
 
 TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND);
 
+// scatter_nd operator has extern schedules for CPU and GPU devices.
+// Fusing extern schedules with Injective schedules leads to errors.
+// So, converting the scatter_nd to Opaque to prevent compilation failures
 RELAY_REGISTER_OP("scatter_nd")
     .describe(R"code(Scatter elements or slices from data and store to a tensor
 whose shape is defined by indices.
@@ -1158,7 +1161,7 @@ Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with sh
     .add_argument("indices", "Tensor", "The indices tensor.")
     .set_support_level(3)
     .add_type_rel("ScatterND", ScatterNDRel)
-    .set_attr<TOpPattern>("TOpPattern", kInjective);
+    .set_attr<TOpPattern>("TOpPattern", kOpaque);
 
 // Take
 TVM_REGISTER_NODE_TYPE(TakeAttrs);
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 625c472..94fac3b 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1391,17 +1391,46 @@ def test_scatter_nd(target, ctx):
             op_res = intrp.evaluate(func)(data_np, indices_np)
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
 
+    def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
+        data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
+        indices_vars = [
+            relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np)
+        ]
+
+        # test if scatter_nd works in case indices are prepared by another Relay operator
+        indices = relay.op.stack(indices_vars, axis=0)
+        out = relay.op.scatter_nd(data, indices, shape)
+        func = relay.Function(
+            [
+                data,
+            ]
+            + indices_vars,
+            out,
+        )
+
+        fargs = [
+            data_np,
+        ]
+        for a in indices_np:
+            fargs.append(a)
+        for kind in ["graph", "debug"]:
+            intrp = relay.create_executor(kind, ctx=ctx, target=target)
+            op_res = intrp.evaluate(func)(*fargs)
+            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
+
     data = np.array([2, 3, 0])
     indices = np.array([[1, 1, 0], [0, 1, 0]])
     shape = (2, 2)
     out = np.array([[0, 0], [2, 3]])
     verify_scatter_nd(data, indices, shape, out)
+    verify_scatter_nd_with_stack(data, indices, shape, out)
 
     data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
     indices = np.array([[0, 1], [1, 1]])
     shape = (2, 2, 2, 2)
     out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
     verify_scatter_nd(data, indices, shape, out)
+    verify_scatter_nd_with_stack(data, indices, shape, out)
 
     data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
     indices = np.array([[1, 0, 0]])
@@ -1411,6 +1440,7 @@ def test_scatter_nd(target, ctx):
     out[0, :] += data[1, :]
     out[0, :] += data[2, :]
     verify_scatter_nd(data, indices, shape, out)
+    verify_scatter_nd_with_stack(data, indices, shape, out)
 
     data = np.ones((5, 3)).astype("float64")
     indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
@@ -1420,6 +1450,7 @@ def test_scatter_nd(target, ctx):
         for j in range(data.shape[1]):
             out[indices[0, i], indices[1, i], j] += data[i, j]
     verify_scatter_nd(data, indices, shape, out)
+    verify_scatter_nd_with_stack(data, indices, shape, out)
 
 
 if __name__ == "__main__":