You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/04/28 05:13:18 UTC
[tvm] branch main updated: [TOPI][RELAY][ONNX] Scatter ND (#7927)
This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8fce895 [TOPI][RELAY][ONNX] Scatter ND (#7927)
8fce895 is described below
commit 8fce89500c520c4dc6ce8733172fa87ead107709
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Tue Apr 27 23:13:06 2021 -0600
[TOPI][RELAY][ONNX] Scatter ND (#7927)
* passing topi tests
* passing relay tests, needs better shape checking still
* support ONNX operator
* add shape checking back in
* fix lint
* update docstring
---
include/tvm/relay/attrs/transform.h | 5 +-
python/tvm/relay/frontend/onnx.py | 13 ++++
python/tvm/relay/frontend/pytorch.py | 21 ++-----
python/tvm/relay/op/_tensor_grad.py | 2 +-
python/tvm/relay/op/_transform.py | 2 +-
python/tvm/relay/op/strategy/generic.py | 2 +-
python/tvm/relay/op/transform.py | 13 ++--
python/tvm/topi/cuda/scatter.py | 66 ++++++++++++---------
python/tvm/topi/scatter.py | 60 +++++++++++--------
python/tvm/topi/x86/scatter.py | 60 +++++++++++--------
src/relay/op/tensor/transform.cc | 26 +++++---
tests/python/frontend/onnx/test_forward.py | 1 -
tests/python/relay/test_op_level3.py | 85 +++++++++++++++------------
tests/python/topi/python/test_topi_scatter.py | 70 ++++++++++++++--------
14 files changed, 248 insertions(+), 178 deletions(-)
diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index a5544c8..113c820 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -126,10 +126,11 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
};
struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
- Array<Integer> out_shape;
+ String mode;
TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
- TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter.");
+ TVM_ATTR_FIELD(mode).describe(
+ "Accumulation mode of the scatter, either \"update\" or \"add\".");
}
};
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index a695e00..deb2948 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1376,6 +1376,18 @@ class Scatter(OnnxOpConverter):
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
+class ScatterND(OnnxOpConverter):
+ """Operator converter for Scatter."""
+
+ @classmethod
+ def _impl_v11(cls, inputs, attr, params):
+ indices_dim = len(infer_shape(inputs[1]))
+ axes = list(range(indices_dim))
+ return _op.scatter_nd(
+ inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update"
+ )
+
+
class Greater(OnnxOpConverter):
"""Operator logical greater."""
@@ -2874,6 +2886,7 @@ def _get_convert_map(opset):
"Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}),
"Scatter": Scatter.get_converter(opset),
"ScatterElements": Scatter.get_converter(opset),
+ "ScatterND": ScatterND.get_converter(opset),
"Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
"Unsqueeze": Unsqueeze.get_converter(opset),
"Pad": Pad.get_converter(opset),
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index a31c44a..025942b 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2118,26 +2118,13 @@ class PyTorchOpConverter:
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3]
- # accumulate parameter is ignored.
- # torch.index_put default is False but Relay.scatter_nd accumulates values.
- # We assume there is no duplicate indices in torch.index_put input
if not accumulate:
- logging.warning(
- "torch.index_put accumulate parameter is False. "
- "TVM uses tvm.relay.scatter_nd operator which accumulates values. "
- "Make sure there is no duplicate indices in torch.index_put input."
- )
- # Relay scatter_nd does not support input tensor
- # We assume that torch.index_put is used with empty zero-values input tensor
- # scatter_nd will create empty zero-values tensor with a given shape
- out_shape = self.infer_shape(in_tensor)
- logging.warning(
- "tvm.relay.scatter_nd operator does not support input tensor parameter. "
- "TVM assumes that torch.index_put is used with empty zero-values input tensor"
- )
+ mode = "update"
+ else:
+ mode = "add"
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)
- return _op.transform.scatter_nd(values, index_tensor, out_shape)
+ return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)
def scalar_tensor(self, inputs, input_types):
data = inputs[0]
diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index 5836aeb..108bef0 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -834,7 +834,7 @@ def gather_nd_grad(orig, grad):
Returns the gradient of gather_nd, which is simply scatter_nd.
"""
data, indices = orig.args
- return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
+ return [scatter_nd(zeros_like(data), indices, grad, mode="add"), zeros_like(indices)]
@register_gradient("reshape_like")
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 8220ad3..2920c99 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -145,7 +145,7 @@ _reg.register_strategy("scatter_add", strategy.scatter_add_strategy)
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
- return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)]
+ return [topi.scatter_nd(inputs[0], inputs[1], inputs[2], attrs.mode)]
_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 70e0219..7451b39 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1288,7 +1288,7 @@ def wrap_compute_scatter_nd(topi_compute):
"""Wrap scatter_nd topi compute"""
def _compute_scatter_nd(attrs, inputs, _):
- return [topi_compute(inputs[0], inputs[1], attrs.out_shape)]
+ return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.mode)]
return _compute_scatter_nd
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index f94a00d..df26861 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -310,8 +310,8 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)
-def scatter_nd(data, indices, out_shape):
- """Scatter values from an array.
+def scatter_nd(data, indices, updates, mode="update"):
+ """Scatter values from an array and update.
See :py:func:`tvm.topi.scatter` for how data is scattered.
@@ -323,15 +323,18 @@ def scatter_nd(data, indices, out_shape):
indices : relay.Expr
The index locations to update.
- out_shape : Union[Tuple[int], List[int]]
- Output shape of the scatter.
+ updates : relay.Expr
+ The values to update.
+
+ mode : string
+ The accumulation mode for scatter. "update" or "add"
Returns
-------
ret : relay.Expr
The computed result.
"""
- return _make.scatter_nd(data, indices, out_shape)
+ return _make.scatter_nd(data, indices, updates, mode)
def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index fd05904..cee13d7 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -723,11 +723,12 @@ def scatter_add(data, indices, updates, axis=0):
return out
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
"""Scatter elements from a n-dimension array.
- Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
- (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+ Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+ (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}),
+ scatter_nd computes
.. code-block::
@@ -737,9 +738,9 @@ def scatter_nd(data, indices, shape):
x_M,
...,
x_{N-1}
- ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+ ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
- all other entries in the output are 0. Repeated indices are summed.
+ where the update function f is determinted by the mode.
Parameters
----------
@@ -749,35 +750,41 @@ def scatter_nd(data, indices, shape):
indices : tvm.te.Tensor
The indices of the values to extract.
- shape : Sequence[int]
- The output shape. This must be specified because it cannot be inferred.
+ updates : tvm.te.Tensor
+ The updates to apply at the Indices
+
+ mode : string
+ The update mode for the algorithm, either "update" or "add"
+ If update, the update values will replace the input data
+ If add, the update values will be added to the input data
Returns
-------
ret : tvm.te.Tensor
"""
- _verify_scatter_nd_inputs(data, indices, shape)
+ _verify_scatter_nd_inputs(data, indices, updates)
- def gen_ir(data_ptr, indices_ptr, out_ptr):
+ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
ib = tvm.tir.ir_builder.create()
data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
+ updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)
# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
- # number of loops. We do the same thing for all the data dimensions.
+ # number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i
- fused_data_dimension = 1
- for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
- fused_data_dimension *= i
+ fused_updates_dimension = 1
+ for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
+ fused_updates_dimension *= i
fused_shape = 1
- for i in shape:
+ for i in data_ptr.shape:
fused_shape *= i
# For now we avoid parallizing over dimensions indexed by `indices` as
@@ -789,38 +796,41 @@ def scatter_nd(data, indices, shape):
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
- tdim = min(max_threads, fused_data_dimension)
+ tdim = min(max_threads, fused_updates_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
- bdim = ceil_div(fused_data_dimension, tdim)
+ bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)
- # zero data
- # TODO(tkonolige): could we use topi.full to zero it instead?
with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
- index = i * fused_data_dimension + bx * tdim + tx
+ index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(index < fused_shape):
- out[index] = tvm.tir.Cast(data_ptr.dtype, 0)
+ out[index] = data[index]
with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
- with ib.if_scope(j < fused_data_dimension):
- offset = fused_data_dimension
+ with ib.if_scope(j < fused_updates_dimension):
+ offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
- offset *= shape[l]
- out[index] += data[i * fused_data_dimension + j]
+ offset *= data_ptr.shape[l]
+ if mode == "update":
+ out[index] = updates[i * fused_updates_dimension + j]
+ elif mode == "add":
+ out[index] += updates[i * fused_updates_dimension + j]
+ else:
+ raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
return ib.get()
- out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
+ out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
- [shape],
- [data, indices],
- lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+ [data.shape],
+ [data, indices, updates],
+ lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_cuda",
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index a376963..d7b008c 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
-from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate
+from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate
from ..te import extern, hybrid
@@ -199,22 +199,22 @@ def scatter(data, indices, updates, axis=0):
raise ValueError("scatter only support for 1-4 dimensions")
-def _verify_scatter_nd_inputs(data, indices, shape):
+def _verify_scatter_nd_inputs(data, indices, updates):
mdim = int(indices.shape[0])
- assert mdim <= len(shape), (
+ assert mdim <= len(data.shape), (
f"The first dimension of the indices ({mdim}) must be less than or equal to "
f"the length of the shape of the output ({len(shape)})."
)
for i in range(len(indices.shape) - 1):
- assert indices.shape[i + 1] == data.shape[i], (
+ assert indices.shape[i + 1] == updates.shape[i], (
f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
- f"data[{i}] ({data.shape[i]})."
+ f"updates[{i}] ({updates.shape[i]})."
)
- for i in range(mdim, len(shape)):
+ for i in range(mdim, len(data.shape)):
data_ind = i - mdim + len(indices.shape) - 1
- assert data.shape[data_ind] == shape[i], (
- f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension "
- f"of out_shape[{i}] ({shape[i]})."
+ assert updates.shape[data_ind] == data.shape[i], (
+ f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension "
+ f"of out_shape[{i}] ({data.shape[i]})."
)
assert (
@@ -222,11 +222,12 @@ def _verify_scatter_nd_inputs(data, indices, shape):
), f"Indices must be a tensor of integers, but its elements are {indices.dtype}."
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
"""Scatter elements from a n-dimension array.
- Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
- (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+ Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+ (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}),
+ scatter_nd computes
.. code-block::
@@ -236,9 +237,9 @@ def scatter_nd(data, indices, shape):
x_M,
...,
x_{N-1}
- ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+ ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
- all other entries in the output are 0. Repeated indices are summed.
+ where the update function f is determinted by the mode.
Parameters
----------
@@ -248,29 +249,33 @@ def scatter_nd(data, indices, shape):
indices : tvm.te.Tensor
The indices of the values to extract.
- shape : Sequence[int]
- The output shape. This must be specified because it cannot be inferred.
+ updates : tvm.te.Tensor
+ The updates to apply at the Indices
+
+ mode : string
+ The update mode for the algorithm, either "update" or "add"
+ If update, the update values will replace the input data
+ If add, the update values will be added to the input data
Returns
-------
ret : tvm.te.Tensor
"""
- _verify_scatter_nd_inputs(data, indices, shape)
+ _verify_scatter_nd_inputs(data, indices, updates)
- def gen_ir(data_ptr, indices_ptr, out_ptr):
+ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
ib = ir_builder.create()
data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
+ updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)
- # zero data
- # TODO(tkonolige): could we use topi.full to zero it instead?
fused_shape = 1
- for i in shape:
+ for i in data.shape:
fused_shape *= i
with ib.for_range(0, fused_shape) as i:
- out[i] = Cast(data_ptr.dtype, 0)
+ out[i] = data[i]
# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
@@ -300,15 +305,20 @@ def scatter_nd(data, indices, shape):
)
)
offset *= shape[l]
- out[index] += data[i * fused_data_dimension + j]
+ if mode == "add":
+ out[index] += updates[i * fused_data_dimension + j]
+ elif mode == "update":
+ out[index] = updates[i * fused_data_dimension + j]
+ else:
+ raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
return ib.get()
out_buf = decl_buffer(shape, data.dtype, "out_buf")
return extern(
[shape],
- [data, indices],
- lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+ [data, indices, updates],
+ lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_generic",
diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py
index 8bb3f57..5eb5e6e 100644
--- a/python/tvm/topi/x86/scatter.py
+++ b/python/tvm/topi/x86/scatter.py
@@ -20,11 +20,12 @@ from tvm import te
from ..scatter import _verify_scatter_nd_inputs
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
"""Scatter elements from a n-dimension array.
- Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
- (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+ Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+ (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}),
+ scatter_nd computes
.. code-block::
@@ -34,9 +35,9 @@ def scatter_nd(data, indices, shape):
x_M,
...,
x_{N-1}
- ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+ ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
- all other entries in the output are 0. Repeated indices are summed.
+ where the update function f is determinted by the mode.
Parameters
----------
@@ -46,62 +47,71 @@ def scatter_nd(data, indices, shape):
indices : tvm.te.Tensor
The indices of the values to extract.
- shape : Sequence[int]
- The output shape. This must be specified because it cannot be inferred.
+ updates : tvm.te.Tensor
+ The updates to apply at the Indices
+
+ mode : string
+ The update mode for the algorithm, either "update" or "add"
+ If update, the update values will replace the input data
+ If add, the update values will be added to the input data
Returns
-------
ret : tvm.te.Tensor
"""
- _verify_scatter_nd_inputs(data, indices, shape)
+ _verify_scatter_nd_inputs(data, indices, updates)
- def gen_ir(data_ptr, indices_ptr, out_ptr):
+ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
# pylint: disable=invalid-name
ib = tvm.tir.ir_builder.create()
data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
+ updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)
# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
- # number of loops. We do the same thing for all the data dimensions.
+ # number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i
- fused_data_dimension = 1
- for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
- fused_data_dimension *= i
+ fused_updates_dimension = 1
+ for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
+ fused_updates_dimension *= i
fused_shape = 1
- for i in shape:
+ for i in data_ptr.shape:
fused_shape *= i
- # zero data
- # TODO(tkonolige): could we use topi.full to zero it instead?
with ib.for_range(0, fused_shape) as i:
- out[i] = tvm.tir.Cast(data_ptr.dtype, 0)
+ out[i] = data[i]
with ib.for_range(0, fused_indices_dimension) as i:
- with ib.for_range(0, fused_data_dimension, kind="parallel") as j:
- offset = fused_data_dimension
+ with ib.for_range(0, fused_updates_dimension, kind="parallel") as j:
+ offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
- offset *= shape[l]
- out[index] += data[i * fused_data_dimension + j]
+ offset *= data_ptr.shape[l]
+ if mode == "update":
+ out[index] = updates[i * fused_updates_dimension + j]
+ elif mode == "add":
+ out[index] += updates[i * fused_updates_dimension + j]
+ else:
+ raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
return ib.get()
- out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
+ out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
- [shape],
- [data, indices],
- lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+ [data.shape],
+ [data, indices, updates],
+ lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_x86",
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 11e94cb..e937cb0 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1096,10 +1096,11 @@ TVM_REGISTER_NODE_TYPE(ScatterNDAttrs);
bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
- // `types` contains: [data, indices, result]
- ICHECK_EQ(types.size(), 3);
+ // `types` contains: [data, indices, updates, result]
+ ICHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* indices = types[1].as<TensorTypeNode>();
+ const auto* updates = types[2].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "ScatterND: expect input data type to be TensorType but got " << types[0];
@@ -1110,8 +1111,14 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "ScatterND: expect indices type to be TensorType but got " << types[1];
return false;
}
+ if (updates == nullptr) {
+ ICHECK(types[2].as<IncompleteTypeNode>())
+ << "ScatterND: expect updates type to be TensorType but got " << types[2];
+ return false;
+ }
ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers.";
- const auto out_shape = attrs.as<ScatterNDAttrs>()->out_shape;
+
+ const auto out_shape = data->shape;
const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
const size_t kdim = indices->shape.size() - 1;
const size_t ndim = out_shape.size();
@@ -1120,7 +1127,7 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
"with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N.";
// Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's.
for (size_t i = 0; i < kdim; i++) {
- reporter->AssertEQ(indices->shape[i + 1], data->shape[i]);
+ reporter->AssertEQ(indices->shape[i + 1], updates->shape[i]);
}
std::vector<IndexExpr> oshape;
@@ -1133,15 +1140,15 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]);
}
- reporter->Assign(types[2], TensorType(oshape, data->dtype));
+ reporter->Assign(types[3], TensorType(data->shape, data->dtype));
return true;
}
-Expr MakeScatterND(Expr data, Expr indices, const Array<Integer> out_shape) {
+Expr MakeScatterND(Expr data, Expr indices, Expr updates, String mode) {
auto attrs = make_object<ScatterNDAttrs>();
- attrs->out_shape = out_shape;
+ attrs->mode = std::move(mode);
static const Op& op = Op::Get("scatter_nd");
- return Call(op, {data, indices}, Attrs(attrs), {});
+ return Call(op, {data, indices, updates}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND);
@@ -1156,9 +1163,10 @@ whose shape is defined by indices.
Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape
(M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}).
)code" TVM_ADD_FILELINE)
- .set_num_inputs(2)
+ .set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
+ .add_argument("updates", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("ScatterND", ScatterNDRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque);
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 1a3d0d4..e11689c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4237,7 +4237,6 @@ unsupported_onnx_tests = [
"test_round/",
"test_scan9_sum/",
"test_scan_sum/",
- "test_scatternd/",
"test_simple_rnn_defaults/",
"test_simple_rnn_with_initial_bias/",
"test_strnormalizer_export_monday_casesensintive_lower/",
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index bf0a7e4..e84b22b 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1833,38 +1833,39 @@ def test_cumprod(target, dev):
@tvm.testing.parametrize_targets
def test_scatter_nd(target, dev):
- def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
+ def verify_scatter_nd(
+ data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5
+ ):
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype))
+ updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype))
- out = relay.op.scatter_nd(data, indices, shape)
- func = relay.Function([data, indices], out)
+ out = relay.op.scatter_nd(data, indices, updates, mode)
+ func = relay.Function([data, indices, updates], out)
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, device=dev, target=target)
- op_res = intrp.evaluate(func)(data_np, indices_np)
+ op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
- def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
+ def verify_scatter_nd_with_stack(
+ data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5
+ ):
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices_vars = [
relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np)
]
+ updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype))
# test if scatter_nd works in case indices are prepared by another Relay operator
indices = relay.op.stack(indices_vars, axis=0)
- out = relay.op.scatter_nd(data, indices, shape)
+ out = relay.op.scatter_nd(data, indices, updates, mode)
func = relay.Function(
- [
- data,
- ]
- + indices_vars,
+ [data, updates] + indices_vars,
out,
)
- fargs = [
- data_np,
- ]
+ fargs = [data_np, updates_np]
for a in indices_np:
fargs.append(a)
for kind in ["graph", "debug"]:
@@ -1872,39 +1873,47 @@ def test_scatter_nd(target, dev):
op_res = intrp.evaluate(func)(*fargs)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
- data = np.array([2, 3, 0])
+ data = np.zeros((2, 2)).astype("int64")
indices = np.array([[1, 1, 0], [0, 1, 0]])
- shape = (2, 2)
+ updates = np.array([2, 3, 0])
out = np.array([[0, 0], [2, 3]])
- verify_scatter_nd(data, indices, shape, out)
- verify_scatter_nd_with_stack(data, indices, shape, out)
+ verify_scatter_nd(data, indices, updates, out)
+ verify_scatter_nd_with_stack(data, indices, updates, out)
- data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+ data = np.zeros((2, 2, 2, 2)).astype("int64")
indices = np.array([[0, 1], [1, 1]])
- shape = (2, 2, 2, 2)
+ updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
- verify_scatter_nd(data, indices, shape, out)
- verify_scatter_nd_with_stack(data, indices, shape, out)
+ verify_scatter_nd(data, indices, updates, out)
+ verify_scatter_nd_with_stack(data, indices, updates, out)
- data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
indices = np.array([[1, 0, 0]])
+ updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
shape = (2, 1560)
- out = np.zeros(shape).astype("float32")
- out[1, :] += data[0, :]
- out[0, :] += data[1, :]
- out[0, :] += data[2, :]
- verify_scatter_nd(data, indices, shape, out)
- verify_scatter_nd_with_stack(data, indices, shape, out)
-
- data = np.ones((5, 3)).astype("float64")
- indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
- shape = (2, 7, 3)
- out = np.zeros(shape).astype("float64")
- for i in range(indices.shape[1]):
- for j in range(data.shape[1]):
- out[indices[0, i], indices[1, i], j] += data[i, j]
- verify_scatter_nd(data, indices, shape, out)
- verify_scatter_nd_with_stack(data, indices, shape, out)
+ data = np.zeros(shape).astype("float32")
+ out = data.copy()
+ out[1, :] += updates[0, :]
+ out[0, :] += updates[1, :]
+ out[0, :] += updates[2, :]
+ verify_scatter_nd(data, indices, updates, out)
+ verify_scatter_nd_with_stack(data, indices, updates, out)
+
+ for mode in ["add", "update"]:
+ indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype(
+ "int64"
+ )
+ updates = np.ones((5, 3)).astype("float64")
+ shape = (2, 7, 3)
+ data = np.random.random(shape).astype("float64")
+ out = data.copy()
+ for i in range(indices.shape[1]):
+ for j in range(updates.shape[1]):
+ if mode == "add":
+ out[indices[0, i], indices[1, i], j] += updates[i, j]
+ elif mode == "update":
+ out[indices[0, i], indices[1, i], j] = updates[i, j]
+ verify_scatter_nd(data, indices, updates, out, mode)
+ verify_scatter_nd_with_stack(data, indices, updates, out, mode)
def test_unique():
diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py
index ad73bb5..648ef62 100644
--- a/tests/python/topi/python/test_topi_scatter.py
+++ b/tests/python/topi/python/test_topi_scatter.py
@@ -23,44 +23,64 @@ import tvm.topi.testing
@tvm.testing.parametrize_targets
def test_scatter_nd(dev, target):
- def check_scatter_nd(data, indices, shape, out):
+ def check_scatter_nd(data, indices, updates, out, mode="add"):
implementations = {
- "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern),
- "gpu": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern),
- "cpu": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern),
+ "generic": (
+ lambda x, y, z: topi.scatter_nd(x, y, z, mode),
+ topi.generic.schedule_extern,
+ ),
+ "gpu": (
+ lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode),
+ topi.generic.schedule_extern,
+ ),
+ "cpu": (
+ lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode),
+ topi.generic.schedule_extern,
+ ),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
- tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, dev, fcompute, fschedule)
+ tvm.topi.testing.compare_numpy_tvm(
+ [data, indices, updates], out, target, dev, fcompute, fschedule
+ )
- data = np.array([2, 3, 0])
+ data = np.zeros((2, 2)).astype("int64")
indices = np.array([[1, 1, 0], [0, 1, 0]])
- shape = (2, 2)
+ updates = np.array([2, 3, 0])
out = np.array([[0, 0], [2, 3]])
- check_scatter_nd(data, indices, shape, out)
+ check_scatter_nd(data, indices, updates, out)
- data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+ data = np.zeros((2, 2, 2, 2)).astype("int64")
indices = np.array([[0, 1], [1, 1]])
- shape = (2, 2, 2, 2)
+ updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
- check_scatter_nd(data, indices, shape, out)
+ check_scatter_nd(data, indices, updates, out)
- data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
indices = np.array([[1, 0, 0]])
+ updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
shape = (2, 1560)
- out = np.zeros(shape).astype("float32")
- out[1, :] += data[0, :]
- out[0, :] += data[1, :]
- out[0, :] += data[2, :]
- check_scatter_nd(data, indices, shape, out)
+ data = np.zeros(shape).astype("float32")
+ out = data.copy()
+ out[1, :] += updates[0, :]
+ out[0, :] += updates[1, :]
+ out[0, :] += updates[2, :]
+ check_scatter_nd(data, indices, updates, out)
- data = np.ones((5, 3)).astype("float64")
- indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
- shape = (2, 7, 3)
- out = np.zeros(shape).astype("float64")
- for i in range(indices.shape[1]):
- for j in range(data.shape[1]):
- out[indices[0, i], indices[1, i], j] += data[i, j]
- check_scatter_nd(data, indices, shape, out)
+ for mode in ["add", "update"]:
+ updates = np.ones((5, 3)).astype("float64")
+ indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype(
+ "int64"
+ )
+ shape = (2, 7, 3)
+ data = np.random.random(shape).astype("float64")
+ out = data.copy()
+ for i in range(indices.shape[1]):
+ for j in range(updates.shape[1]):
+ if mode == "add":
+ out[indices[0, i], indices[1, i], j] += updates[i, j]
+ elif mode == "update":
+ out[indices[0, i], indices[1, i], j] = updates[i, j]
+
+ check_scatter_nd(data, indices, updates, out, mode)
if __name__ == "__main__":