You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/04/28 05:13:18 UTC

[tvm] branch main updated: [TOPI][RELAY][ONNX] Scatter ND (#7927)

This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8fce895  [TOPI][RELAY][ONNX] Scatter ND (#7927)
8fce895 is described below

commit 8fce89500c520c4dc6ce8733172fa87ead107709
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Tue Apr 27 23:13:06 2021 -0600

    [TOPI][RELAY][ONNX] Scatter ND (#7927)
    
    * passing topi tests
    
    * passing relay tests, needs better shape checking still
    
    * support ONNX operator
    
    * add shape checking back in
    
    * fix lint
    
    * update docstring
---
 include/tvm/relay/attrs/transform.h           |  5 +-
 python/tvm/relay/frontend/onnx.py             | 13 ++++
 python/tvm/relay/frontend/pytorch.py          | 21 ++-----
 python/tvm/relay/op/_tensor_grad.py           |  2 +-
 python/tvm/relay/op/_transform.py             |  2 +-
 python/tvm/relay/op/strategy/generic.py       |  2 +-
 python/tvm/relay/op/transform.py              | 13 ++--
 python/tvm/topi/cuda/scatter.py               | 66 ++++++++++++---------
 python/tvm/topi/scatter.py                    | 60 +++++++++++--------
 python/tvm/topi/x86/scatter.py                | 60 +++++++++++--------
 src/relay/op/tensor/transform.cc              | 26 +++++---
 tests/python/frontend/onnx/test_forward.py    |  1 -
 tests/python/relay/test_op_level3.py          | 85 +++++++++++++++------------
 tests/python/topi/python/test_topi_scatter.py | 70 ++++++++++++++--------
 14 files changed, 248 insertions(+), 178 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index a5544c8..113c820 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -126,10 +126,11 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
 };
 
 struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
-  Array<Integer> out_shape;
+  String mode;
 
   TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
-    TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter.");
+    TVM_ATTR_FIELD(mode).describe(
+        "Accumulation mode of the scatter, either \"update\" or \"add\".");
   }
 };
 
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index a695e00..deb2948 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1376,6 +1376,18 @@ class Scatter(OnnxOpConverter):
         return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
 
 
+class ScatterND(OnnxOpConverter):
+    """Operator converter for Scatter."""
+
+    @classmethod
+    def _impl_v11(cls, inputs, attr, params):
+        indices_dim = len(infer_shape(inputs[1]))
+        axes = list(range(indices_dim))
+        return _op.scatter_nd(
+            inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update"
+        )
+
+
 class Greater(OnnxOpConverter):
     """Operator logical greater."""
 
@@ -2874,6 +2886,7 @@ def _get_convert_map(opset):
         "Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}),
         "Scatter": Scatter.get_converter(opset),
         "ScatterElements": Scatter.get_converter(opset),
+        "ScatterND": ScatterND.get_converter(opset),
         "Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
         "Unsqueeze": Unsqueeze.get_converter(opset),
         "Pad": Pad.get_converter(opset),
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index a31c44a..025942b 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2118,26 +2118,13 @@ class PyTorchOpConverter:
         indices = inputs[1]
         values = inputs[2]
         accumulate = inputs[3]
-        # accumulate parameter is ignored.
-        # torch.index_put default is False but Relay.scatter_nd accumulates values.
-        # We assume there is no duplicate indices in torch.index_put input
         if not accumulate:
-            logging.warning(
-                "torch.index_put accumulate parameter is False. "
-                "TVM uses tvm.relay.scatter_nd operator which accumulates values. "
-                "Make sure there is no duplicate indices in torch.index_put input."
-            )
-        # Relay scatter_nd does not support input tensor
-        # We assume that torch.index_put is used with empty zero-values input tensor
-        # scatter_nd will create empty zero-values tensor with a given shape
-        out_shape = self.infer_shape(in_tensor)
-        logging.warning(
-            "tvm.relay.scatter_nd operator does not support input tensor parameter. "
-            "TVM assumes that torch.index_put is used with empty zero-values input tensor"
-        )
+            mode = "update"
+        else:
+            mode = "add"
         # Combine array of index tensors into one index tensor with shape (N,_)
         index_tensor = _op.stack(indices, axis=0)
-        return _op.transform.scatter_nd(values, index_tensor, out_shape)
+        return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)
 
     def scalar_tensor(self, inputs, input_types):
         data = inputs[0]
diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index 5836aeb..108bef0 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -834,7 +834,7 @@ def gather_nd_grad(orig, grad):
     Returns the gradient of gather_nd, which is simply scatter_nd.
     """
     data, indices = orig.args
-    return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
+    return [scatter_nd(zeros_like(data), indices, grad, mode="add"), zeros_like(indices)]
 
 
 @register_gradient("reshape_like")
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 8220ad3..2920c99 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -145,7 +145,7 @@ _reg.register_strategy("scatter_add", strategy.scatter_add_strategy)
 @_reg.register_compute("scatter_nd")
 def compute_scatter_nd(attrs, inputs, output_type):
     """Compute definition of scatter_nd"""
-    return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)]
+    return [topi.scatter_nd(inputs[0], inputs[1], inputs[2], attrs.mode)]
 
 
 _reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 70e0219..7451b39 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1288,7 +1288,7 @@ def wrap_compute_scatter_nd(topi_compute):
     """Wrap scatter_nd topi compute"""
 
     def _compute_scatter_nd(attrs, inputs, _):
-        return [topi_compute(inputs[0], inputs[1], attrs.out_shape)]
+        return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.mode)]
 
     return _compute_scatter_nd
 
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index f94a00d..df26861 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -310,8 +310,8 @@ def scatter_add(data, indices, updates, axis):
     return _make.scatter_add(data, indices, updates, axis)
 
 
-def scatter_nd(data, indices, out_shape):
-    """Scatter values from an array.
+def scatter_nd(data, indices, updates, mode="update"):
+    """Scatter values from an array and update.
 
     See :py:func:`tvm.topi.scatter` for how data is scattered.
 
@@ -323,15 +323,18 @@ def scatter_nd(data, indices, out_shape):
     indices : relay.Expr
         The index locations to update.
 
-    out_shape : Union[Tuple[int], List[int]]
-        Output shape of the scatter.
+    updates : relay.Expr
+        The values to update.
+
+    mode : string
+        The accumulation mode for scatter. "update" or "add"
 
     Returns
     -------
     ret : relay.Expr
         The computed result.
     """
-    return _make.scatter_nd(data, indices, out_shape)
+    return _make.scatter_nd(data, indices, updates, mode)
 
 
 def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index fd05904..cee13d7 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -723,11 +723,12 @@ def scatter_add(data, indices, updates, axis=0):
     return out
 
 
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
     """Scatter elements from a n-dimension array.
 
-    Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
-    (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+    Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+    (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}),
+    scatter_nd computes
 
     .. code-block::
 
@@ -737,9 +738,9 @@ def scatter_nd(data, indices, shape):
                x_M,
                ...,
                x_{N-1}
-              ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+              ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
 
-    all other entries in the output are 0. Repeated indices are summed.
+    where the update function f is determinted by the mode.
 
     Parameters
     ----------
@@ -749,35 +750,41 @@ def scatter_nd(data, indices, shape):
     indices : tvm.te.Tensor
         The indices of the values to extract.
 
-    shape : Sequence[int]
-        The output shape. This must be specified because it cannot be inferred.
+    updates : tvm.te.Tensor
+        The updates to apply at the Indices
+
+    mode : string
+        The update mode for the algorithm, either "update" or "add"
+        If update, the update values will replace the input data
+        If add, the update values will be added to the input data
 
     Returns
     -------
     ret : tvm.te.Tensor
     """
-    _verify_scatter_nd_inputs(data, indices, shape)
+    _verify_scatter_nd_inputs(data, indices, updates)
 
-    def gen_ir(data_ptr, indices_ptr, out_ptr):
+    def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         ib = tvm.tir.ir_builder.create()
 
         data = ib.buffer_ptr(data_ptr)
         indices = ib.buffer_ptr(indices_ptr)
+        updates = ib.buffer_ptr(updates_ptr)
         out = ib.buffer_ptr(out_ptr)
 
         # We combine all the indices dimensions but the first one into a single
         # dimension so we can iterate it in single loop instead of an arbitrary
-        # number of loops. We do the same thing for all the data dimensions.
+        # number of loops. We do the same thing for all the update dimensions.
         fused_indices_dimension = 1
         for i in indices_ptr.shape[1:]:
             fused_indices_dimension *= i
 
-        fused_data_dimension = 1
-        for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
-            fused_data_dimension *= i
+        fused_updates_dimension = 1
+        for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
+            fused_updates_dimension *= i
 
         fused_shape = 1
-        for i in shape:
+        for i in data_ptr.shape:
             fused_shape *= i
 
         # For now we avoid parallizing over dimensions indexed by `indices` as
@@ -789,38 +796,41 @@ def scatter_nd(data, indices, shape):
         bx = te.thread_axis("blockIdx.x")
         tx = te.thread_axis("threadIdx.x")
         max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
-        tdim = min(max_threads, fused_data_dimension)
+        tdim = min(max_threads, fused_updates_dimension)
         ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_data_dimension, tdim)
+        bdim = ceil_div(fused_updates_dimension, tdim)
         ib.scope_attr(bx, "thread_extent", bdim)
 
-        # zero data
-        # TODO(tkonolige): could we use topi.full to zero it instead?
         with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
-            index = i * fused_data_dimension + bx * tdim + tx
+            index = i * fused_updates_dimension + bx * tdim + tx
             with ib.if_scope(index < fused_shape):
-                out[index] = tvm.tir.Cast(data_ptr.dtype, 0)
+                out[index] = data[index]
 
         with ib.for_range(0, fused_indices_dimension) as i:
             j = bx * tdim + tx
-            with ib.if_scope(j < fused_data_dimension):
-                offset = fused_data_dimension
+            with ib.if_scope(j < fused_updates_dimension):
+                offset = fused_updates_dimension
                 index = j  # This is x_M, .. x_{N-1} part of the index into out.
                 # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
                 # of the index into out.
                 for l in reversed(range(indices_ptr.shape[0].value)):
                     # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
                     index += offset * indices[i + l * fused_indices_dimension]
-                    offset *= shape[l]
-                out[index] += data[i * fused_data_dimension + j]
+                    offset *= data_ptr.shape[l]
+                if mode == "update":
+                    out[index] = updates[i * fused_updates_dimension + j]
+                elif mode == "add":
+                    out[index] += updates[i * fused_updates_dimension + j]
+                else:
+                    raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
 
         return ib.get()
 
-    out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
+    out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
     return te.extern(
-        [shape],
-        [data, indices],
-        lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+        [data.shape],
+        [data, indices, updates],
+        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
         dtype=data.dtype,
         out_buffers=[out_buf],
         name="scatter_nd_cuda",
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index a376963..d7b008c 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
 """Scatter operator"""
-from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate
+from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate
 from ..te import extern, hybrid
 
 
@@ -199,22 +199,22 @@ def scatter(data, indices, updates, axis=0):
     raise ValueError("scatter only support for 1-4 dimensions")
 
 
-def _verify_scatter_nd_inputs(data, indices, shape):
+def _verify_scatter_nd_inputs(data, indices, updates):
     mdim = int(indices.shape[0])
-    assert mdim <= len(shape), (
+    assert mdim <= len(data.shape), (
         f"The first dimension of the indices ({mdim}) must be less than or equal to "
         f"the length of the shape of the output ({len(shape)})."
     )
     for i in range(len(indices.shape) - 1):
-        assert indices.shape[i + 1] == data.shape[i], (
+        assert indices.shape[i + 1] == updates.shape[i], (
             f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
-            f"data[{i}] ({data.shape[i]})."
+            f"updates[{i}] ({updates.shape[i]})."
         )
-    for i in range(mdim, len(shape)):
+    for i in range(mdim, len(data.shape)):
         data_ind = i - mdim + len(indices.shape) - 1
-        assert data.shape[data_ind] == shape[i], (
-            f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension "
-            f"of out_shape[{i}] ({shape[i]})."
+        assert updates.shape[data_ind] == data.shape[i], (
+            f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension "
+            f"of out_shape[{i}] ({data.shape[i]})."
         )
 
     assert (
@@ -222,11 +222,12 @@ def _verify_scatter_nd_inputs(data, indices, shape):
     ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}."
 
 
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
     """Scatter elements from a n-dimension array.
 
-    Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
-    (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+    Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+    (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}),
+    scatter_nd computes
 
     .. code-block::
 
@@ -236,9 +237,9 @@ def scatter_nd(data, indices, shape):
                x_M,
                ...,
                x_{N-1}
-              ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+              ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
 
-    all other entries in the output are 0. Repeated indices are summed.
+    where the update function f is determinted by the mode.
 
     Parameters
     ----------
@@ -248,29 +249,33 @@ def scatter_nd(data, indices, shape):
     indices : tvm.te.Tensor
         The indices of the values to extract.
 
-    shape : Sequence[int]
-        The output shape. This must be specified because it cannot be inferred.
+    updates : tvm.te.Tensor
+        The updates to apply at the Indices
+
+    mode : string
+        The update mode for the algorithm, either "update" or "add"
+        If update, the update values will replace the input data
+        If add, the update values will be added to the input data
 
     Returns
     -------
     ret : tvm.te.Tensor
     """
-    _verify_scatter_nd_inputs(data, indices, shape)
+    _verify_scatter_nd_inputs(data, indices, updates)
 
-    def gen_ir(data_ptr, indices_ptr, out_ptr):
+    def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         ib = ir_builder.create()
 
         data = ib.buffer_ptr(data_ptr)
         indices = ib.buffer_ptr(indices_ptr)
+        updates = ib.buffer_ptr(updates_ptr)
         out = ib.buffer_ptr(out_ptr)
 
-        # zero data
-        # TODO(tkonolige): could we use topi.full to zero it instead?
         fused_shape = 1
-        for i in shape:
+        for i in data.shape:
             fused_shape *= i
         with ib.for_range(0, fused_shape) as i:
-            out[i] = Cast(data_ptr.dtype, 0)
+            out[i] = data[i]
 
         # We combine all the indices dimensions but the first one into a single
         # dimension so we can iterate it in single loop instead of an arbitrary
@@ -300,15 +305,20 @@ def scatter_nd(data, indices, shape):
                         )
                     )
                     offset *= shape[l]
-                out[index] += data[i * fused_data_dimension + j]
+                if mode == "add":
+                    out[index] += updates[i * fused_data_dimension + j]
+                elif mode == "update":
+                    out[index] = updates[i * fused_data_dimension + j]
+                else:
+                    raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
 
         return ib.get()
 
     out_buf = decl_buffer(shape, data.dtype, "out_buf")
     return extern(
         [shape],
-        [data, indices],
-        lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+        [data, indices, updates],
+        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
         dtype=data.dtype,
         out_buffers=[out_buf],
         name="scatter_nd_generic",
diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py
index 8bb3f57..5eb5e6e 100644
--- a/python/tvm/topi/x86/scatter.py
+++ b/python/tvm/topi/x86/scatter.py
@@ -20,11 +20,12 @@ from tvm import te
 from ..scatter import _verify_scatter_nd_inputs
 
 
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
     """Scatter elements from a n-dimension array.
 
-    Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
-    (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+    Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+    (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}),
+    scatter_nd computes
 
     .. code-block::
 
@@ -34,9 +35,9 @@ def scatter_nd(data, indices, shape):
                x_M,
                ...,
                x_{N-1}
-              ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+              ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
 
-    all other entries in the output are 0. Repeated indices are summed.
+    where the update function f is determinted by the mode.
 
     Parameters
     ----------
@@ -46,62 +47,71 @@ def scatter_nd(data, indices, shape):
     indices : tvm.te.Tensor
         The indices of the values to extract.
 
-    shape : Sequence[int]
-        The output shape. This must be specified because it cannot be inferred.
+    updates : tvm.te.Tensor
+        The updates to apply at the Indices
+
+    mode : string
+        The update mode for the algorithm, either "update" or "add"
+        If update, the update values will replace the input data
+        If add, the update values will be added to the input data
 
     Returns
     -------
     ret : tvm.te.Tensor
     """
-    _verify_scatter_nd_inputs(data, indices, shape)
+    _verify_scatter_nd_inputs(data, indices, updates)
 
-    def gen_ir(data_ptr, indices_ptr, out_ptr):
+    def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         # pylint: disable=invalid-name
         ib = tvm.tir.ir_builder.create()
 
         data = ib.buffer_ptr(data_ptr)
         indices = ib.buffer_ptr(indices_ptr)
+        updates = ib.buffer_ptr(updates_ptr)
         out = ib.buffer_ptr(out_ptr)
 
         # We combine all the indices dimensions but the first one into a single
         # dimension so we can iterate it in single loop instead of an arbitrary
-        # number of loops. We do the same thing for all the data dimensions.
+        # number of loops. We do the same thing for all the update dimensions.
         fused_indices_dimension = 1
         for i in indices_ptr.shape[1:]:
             fused_indices_dimension *= i
 
-        fused_data_dimension = 1
-        for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
-            fused_data_dimension *= i
+        fused_updates_dimension = 1
+        for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
+            fused_updates_dimension *= i
 
         fused_shape = 1
-        for i in shape:
+        for i in data_ptr.shape:
             fused_shape *= i
 
-        # zero data
-        # TODO(tkonolige): could we use topi.full to zero it instead?
         with ib.for_range(0, fused_shape) as i:
-            out[i] = tvm.tir.Cast(data_ptr.dtype, 0)
+            out[i] = data[i]
 
         with ib.for_range(0, fused_indices_dimension) as i:
-            with ib.for_range(0, fused_data_dimension, kind="parallel") as j:
-                offset = fused_data_dimension
+            with ib.for_range(0, fused_updates_dimension, kind="parallel") as j:
+                offset = fused_updates_dimension
                 index = j  # This is x_M, .. x_{N-1} part of the index into out.
                 # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
                 # of the index into out.
                 for l in reversed(range(indices_ptr.shape[0].value)):
                     # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
                     index += offset * indices[i + l * fused_indices_dimension]
-                    offset *= shape[l]
-                out[index] += data[i * fused_data_dimension + j]
+                    offset *= data_ptr.shape[l]
+                if mode == "update":
+                    out[index] = updates[i * fused_updates_dimension + j]
+                elif mode == "add":
+                    out[index] += updates[i * fused_updates_dimension + j]
+                else:
+                    raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
 
         return ib.get()
 
-    out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
+    out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
     return te.extern(
-        [shape],
-        [data, indices],
-        lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+        [data.shape],
+        [data, indices, updates],
+        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
         dtype=data.dtype,
         out_buffers=[out_buf],
         name="scatter_nd_x86",
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 11e94cb..e937cb0 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1096,10 +1096,11 @@ TVM_REGISTER_NODE_TYPE(ScatterNDAttrs);
 
 bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                   const TypeReporter& reporter) {
-  // `types` contains: [data, indices, result]
-  ICHECK_EQ(types.size(), 3);
+  // `types` contains: [data, indices, updates, result]
+  ICHECK_EQ(types.size(), 4);
   const auto* data = types[0].as<TensorTypeNode>();
   const auto* indices = types[1].as<TensorTypeNode>();
+  const auto* updates = types[2].as<TensorTypeNode>();
   if (data == nullptr) {
     ICHECK(types[0].as<IncompleteTypeNode>())
         << "ScatterND: expect input data type to be TensorType but got " << types[0];
@@ -1110,8 +1111,14 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
         << "ScatterND: expect indices type to be TensorType but got " << types[1];
     return false;
   }
+  if (updates == nullptr) {
+    ICHECK(types[2].as<IncompleteTypeNode>())
+        << "ScatterND: expect updates type to be TensorType but got " << types[2];
+    return false;
+  }
   ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers.";
-  const auto out_shape = attrs.as<ScatterNDAttrs>()->out_shape;
+
+  const auto out_shape = data->shape;
   const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
   const size_t kdim = indices->shape.size() - 1;
   const size_t ndim = out_shape.size();
@@ -1120,7 +1127,7 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
          "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N.";
   // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's.
   for (size_t i = 0; i < kdim; i++) {
-    reporter->AssertEQ(indices->shape[i + 1], data->shape[i]);
+    reporter->AssertEQ(indices->shape[i + 1], updates->shape[i]);
   }
 
   std::vector<IndexExpr> oshape;
@@ -1133,15 +1140,15 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]);
   }
 
-  reporter->Assign(types[2], TensorType(oshape, data->dtype));
+  reporter->Assign(types[3], TensorType(data->shape, data->dtype));
   return true;
 }
 
-Expr MakeScatterND(Expr data, Expr indices, const Array<Integer> out_shape) {
+Expr MakeScatterND(Expr data, Expr indices, Expr updates, String mode) {
   auto attrs = make_object<ScatterNDAttrs>();
-  attrs->out_shape = out_shape;
+  attrs->mode = std::move(mode);
   static const Op& op = Op::Get("scatter_nd");
-  return Call(op, {data, indices}, Attrs(attrs), {});
+  return Call(op, {data, indices, updates}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND);
@@ -1156,9 +1163,10 @@ whose shape is defined by indices.
 Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape
 (M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}).
 )code" TVM_ADD_FILELINE)
-    .set_num_inputs(2)
+    .set_num_inputs(3)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("indices", "Tensor", "The indices tensor.")
+    .add_argument("updates", "Tensor", "The input tensor.")
     .set_support_level(3)
     .add_type_rel("ScatterND", ScatterNDRel)
     .set_attr<TOpPattern>("TOpPattern", kOpaque);
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 1a3d0d4..e11689c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4237,7 +4237,6 @@ unsupported_onnx_tests = [
     "test_round/",
     "test_scan9_sum/",
     "test_scan_sum/",
-    "test_scatternd/",
     "test_simple_rnn_defaults/",
     "test_simple_rnn_with_initial_bias/",
     "test_strnormalizer_export_monday_casesensintive_lower/",
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index bf0a7e4..e84b22b 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1833,38 +1833,39 @@ def test_cumprod(target, dev):
 
 @tvm.testing.parametrize_targets
 def test_scatter_nd(target, dev):
-    def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
+    def verify_scatter_nd(
+        data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5
+    ):
         data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
         indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype))
+        updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype))
 
-        out = relay.op.scatter_nd(data, indices, shape)
-        func = relay.Function([data, indices], out)
+        out = relay.op.scatter_nd(data, indices, updates, mode)
+        func = relay.Function([data, indices, updates], out)
 
         for kind in ["graph", "debug"]:
             intrp = relay.create_executor(kind, device=dev, target=target)
-            op_res = intrp.evaluate(func)(data_np, indices_np)
+            op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
 
-    def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
+    def verify_scatter_nd_with_stack(
+        data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5
+    ):
         data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
         indices_vars = [
             relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np)
         ]
+        updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype))
 
         # test if scatter_nd works in case indices are prepared by another Relay operator
         indices = relay.op.stack(indices_vars, axis=0)
-        out = relay.op.scatter_nd(data, indices, shape)
+        out = relay.op.scatter_nd(data, indices, updates, mode)
         func = relay.Function(
-            [
-                data,
-            ]
-            + indices_vars,
+            [data, updates] + indices_vars,
             out,
         )
 
-        fargs = [
-            data_np,
-        ]
+        fargs = [data_np, updates_np]
         for a in indices_np:
             fargs.append(a)
         for kind in ["graph", "debug"]:
@@ -1872,39 +1873,47 @@ def test_scatter_nd(target, dev):
             op_res = intrp.evaluate(func)(*fargs)
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)
 
-    data = np.array([2, 3, 0])
+    data = np.zeros((2, 2)).astype("int64")
     indices = np.array([[1, 1, 0], [0, 1, 0]])
-    shape = (2, 2)
+    updates = np.array([2, 3, 0])
     out = np.array([[0, 0], [2, 3]])
-    verify_scatter_nd(data, indices, shape, out)
-    verify_scatter_nd_with_stack(data, indices, shape, out)
+    verify_scatter_nd(data, indices, updates, out)
+    verify_scatter_nd_with_stack(data, indices, updates, out)
 
-    data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+    data = np.zeros((2, 2, 2, 2)).astype("int64")
     indices = np.array([[0, 1], [1, 1]])
-    shape = (2, 2, 2, 2)
+    updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
     out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
-    verify_scatter_nd(data, indices, shape, out)
-    verify_scatter_nd_with_stack(data, indices, shape, out)
+    verify_scatter_nd(data, indices, updates, out)
+    verify_scatter_nd_with_stack(data, indices, updates, out)
 
-    data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
     indices = np.array([[1, 0, 0]])
+    updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
     shape = (2, 1560)
-    out = np.zeros(shape).astype("float32")
-    out[1, :] += data[0, :]
-    out[0, :] += data[1, :]
-    out[0, :] += data[2, :]
-    verify_scatter_nd(data, indices, shape, out)
-    verify_scatter_nd_with_stack(data, indices, shape, out)
-
-    data = np.ones((5, 3)).astype("float64")
-    indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
-    shape = (2, 7, 3)
-    out = np.zeros(shape).astype("float64")
-    for i in range(indices.shape[1]):
-        for j in range(data.shape[1]):
-            out[indices[0, i], indices[1, i], j] += data[i, j]
-    verify_scatter_nd(data, indices, shape, out)
-    verify_scatter_nd_with_stack(data, indices, shape, out)
+    data = np.zeros(shape).astype("float32")
+    out = data.copy()
+    out[1, :] += updates[0, :]
+    out[0, :] += updates[1, :]
+    out[0, :] += updates[2, :]
+    verify_scatter_nd(data, indices, updates, out)
+    verify_scatter_nd_with_stack(data, indices, updates, out)
+
+    for mode in ["add", "update"]:
+        indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype(
+            "int64"
+        )
+        updates = np.ones((5, 3)).astype("float64")
+        shape = (2, 7, 3)
+        data = np.random.random(shape).astype("float64")
+        out = data.copy()
+        for i in range(indices.shape[1]):
+            for j in range(updates.shape[1]):
+                if mode == "add":
+                    out[indices[0, i], indices[1, i], j] += updates[i, j]
+                elif mode == "update":
+                    out[indices[0, i], indices[1, i], j] = updates[i, j]
+        verify_scatter_nd(data, indices, updates, out, mode)
+        verify_scatter_nd_with_stack(data, indices, updates, out, mode)
 
 
 def test_unique():
diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py
index ad73bb5..648ef62 100644
--- a/tests/python/topi/python/test_topi_scatter.py
+++ b/tests/python/topi/python/test_topi_scatter.py
@@ -23,44 +23,64 @@ import tvm.topi.testing
 
 @tvm.testing.parametrize_targets
 def test_scatter_nd(dev, target):
-    def check_scatter_nd(data, indices, shape, out):
+    def check_scatter_nd(data, indices, updates, out, mode="add"):
         implementations = {
-            "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern),
-            "gpu": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern),
-            "cpu": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern),
+            "generic": (
+                lambda x, y, z: topi.scatter_nd(x, y, z, mode),
+                topi.generic.schedule_extern,
+            ),
+            "gpu": (
+                lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode),
+                topi.generic.schedule_extern,
+            ),
+            "cpu": (
+                lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode),
+                topi.generic.schedule_extern,
+            ),
         }
         fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
-        tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, dev, fcompute, fschedule)
+        tvm.topi.testing.compare_numpy_tvm(
+            [data, indices, updates], out, target, dev, fcompute, fschedule
+        )
 
-    data = np.array([2, 3, 0])
+    data = np.zeros((2, 2)).astype("int64")
     indices = np.array([[1, 1, 0], [0, 1, 0]])
-    shape = (2, 2)
+    updates = np.array([2, 3, 0])
     out = np.array([[0, 0], [2, 3]])
-    check_scatter_nd(data, indices, shape, out)
+    check_scatter_nd(data, indices, updates, out)
 
-    data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+    data = np.zeros((2, 2, 2, 2)).astype("int64")
     indices = np.array([[0, 1], [1, 1]])
-    shape = (2, 2, 2, 2)
+    updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
     out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
-    check_scatter_nd(data, indices, shape, out)
+    check_scatter_nd(data, indices, updates, out)
 
-    data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
     indices = np.array([[1, 0, 0]])
+    updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
     shape = (2, 1560)
-    out = np.zeros(shape).astype("float32")
-    out[1, :] += data[0, :]
-    out[0, :] += data[1, :]
-    out[0, :] += data[2, :]
-    check_scatter_nd(data, indices, shape, out)
+    data = np.zeros(shape).astype("float32")
+    out = data.copy()
+    out[1, :] += updates[0, :]
+    out[0, :] += updates[1, :]
+    out[0, :] += updates[2, :]
+    check_scatter_nd(data, indices, updates, out)
 
-    data = np.ones((5, 3)).astype("float64")
-    indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
-    shape = (2, 7, 3)
-    out = np.zeros(shape).astype("float64")
-    for i in range(indices.shape[1]):
-        for j in range(data.shape[1]):
-            out[indices[0, i], indices[1, i], j] += data[i, j]
-    check_scatter_nd(data, indices, shape, out)
+    for mode in ["add", "update"]:
+        updates = np.ones((5, 3)).astype("float64")
+        indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype(
+            "int64"
+        )
+        shape = (2, 7, 3)
+        data = np.random.random(shape).astype("float64")
+        out = data.copy()
+        for i in range(indices.shape[1]):
+            for j in range(updates.shape[1]):
+                if mode == "add":
+                    out[indices[0, i], indices[1, i], j] += updates[i, j]
+                elif mode == "update":
+                    out[indices[0, i], indices[1, i], j] = updates[i, j]
+
+        check_scatter_nd(data, indices, updates, out, mode)
 
 
 if __name__ == "__main__":