You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/12/01 19:20:30 UTC

[tvm] branch main updated: [RELAY,TOPI] Add scatter_nd op (#6854)

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0421efb  [RELAY,TOPI] Add scatter_nd op (#6854)
0421efb is described below

commit 0421efba4c3a42c6cf8d692734c24fe8e08e3884
Author: Tristan Konolige <tr...@gmail.com>
AuthorDate: Tue Dec 1 11:20:09 2020 -0800

    [RELAY,TOPI] Add scatter_nd op (#6854)
    
    * [RELAY,TOPI] Add scatter_nd op
    
    Scatter_nd is the inverse of gather_nd and also happens to be its
    gradient. The implementation here is not optimized. There are no cpu or
    gpu specific implementations.
    
    * formatting
    
    * Fix tests
    
    * formatting
    
    * specify types on test
    
    * Fix grad test
    
    * scatter_nd cuda impl
    
    * cuda impl
    
    * x86 impl
    
    * formatting
    
    * fix shape rel
    
    * fix tests
    
    * formatting
---
 include/tvm/relay/attrs/transform.h           |   8 ++
 python/tvm/relay/backend/compile_engine.py    |   5 +-
 python/tvm/relay/op/_tensor_grad.py           |   7 ++
 python/tvm/relay/op/_transform.py             |   9 ++
 python/tvm/relay/op/strategy/cuda.py          |  13 +++
 python/tvm/relay/op/strategy/generic.py       |  22 +++++
 python/tvm/relay/op/strategy/x86.py           |  13 +++
 python/tvm/relay/op/transform.py              |  24 ++++++
 python/tvm/relay/testing/__init__.py          |   2 +
 python/tvm/te/operation.py                    |   6 +-
 python/tvm/topi/cuda/scatter.py               | 106 +++++++++++++++++++++++
 python/tvm/topi/scatter.py                    | 120 +++++++++++++++++++++++++-
 python/tvm/topi/testing/__init__.py           |   1 +
 python/tvm/topi/testing/common.py             |  31 +++++++
 python/tvm/topi/x86/__init__.py               |   1 +
 python/tvm/topi/x86/scatter.py                | 109 +++++++++++++++++++++++
 src/relay/analysis/type_solver.cc             |   9 +-
 src/relay/op/tensor/transform.cc              |  68 +++++++++++++++
 tests/python/relay/test_any.py                |   5 +-
 tests/python/relay/test_op_grad_level3.py     |   9 ++
 tests/python/topi/python/test_topi_scatter.py |  67 ++++++++++++++
 21 files changed, 627 insertions(+), 8 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index a7830cf..3ed6b83 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -129,6 +129,14 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
   }
 };
 
+struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
+  Array<Integer> out_shape;
+
+  TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
+    TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter.");
+  }
+};
+
 struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
   Integer axis;
 
diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py
index 32affe7..a39f72e 100644
--- a/python/tvm/relay/backend/compile_engine.py
+++ b/python/tvm/relay/backend/compile_engine.py
@@ -122,7 +122,10 @@ def get_valid_implementations(op, attrs, inputs, out_type, target):
         The list of all valid op implementations.
     """
     fstrategy = op.get_attr("FTVMStrategy")
-    assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
+    assert fstrategy is not None, (
+        "%s doesn't have an FTVMStrategy registered. You can register "
+        "one in python with `tvm.relay.op.register_strategy`." % op.name
+    )
     with target:
         strategy = fstrategy(attrs, inputs, out_type, target)
     analyzer = tvm.arith.Analyzer()
diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index b070d9f..9c84411 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -62,6 +62,7 @@ from .transform import (
     squeeze,
     strided_set,
     arange,
+    scatter_nd,
 )
 
 
@@ -803,3 +804,9 @@ def arange_grad(orig, grad):
     grad_step = cast_like(_sum(grad_step), step)
 
     return [grad_start, grad_stop, grad_step]
+
+
+@register_gradient("gather_nd")
+def gather_nd_grad(orig, grad):
+    data, indices = orig.args
+    return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 439d44b..e1cb9e9 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -115,6 +115,15 @@ def compute_scatter_add(attrs, inputs, output_type):
 
 _reg.register_strategy("scatter_add", strategy.scatter_add_strategy)
 
+# scatter
+@_reg.register_compute("scatter_nd")
+def compute_scatter_nd(attrs, inputs, output_type):
+    """Compute definition of scatter_nd"""
+    return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)]
+
+
+_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)
+
 #####################
 #  Shape functions  #
 #####################
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index f37fc2a..bd96cad 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -776,6 +776,19 @@ def scatter_add_cuda(attrs, inputs, out_type, target):
     return strategy
 
 
+@scatter_nd_strategy.register(["cuda", "gpu"])
+def scatter_nd_cuda(attrs, inputs, out_type, target):
+    """scatter_nd cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_scatter_nd(topi.cuda.scatter_nd),
+        wrap_topi_schedule(topi.generic.schedule_extern),
+        name="scatter_nd.cuda",
+        plevel=10,
+    )
+    return strategy
+
+
 @argsort_strategy.register(["cuda", "gpu"])
 def argsort_strategy_cuda(attrs, inputs, out_type, target):
     """argsort cuda strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index e49135c..ac9d3b1 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1063,6 +1063,28 @@ def scatter_add_strategy(attrs, outs, out_type, target):
     return strategy
 
 
+# scatter_nd
+@override_native_generic_func("scatter_nd_strategy")
+def scatter_nd_strategy(attrs, inputs, out_type, target):
+    """scatter_nd generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_scatter_nd(topi.scatter_nd),
+        wrap_topi_schedule(topi.generic.schedule_extern),
+        name="scatter_nd.generic",
+    )
+    return strategy
+
+
+def wrap_compute_scatter_nd(topi_compute):
+    """Wrap scatter_nd topi compute"""
+
+    def _compute_scatter_nd(attrs, inputs, _):
+        return [topi_compute(inputs[0], inputs[1], attrs.out_shape)]
+
+    return _compute_scatter_nd
+
+
 # bitserial_conv2d
 def wrap_compute_bitserial_conv2d(topi_compute):
     """wrap bitserial_conv2d topi compute"""
diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py
index 3c5735b..3f129c4 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -446,3 +446,16 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target):
         name="bitserial_dense.x86",
     )
     return strategy
+
+
+@scatter_nd_strategy.register("cpu")
+def scatter_nd_strategy_cpu(attrs, inputs, out_type, target):
+    """scatter_nd x86 strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_scatter_nd(topi.x86.scatter_nd),
+        wrap_topi_schedule(topi.generic.schedule_extern),
+        name="scatter_nd.x86",
+        plevel=10,
+    )
+    return strategy
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 19488a0..7e7f9b2 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -308,6 +308,30 @@ def scatter_add(data, indices, updates, axis):
     return _make.scatter_add(data, indices, updates, axis)
 
 
+def scatter_nd(data, indices, out_shape):
+    """Scatter values from an array.
+
+    See :py:func:`tvm.topi.scatter` for how data is scattered.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    indices : relay.Expr
+        The index locations to update.
+
+    out_shape : relay.Expr
+        Output shape of the scatter.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+    """
+    return _make.scatter_nd(data, indices, out_shape)
+
+
 def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
     """Reshapes the input tensor by the size of another tensor.
     For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes
diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py
index 9c87f27..93110e3 100644
--- a/python/tvm/relay/testing/__init__.py
+++ b/python/tvm/relay/testing/__init__.py
@@ -143,6 +143,8 @@ def check_grad(
                         break
             grads = tmp
 
+        assert len(grads) > 0, "You must test at least one gradient."
+
         # Get numeric gradients for each dimension of each param, using two-sided approximation.
         approx_grads = []
         for x in test_inputs:
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 30d0df3..0f3457a 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -317,7 +317,11 @@ def extern(
     if isinstance(body, tvm.tir.PrimExpr):
         body = tvm.tir.Evaluate(body)
     if not isinstance(body, tvm.tir.Stmt):
-        raise ValueError("Function '{}' should return PrimExpr or Stmt".format(fcompute.__name__))
+        raise ValueError(
+            "Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(
+                fcompute.__name__, type(body)
+            )
+        )
 
     op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body)
     res = [op.output(i) for i in range(len(output_placeholders))]
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index 0a3e96f..5e03faf 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -18,6 +18,7 @@
 """Scatter operator """
 import tvm
 from tvm import te
+from ..scatter import _verify_scatter_nd_inputs
 
 
 def ceil_div(a, b):
@@ -522,3 +523,108 @@ def scatter_add(data, indices, updates, axis=0):
     )
 
     return out
+
+
+def scatter_nd(data, indices, shape):
+    """Scatter elements from a n-dimension array.
+
+    Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+    (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+
+    .. code-block::
+
+        output[indices[0, y_0, ..., y_{K-1}],
+               ...,
+               indices[M-1, y_0, ..., y_{K-1}],
+               x_M,
+               ...,
+               x_{N-1}
+              ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+
+    all other entries in the output are 0. Repeated indices are summed.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The source array.
+
+    indices : tvm.te.Tensor
+        The indices of the values to extract.
+
+    shape : Sequence[int]
+        The output shape. This must be specified because it cannot be inferred.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+    """
+    _verify_scatter_nd_inputs(data, indices, shape)
+
+    def gen_ir(data_ptr, indices_ptr, out_ptr):
+        ib = tvm.tir.ir_builder.create()
+
+        data = ib.buffer_ptr(data_ptr)
+        indices = ib.buffer_ptr(indices_ptr)
+        out = ib.buffer_ptr(out_ptr)
+
+        # We combine all the indices dimensions but the first one into a single
+        # dimension so we can iterate it in single loop instead of an arbitrary
+        # number of loops. We do the same thing for all the data dimensions.
+        fused_indices_dimension = 1
+        for i in indices_ptr.shape[1:]:
+            fused_indices_dimension *= i
+
+        fused_data_dimension = 1
+        for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
+            fused_data_dimension *= i
+
+        fused_shape = 1
+        for i in shape:
+            fused_shape *= i
+
+        # For now we avoid parallizing over dimensions indexed by `indices` as
+        # there may be repeated indices and hadling parallel accumulation can
+        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
+        # work well when these dimensions are large enough to saturate memory
+        # bandwidth, but performance will be bad when these dimensions are
+        # small.
+        bx = te.thread_axis("blockIdx.x")
+        tx = te.thread_axis("threadIdx.x")
+        max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+        tdim = min(max_threads, fused_data_dimension)
+        ib.scope_attr(tx, "thread_extent", tdim)
+        bdim = ceil_div(fused_data_dimension, tdim)
+        ib.scope_attr(bx, "thread_extent", bdim)
+
+        # zero data
+        # TODO(tkonolige): could we use topi.full to zero it instead?
+        with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
+            index = i * fused_data_dimension + bx * tdim + tx
+            with ib.if_scope(index < fused_shape):
+                out[index] = tvm.tir.Cast(data_ptr.dtype, 0)
+
+        with ib.for_range(0, fused_indices_dimension) as i:
+            j = bx * tdim + tx
+            with ib.if_scope(j < fused_data_dimension):
+                offset = fused_data_dimension
+                index = j  # This is x_M, .. x_{N-1} part of the index into out.
+                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
+                # of the index into out.
+                for l in reversed(range(indices_ptr.shape[0].value)):
+                    # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
+                    index += offset * indices[i + l * fused_indices_dimension]
+                    offset *= shape[l]
+                out[index] += data[i * fused_data_dimension + j]
+
+        return ib.get()
+
+    out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
+    return te.extern(
+        [shape],
+        [data, indices],
+        lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+        dtype=data.dtype,
+        out_buffers=[out_buf],
+        name="scatter_nd_cuda",
+        tag="scatter_nd_cuda",
+    )
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index f1c307a..a376963 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -16,7 +16,8 @@
 # under the License.
 # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
 """Scatter operator"""
-from tvm.te import hybrid
+from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate
+from ..te import extern, hybrid
 
 
 @hybrid.script
@@ -196,3 +197,120 @@ def scatter(data, indices, updates, axis=0):
     if len(data.shape) == 4:
         return _scatter_4d(data, indices, updates, axis)
     raise ValueError("scatter only support for 1-4 dimensions")
+
+
+def _verify_scatter_nd_inputs(data, indices, shape):
+    mdim = int(indices.shape[0])
+    assert mdim <= len(shape), (
+        f"The first dimension of the indices ({mdim}) must be less than or equal to "
+        f"the length of the shape of the output ({len(shape)})."
+    )
+    for i in range(len(indices.shape) - 1):
+        assert indices.shape[i + 1] == data.shape[i], (
+            f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
+            f"data[{i}] ({data.shape[i]})."
+        )
+    for i in range(mdim, len(shape)):
+        data_ind = i - mdim + len(indices.shape) - 1
+        assert data.shape[data_ind] == shape[i], (
+            f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension "
+            f"of out_shape[{i}] ({shape[i]})."
+        )
+
+    assert (
+        "int" in indices.dtype
+    ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}."
+
+
+def scatter_nd(data, indices, shape):
+    """Scatter elements from a n-dimension array.
+
+    Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+    (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+
+    .. code-block::
+
+        output[indices[0, y_0, ..., y_{K-1}],
+               ...,
+               indices[M-1, y_0, ..., y_{K-1}],
+               x_M,
+               ...,
+               x_{N-1}
+              ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+
+    all other entries in the output are 0. Repeated indices are summed.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The source array.
+
+    indices : tvm.te.Tensor
+        The indices of the values to extract.
+
+    shape : Sequence[int]
+        The output shape. This must be specified because it cannot be inferred.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+    """
+    _verify_scatter_nd_inputs(data, indices, shape)
+
+    def gen_ir(data_ptr, indices_ptr, out_ptr):
+        ib = ir_builder.create()
+
+        data = ib.buffer_ptr(data_ptr)
+        indices = ib.buffer_ptr(indices_ptr)
+        out = ib.buffer_ptr(out_ptr)
+
+        # zero data
+        # TODO(tkonolige): could we use topi.full to zero it instead?
+        fused_shape = 1
+        for i in shape:
+            fused_shape *= i
+        with ib.for_range(0, fused_shape) as i:
+            out[i] = Cast(data_ptr.dtype, 0)
+
+        # We combine all the indices dimensions but the first one into a single
+        # dimension so we can iterate it in single loop instead of an arbitrary
+        # number of loops. We do the same thing for all the data dimensions.
+        fused_indices_dimension = 1
+        for i in indices_ptr.shape[1:]:
+            fused_indices_dimension *= i
+
+        fused_data_dimension = 1
+        for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
+            fused_data_dimension *= i
+
+        with ib.for_range(0, fused_indices_dimension, name="i") as i:
+            with ib.for_range(0, fused_data_dimension, name="j") as j:
+                offset = fused_data_dimension
+                index = j  # This is x_M, .. x_{N-1} part of the index into out.
+                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
+                # of the index into out.
+                for l in reversed(range(indices_ptr.shape[0].value)):
+                    # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
+                    index += offset * indices[i + l * fused_indices_dimension]
+                    ib.emit(
+                        AssertStmt(
+                            indices[i + l * fused_indices_dimension] < shape[l],
+                            StringImm("index out of bounds"),
+                            Evaluate(0),
+                        )
+                    )
+                    offset *= shape[l]
+                out[index] += data[i * fused_data_dimension + j]
+
+        return ib.get()
+
+    out_buf = decl_buffer(shape, data.dtype, "out_buf")
+    return extern(
+        [shape],
+        [data, indices],
+        lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+        dtype=data.dtype,
+        out_buffers=[out_buf],
+        name="scatter_nd_generic",
+        tag="scatter_nd_generic",
+    )
diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py
index 4f90550..0654344 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -57,6 +57,7 @@ from .depth_to_space import depth_to_space_python
 from .space_to_depth import space_to_depth_python
 from .crop_and_resize_python import crop_and_resize_python
 from .common import (
+    compare_numpy_tvm,
     get_injective_schedule,
     get_reduce_schedule,
     get_broadcast_schedule,
diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py
index 51ea19a..e4e5e81 100644
--- a/python/tvm/topi/testing/common.py
+++ b/python/tvm/topi/testing/common.py
@@ -17,8 +17,10 @@
 # pylint: disable=invalid-name
 """Common utility for topi test"""
 
+import numpy as np
 import tvm
 from tvm import topi
+from tvm.testing import assert_allclose
 
 _injective_schedule = {
     "generic": topi.generic.schedule_injective,
@@ -77,3 +79,32 @@ _conv2d_nchw_implement = {
 
 def get_conv2d_nchw_implement(target):
     return dispatch(target, _conv2d_nchw_implement)
+
+
+def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule):
+    """Compare a numpy inputs and output of a function to the results of the TVM version.
+
+    Parameters
+    ----------
+    inputs : Sequence[numpy.nd.array]
+        List of input numpy arrays to pass to the function.
+    output : numpy.nd.array
+        Verified correct function output.
+    target : tvm.target.Target
+        Target to run on.
+    ctx : tvm.TVMContext
+        Context to run on.
+    compute : callable
+        Topi compute function to test against.
+    schedule : callable
+        Topi scheduling function to test against.
+    """
+    te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs]
+    te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx)
+    with tvm.target.Target(target):
+        out = compute(*te_inputs)
+        s = schedule([out])
+        func = tvm.build(s, te_inputs + [out])
+        arys = [tvm.nd.array(x, ctx=ctx) for x in inputs]
+        func(*(arys + [te_out]))
+        assert_allclose(te_out.asnumpy(), output, atol=1e-4, rtol=1e-4)
diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py
index 659668c..1545110 100644
--- a/python/tvm/topi/x86/__init__.py
+++ b/python/tvm/topi/x86/__init__.py
@@ -39,3 +39,4 @@ from .conv2d_transpose import *
 from .conv3d_transpose import *
 from .sparse import *
 from .conv2d_alter_op import *
+from .scatter import *
diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py
new file mode 100644
index 0000000..8147d3a
--- /dev/null
+++ b/python/tvm/topi/x86/scatter.py
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Scatter operators for x86"""
+import tvm
+from tvm import te
+from ..scatter import _verify_scatter_nd_inputs
+
+
+def scatter_nd(data, indices, shape):
+    """Scatter elements from a n-dimension array.
+
+    Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
+    (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
+
+    .. code-block::
+
+        output[indices[0, y_0, ..., y_{K-1}],
+               ...,
+               indices[M-1, y_0, ..., y_{K-1}],
+               x_M,
+               ...,
+               x_{N-1}
+              ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
+
+    all other entries in the output are 0. Repeated indices are summed.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The source array.
+
+    indices : tvm.te.Tensor
+        The indices of the values to extract.
+
+    shape : Sequence[int]
+        The output shape. This must be specified because it cannot be inferred.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+    """
+    _verify_scatter_nd_inputs(data, indices, shape)
+
+    def gen_ir(data_ptr, indices_ptr, out_ptr):
+        # pylint: disable=invalid-name
+        ib = tvm.tir.ir_builder.create()
+
+        data = ib.buffer_ptr(data_ptr)
+        indices = ib.buffer_ptr(indices_ptr)
+        out = ib.buffer_ptr(out_ptr)
+
+        # We combine all the indices dimensions but the first one into a single
+        # dimension so we can iterate it in single loop instead of an arbitrary
+        # number of loops. We do the same thing for all the data dimensions.
+        fused_indices_dimension = 1
+        for i in indices_ptr.shape[1:]:
+            fused_indices_dimension *= i
+
+        fused_data_dimension = 1
+        for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
+            fused_data_dimension *= i
+
+        fused_shape = 1
+        for i in shape:
+            fused_shape *= i
+
+        # zero data
+        # TODO(tkonolige): could we use topi.full to zero it instead?
+        with ib.for_range(0, fused_shape) as i:
+            out[i] = tvm.tir.Cast(data_ptr.dtype, 0)
+
+        with ib.for_range(0, fused_indices_dimension) as i:
+            with ib.for_range(0, fused_data_dimension, for_type="parallel") as j:
+                offset = fused_data_dimension
+                index = j  # This is x_M, .. x_{N-1} part of the index into out.
+                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
+                # of the index into out.
+                for l in reversed(range(indices_ptr.shape[0].value)):
+                    # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
+                    index += offset * indices[i + l * fused_indices_dimension]
+                    offset *= shape[l]
+                out[index] += data[i * fused_data_dimension + j]
+
+        return ib.get()
+
+    out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
+    return te.extern(
+        [shape],
+        [data, indices],
+        lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
+        dtype=data.dtype,
+        out_buffers=[out_buf],
+        name="scatter_nd_x86",
+        tag="scatter_nd_x86",
+    )
diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc
index 8f14b55..64db13a 100644
--- a/src/relay/analysis/type_solver.cc
+++ b/src/relay/analysis/type_solver.cc
@@ -246,7 +246,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     for (size_t i = 0; i < tt1->shape.size(); i++) {
       auto dim = UnifyDim(tt1->shape[i], tt2->shape[i]);
       if (!dim.defined()) {
-        // NB: We push an arbitrary dimension here so we can continue error propogation.
+        // NB: We push an arbitrary dimension here so we can continue error propagation.
         shape.push_back(tt1->shape[i]);
         tvm::PrimExpr shape1 = tt1->shape[i];
         tvm::PrimExpr shape2 = tt2->shape[i];
@@ -259,10 +259,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
 
     if (mismatches.size() != 0) {
       auto err = Diagnostic::Error(this->span);
-      err << "in particular ";
+      err << "The Relay type checker is unable to show the following types match.\n";
+      err << "In particular ";
       for (auto mismatch : mismatches) {
-        err << "dimension " << std::get<0>(mismatch) << " conflicts " << std::get<1>(mismatch)
-            << " does not match " << std::get<2>(mismatch);
+        err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch)
+            << " does not match " << std::get<2>(mismatch) << ".";
       }
       this->solver_->diag_ctx_.Emit(err);
       return Type(nullptr);
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index d1f2f26..5a13e9a 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -977,6 +977,74 @@ RELAY_REGISTER_OP("scatter_add")
     .set_attr<TOpPattern>("TOpPattern", kOpaque)
     .set_support_level(10);
 
+// scatter_nd operator
+TVM_REGISTER_NODE_TYPE(ScatterNDAttrs);
+
+bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  // `types` contains: [data, indices, result]
+  ICHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* indices = types[1].as<TensorTypeNode>();
+  if (data == nullptr) {
+    ICHECK(types[0].as<IncompleteTypeNode>())
+        << "ScatterND: expect input data type to be TensorType but got " << types[0];
+    return false;
+  }
+  if (indices == nullptr) {
+    ICHECK(types[1].as<IncompleteTypeNode>())
+        << "ScatterND: expect indices type to be TensorType but got " << types[1];
+    return false;
+  }
+  ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers.";
+  const auto out_shape = attrs.as<ScatterNDAttrs>()->out_shape;
+  const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
+  const size_t kdim = indices->shape.size() - 1;
+  const size_t ndim = out_shape.size();
+  ICHECK_LE(size_t(mdim->value), ndim)
+      << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices "
+         "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N.";
+  // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's.
+  for (size_t i = 0; i < kdim; i++) {
+    reporter->AssertEQ(indices->shape[i + 1], data->shape[i]);
+  }
+
+  std::vector<IndexExpr> oshape;
+  for (auto& x : out_shape) {
+    oshape.push_back(x);
+  }
+
+  // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1}
+  for (size_t i = mdim->value; i < ndim; i++) {
+    reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]);
+  }
+
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
+  return true;
+}
+
+Expr MakeScatterND(Expr data, Expr indices, const Array<Integer> out_shape) {
+  auto attrs = make_object<ScatterNDAttrs>();
+  attrs->out_shape = out_shape;
+  static const Op& op = Op::Get("scatter_nd");
+  return Call(op, {data, indices}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND);
+
+RELAY_REGISTER_OP("scatter_nd")
+    .describe(R"code(Scatter elements or slices from data and store to a tensor
+whose shape is defined by indices.
+
+Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape
+(M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}).
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_support_level(3)
+    .add_type_rel("ScatterND", ScatterNDRel)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
 // Take
 TVM_REGISTER_NODE_TYPE(TakeAttrs);
 
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 5469737..eec6aa2 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -989,7 +989,10 @@ def test_recursive_concat_with_wrong_annotation():
     body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
     func = relay.Function([start], relay.TupleGetItem(body, 1))
     with DiagnosticTesting() as diagnostics:
-        diagnostics.assert_message("in particular dimension 0 conflicts 2 does not match 1")
+        diagnostics.assert_message(
+            "The Relay type checker is unable to show the following types "
+            "match.\nIn particular dimension 0 conflicts: 2 does not match 1."
+        )
         func = infer_type(func)
 
 
diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py
index 9c27afd..98ff62e 100644
--- a/tests/python/relay/test_op_grad_level3.py
+++ b/tests/python/relay/test_op_grad_level3.py
@@ -117,5 +117,14 @@ def test_arange_grad():
     check_grad(fwd_func, inputs=values)
 
 
+def test_gather_nd_grad():
+    data = relay.var("data", relay.TensorType((2, 3), "float64"))
+    indices = relay.var("indices", relay.TensorType((2, 4), "int64"))
+    fwd = relay.Function([data, indices], relay.gather_nd(data, indices))
+    data_np = np.random.rand(2, 3).astype("float64")
+    indices_np = np.array([[0, 1, 1, 0], [0, 1, 0, 0]], dtype="int64")
+    check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[data_np])
+
+
 if __name__ == "__main__":
     pytest.main()
diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py
new file mode 100644
index 0000000..2e701e2
--- /dev/null
+++ b/tests/python/topi/python/test_topi_scatter.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import topi
+import tvm.topi.testing
+
+
+@tvm.testing.parametrize_targets
+def test_scatter_nd(ctx, target):
+    def check_scatter_nd(data, indices, shape, out):
+        implementations = {
+            "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern),
+            "gpu": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern),
+            "cpu": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern),
+        }
+        fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
+        tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule)
+
+    data = np.array([2, 3, 0])
+    indices = np.array([[1, 1, 0], [0, 1, 0]])
+    shape = (2, 2)
+    out = np.array([[0, 0], [2, 3]])
+    check_scatter_nd(data, indices, shape, out)
+
+    data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+    indices = np.array([[0, 1], [1, 1]])
+    shape = (2, 2, 2, 2)
+    out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
+    check_scatter_nd(data, indices, shape, out)
+
+    data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
+    indices = np.array([[1, 0, 0]])
+    shape = (2, 1560)
+    out = np.zeros(shape).astype("float32")
+    out[1, :] += data[0, :]
+    out[0, :] += data[1, :]
+    out[0, :] += data[2, :]
+    check_scatter_nd(data, indices, shape, out)
+
+    data = np.ones((5, 3)).astype("float64")
+    indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
+    shape = (2, 7, 3)
+    out = np.zeros(shape).astype("float64")
+    for i in range(indices.shape[1]):
+        for j in range(data.shape[1]):
+            out[indices[0, i], indices[1, i], j] += data[i, j]
+    check_scatter_nd(data, indices, shape, out)
+
+
+if __name__ == "__main__":
+    test_scatter_nd(tvm.context("cpu"), tvm.target.Target("llvm"))