You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/05/30 03:06:49 UTC

[tvm] branch main updated: [Relay] Support dynamic indices size in gather_nd and scatter_nd (#8105)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 27e44ee  [Relay] Support dynamic indices size in gather_nd and scatter_nd (#8105)
27e44ee is described below

commit 27e44eee21bb65c0ef897799355084944c863aec
Author: masahi <ma...@gmail.com>
AuthorDate: Sun May 30 12:06:22 2021 +0900

    [Relay] Support dynamic indices size in gather_nd and scatter_nd (#8105)
    
    * add gather_nd shape func
    
    * refactor gather_nd ref funcs
    
    * add dynamic gather_nd test
    
    * gather_dim -> num_indices_per_tuple
    
    * support dynamic scatter nd
    
    * minor fix
    
    * fix pylint
    
    * rename to index_rank and make it Optional
    
    * pylint, do not use -1 for default value
---
 include/tvm/relay/attrs/transform.h   |  7 +++++
 python/tvm/relay/frontend/onnx.py     |  4 ++-
 python/tvm/relay/op/_transform.py     | 31 ++++++++++++++++++++++
 python/tvm/relay/op/transform.py      |  8 ++++--
 python/tvm/topi/scatter.py            |  6 ++++-
 src/relay/op/tensor/transform.cc      |  4 ++-
 tests/python/relay/test_any.py        | 48 +++++++++++++++++++++++++++++++++++
 tests/python/relay/test_op_level3.py  | 22 ++--------------
 tests/python/relay/utils/ref_funcs.py | 48 +++++++++++++++++++++++++++++++++++
 9 files changed, 153 insertions(+), 25 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index cc97a94..027b3fe 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -146,11 +146,18 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
 
 struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
   Integer batch_dims;
+  Optional<Integer> index_rank;
 
   TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
     TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
+    TVM_ATTR_FIELD(index_rank)
+        .set_default(NullValue<Integer>())
+        .describe(
+            "The size of an indexing tuple, which is a fixed value. Only needed when the number of "
+            "indexting tuples is dynamic.");
   }
 };
+
 struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
   Integer batch_dims;
   Integer axis;
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 3f876f4..896e8af 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1416,8 +1416,10 @@ class GatherND(OnnxOpConverter):
     @classmethod
     def _impl_common(cls, data, indices, batch_dims=0):
         indices_dims = len(infer_shape(indices))
+        indices_shape = infer_shape(indices)
         indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
-        return _op.gather_nd(data, indices, batch_dims)
+        index_rank = indices_shape[-1]
+        return _op.gather_nd(data, indices, batch_dims, index_rank)
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 412acb4..94c413b 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -1074,3 +1074,34 @@ def unique_shape_func(attrs, inputs, _):
         return _unique_with_counts_shape(inputs[0])
     else:
         return _unique_shape(inputs[0])
+
+
+@script
+def _gather_nd_shape(data_shape, indices_shape, batch_dims, index_rank):
+    ndim = data_shape.shape[0]
+    # using mdim = indices_shape[0] wouldn't work because a rank cannot
+    # depend on a runtime shape dimension of indices tensor, even if the
+    # dimension is always a known, fixed value. As a workaround, we assume that
+    # the fixed gather dimension (the size of an indexing tuple) is recorded
+    # in gather_nd op attributes.
+    mdim = index_rank
+    kdim = indices_shape.shape[0] - 1
+    out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
+    for i in range(1, kdim + 1):
+        out_shape[i - 1] = indices_shape[i]
+    for i in range(mdim + batch_dims, ndim):
+        out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i]
+    return out_shape
+
+
+@_reg.register_shape_func("gather_nd", False)
+def gather_nd_shape_func(attrs, inputs, _):
+    """
+    Shape func for gather_nd operator.
+    """
+    batch_dims = get_const_int(attrs.batch_dims)
+    index_rank = get_const_int(attrs.index_rank)
+
+    assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd"
+
+    return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))]
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index c87f545..74fb44f 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -1072,7 +1072,7 @@ def gather(data, axis, indices):
     return _make.gather(data, axis, indices)
 
 
-def gather_nd(data, indices, batch_dims=0):
+def gather_nd(data, indices, batch_dims=0, index_rank=None):
     """Gather elements or slices from data and store to a tensor whose shape is
     defined by indices.
 
@@ -1087,6 +1087,10 @@ def gather_nd(data, indices, batch_dims=0):
     batch_dims : int
         The number of batch dimensions.
 
+    index_rank : int, optional
+        The size of an indexing tuple, which is a fixed value and the same as indices.shape[0]
+        Only needed when other dimensions of indices are dynamic.
+
     Returns
     -------
     ret : relay.Expr
@@ -1108,7 +1112,7 @@ def gather_nd(data, indices, batch_dims=0):
         indices = [[1, 0]]
         relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
     """
-    return _make.gather_nd(data, indices, batch_dims)
+    return _make.gather_nd(data, indices, batch_dims, index_rank)
 
 
 def sequence_mask(data, valid_length, mask_value=0, axis=0):
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index d7b008c..0fe29f3 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
 """Scatter operator"""
-from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate
+from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr
 from ..te import extern, hybrid
 
 
@@ -206,12 +206,16 @@ def _verify_scatter_nd_inputs(data, indices, updates):
         f"the length of the shape of the output ({len(shape)})."
     )
     for i in range(len(indices.shape) - 1):
+        if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var):
+            continue
         assert indices.shape[i + 1] == updates.shape[i], (
             f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
             f"updates[{i}] ({updates.shape[i]})."
         )
     for i in range(mdim, len(data.shape)):
         data_ind = i - mdim + len(indices.shape) - 1
+        if isinstance(updates.shape[data_ind], expr.Var) or isinstance(data.shape[i], expr.Var):
+            continue
         assert updates.shape[data_ind] == data.shape[i], (
             f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension "
             f"of out_shape[{i}] ({data.shape[i]})."
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index bf45a41..10fe5e5 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3373,10 +3373,12 @@ Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& i
   return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
 }
 
-Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) {
+Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0,
+                  Optional<Integer> index_rank = NullValue<Integer>()) {
   static const Op& op = Op::Get("gather_nd");
   auto attrs = make_object<GatherNDAttrs>();
   attrs->batch_dims = batch_dims;
+  attrs->index_rank = index_rank;
   return Call(op, {data, indices}, Attrs(attrs));
 }
 
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 11f4515..8016e43 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -25,6 +25,7 @@ from tvm.relay.loops import while_loop
 from tvm.relay.testing import run_infer_type as infer_type
 
 from utils.assert_diagnostic import DiagnosticTesting
+from utils import ref_funcs
 
 
 def int32(val):
@@ -1703,5 +1704,52 @@ def test_all_class_non_max_suppression():
     )
 
 
+@tvm.testing.uses_gpu
+def test_gather_nd():
+    def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0):
+        x = relay.var("x", relay.TensorType(data_shape, "float32"))
+        y = relay.var("y", relay.TensorType(indices_shape, "int32"))
+        z = relay.gather_nd(x, y, batch_dims, indices_shape[0])
+
+        mod = tvm.IRModule()
+        mod["main"] = relay.Function([x, y], z)
+
+        data_np = np.random.uniform(size=data_shape_np).astype("float32")
+        indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32")
+
+        ref_res = ref_funcs.gather_nd(data_np, indices_np, batch_dims)
+        check_result([data_np, indices_np], mod, [ref_res])
+
+    verify_gather_nd((2, 2), (2, relay.Any()), (2, 2), (2, 3))
+    verify_gather_nd((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3))
+    verify_gather_nd((relay.Any(), 2), (1, relay.Any()), (10, 2), (1, 10), 1)
+    verify_gather_nd(
+        (relay.Any(), 2, 2, 3, 4), (3, relay.Any(), relay.Any()), (3, 2, 2, 3, 4), (3, 3, 2), 2
+    )
+
+
+@tvm.testing.uses_gpu
+def test_scatter_nd():
+    def verify_scatter_nd(data_np, indices_np, updates_np, ref_res):
+        indices_shape = (2, relay.Any())
+        updates_shape = (relay.Any(),)
+        data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
+        indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype)))
+        updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype)))
+
+        out = relay.op.scatter_nd(data, indices, updates, "add")
+
+        mod = tvm.IRModule()
+        mod["main"] = relay.Function([data, indices, updates], out)
+
+        check_result([data_np, indices_np, updates_np], mod, [ref_res])
+
+    data = np.zeros((2, 2)).astype("int64")
+    indices = np.array([[1, 1, 0], [0, 1, 0]])
+    updates = np.array([2, 3, 0])
+    out = np.array([[0, 0], [2, 3]])
+    verify_scatter_nd(data, indices, updates, out)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index fd6d7a9..0795594 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -26,6 +26,7 @@ from tvm import relay, te
 from tvm.error import TVMError
 from tvm.relay import create_executor, transform
 from tvm.relay.testing import check_grad, run_infer_type
+from utils import ref_funcs
 
 
 def test_zeros_ones():
@@ -1266,26 +1267,7 @@ def test_gather_nd():
         else:
             y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")
 
-        def gather_nd_batch_dims_1_ref(data, indices):
-            res = []
-            for i, row in enumerate(data):
-                indices_tuple = tuple(indices[:, i])  # the indices for the i-th batch
-                res.append(row[indices_tuple])
-            # stack on the batch dim
-            return np.stack(res, 0)
-
-        if batch_dims > 1:
-            x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:])
-            y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :])
-
-            ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape)
-
-            out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:]
-            ref_res = np.reshape(ref_res, out_shape)
-        elif batch_dims == 1:
-            ref_res = gather_nd_batch_dims_1_ref(x_data, y_data)
-        else:
-            ref_res = x_data[tuple(y_data)]
+        ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims)
 
         for target, dev in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
diff --git a/tests/python/relay/utils/ref_funcs.py b/tests/python/relay/utils/ref_funcs.py
new file mode 100644
index 0000000..924805b
--- /dev/null
+++ b/tests/python/relay/utils/ref_funcs.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+
+
+def gather_nd(data_np, indices_np, batch_dims=0):
+    """gather_nd implemented using numpy"""
+    data_shape = data_np.shape
+    indices_shape = indices_np.shape
+
+    def gather_nd_batch_dims_1_ref(data, indices):
+        res = []
+        for i, row in enumerate(data):
+            indices_tuple = tuple(indices[:, i])  # the indices for the i-th batch
+            res.append(row[indices_tuple])
+        # stack on the batch dim
+        return np.stack(res, 0)
+
+    if batch_dims > 1:
+        data_np_reshape = np.reshape(data_np, (-1,) + data_shape[batch_dims:])
+        indices_np_reshape = np.reshape(
+            indices_np, (indices_shape[0], -1) + indices_shape[(batch_dims + 1) :]
+        )
+
+        ref_res = gather_nd_batch_dims_1_ref(data_np_reshape, indices_np_reshape)
+
+        out_shape = indices_shape[1 : (batch_dims + 1)] + ref_res.shape[1:]
+        ref_res = np.reshape(ref_res, out_shape)
+    elif batch_dims == 1:
+        ref_res = gather_nd_batch_dims_1_ref(data_np, indices_np)
+    else:
+        ref_res = data_np[tuple(indices_np)]
+
+    return ref_res