You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/26 01:59:13 UTC

[tvm] branch main updated: [Frontend][Tensorflow] Add unique operator (#7441)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 09b0c8e  [Frontend][Tensorflow] Add unique operator (#7441)
09b0c8e is described below

commit 09b0c8e6f688d1c25734b6371426972ab1c37183
Author: Yanming Wang <ya...@gmail.com>
AuthorDate: Thu Feb 25 17:58:39 2021 -0800

    [Frontend][Tensorflow] Add unique operator (#7441)
    
    * Initial commit of the unique operator
    
    Add unit tests for unique operator
    
    * Add tensorflow unique op
    
    * Refactor unique to use sort-based algorithm
    
    * Change relay.unique test to run only on cpu
    
    * Change topi.unique test to run only on cpu
    
    * Change range to parallel for parallelizable loops
    
    * Add return_counts option for relay.unique and topi.unique, add pytorch frontend
    
    * Fix pylint
    
    * Patch pytorch frontend
    
    * Initial support of topi.cuda.unique
    
    * Refactor to use ir_builder directly
    
    * Modularize adjacent difference
    
    * Refactor to simplify
    
    * Fix typo
    
    * Combine _unique and _unique_with_counts
    
    * Reuse indices_ptr to remove arange_ptr
    
    Co-authored-by: Yanming Wang <ya...@amazon.com>
---
 include/tvm/relay/attrs/transform.h              |  12 +
 python/tvm/relay/frontend/pytorch.py             |  19 ++
 python/tvm/relay/frontend/tensorflow.py          |  26 ++
 python/tvm/relay/op/_transform.py                |  44 +++
 python/tvm/relay/op/strategy/cuda.py             |  12 +
 python/tvm/relay/op/strategy/generic.py          |  21 ++
 python/tvm/relay/op/transform.py                 |  54 ++++
 python/tvm/topi/__init__.py                      |   1 +
 python/tvm/topi/cuda/__init__.py                 |   1 +
 python/tvm/topi/cuda/unique.py                   | 396 +++++++++++++++++++++++
 python/tvm/topi/generic/search.py                |  16 +
 python/tvm/topi/unique.py                        | 297 +++++++++++++++++
 src/relay/op/tensor/transform.cc                 |  47 +++
 tests/python/frontend/pytorch/test_forward.py    |  25 +-
 tests/python/frontend/tensorflow/test_forward.py |  65 ++++
 tests/python/relay/test_op_level3.py             |  53 +++
 tests/python/topi/python/test_topi_unique.py     | 111 +++++++
 17 files changed, 1199 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index 24098b7..ff344f5 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -452,6 +452,18 @@ struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
   }
 };
 
+/*! \brief Attributes used in unique operator */
+struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
+  bool sorted;
+  bool return_counts;
+  TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs") {
+    TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted").set_default(true);
+    TVM_ATTR_FIELD(return_counts)
+        .describe("Whether to return an additional tensor with counts of each unique elements")
+        .set_default(false);
+  }
+};  // struct UniqueAttrs
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9316112..6795410 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2157,6 +2157,24 @@ class PyTorchOpConverter:
         is_float = input_type in ["float32", "float64", "float16", "bfloat16"]
         return _expr.const(is_float)
 
+    def unique(self, inputs, input_types):
+        assert len(inputs) == 4
+        [data, is_sorted, return_inverse, return_counts] = inputs
+        if not is_sorted:
+            logging.warning("TVM always assumes sorted=True for torch.unique")
+            is_sorted = True
+        if return_counts:
+            [unique, indices, num_uniq, counts] = _op.unique(
+                data, is_sorted=is_sorted, return_counts=True
+            )
+            unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
+            counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
+            return (unique_sliced, indices, counts_sliced)
+        else:
+            [unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False)
+            unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
+            return (unique_sliced, indices)
+
     # Operator mappings
     def create_convert_map(self):
         self.convert_map = {
@@ -2363,6 +2381,7 @@ class PyTorchOpConverter:
             "aten::masked_select": self.masked_select,
             "aten::argsort": self.argsort,
             "aten::sort": self.sort,
+            "aten::_unique2": self.unique,
         }
 
     def update_convert_map(self, custom_map):
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index ab98cdd..65f18c0 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -2471,6 +2471,30 @@ def _LSTMBlockCell():
     return _impl
 
 
+def _unique(return_counts=True):
+    def _impl(inputs, attr, params, mod):
+        assert len(inputs) == 1
+        data = inputs[0]
+        if return_counts:
+            [unique, indices, num_uniq, counts] = _op.unique(
+                data, is_sorted=False, return_counts=True
+            )
+            unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
+            counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
+            return _expr.TupleWrapper(
+                _expr.Tuple([unique_sliced, indices, counts_sliced]),
+                3,
+            )
+        [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False)
+        unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
+        return _expr.TupleWrapper(
+            _expr.Tuple([unique_sliced, indices]),
+            2,
+        )
+
+    return _impl
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -2650,6 +2674,8 @@ _convert_map = {
     "TopKV2": _topk(),
     "Transpose": _transpose(),
     "TruncateMod": _elemwise("mod"),
+    "Unique": _unique(False),
+    "UniqueWithCounts": _unique(True),
     "Unpack": _unpack(),
     "UnravelIndex": _unravel_index(),
     "Where": _where(),
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 01bcf4a..e9cf3d8 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -142,6 +142,15 @@ def compute_cumsum(attrs, inputs, output_type):
 _reg.register_strategy("cumsum", strategy.cumsum_strategy)
 _reg.register_shape_func("cumsum", False, elemwise_shape_func)
 
+
+@_reg.register_compute("unique")
+def compute_unique(attrs, inputs, output_type):
+    """Compute definition of unique"""
+    return topi.unique(inputs[0], attrs.sorted, attrs.return_counts)
+
+
+_reg.register_strategy("unique", strategy.unique_strategy)
+
 #####################
 #  Shape functions  #
 #####################
@@ -946,3 +955,38 @@ def where_shape_func(attrs, inputs, _):
     out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)
 
     return [out_shape]
+
+
+@script
+def _unique_shape(data_shape):
+    unique_shape = output_tensor((1,), "int64")
+    indices_shape = output_tensor((1,), "int64")
+    num_unique_shape = output_tensor((1,), "int64")
+    unique_shape[0] = data_shape[0]
+    indices_shape[0] = data_shape[0]
+    num_unique_shape[0] = int64(1)
+    return (unique_shape, indices_shape, num_unique_shape)
+
+
+@script
+def _unique_with_counts_shape(data_shape):
+    unique_shape = output_tensor((1,), "int64")
+    indices_shape = output_tensor((1,), "int64")
+    num_unique_shape = output_tensor((1,), "int64")
+    counts_shape = output_tensor((1,), "int64")
+    unique_shape[0] = data_shape[0]
+    indices_shape[0] = data_shape[0]
+    num_unique_shape[0] = int64(1)
+    counts_shape[0] = data_shape[0]
+    return (unique_shape, indices_shape, num_unique_shape, counts_shape)
+
+
+@_reg.register_shape_func("unique", False)
+def unique_shape_func(attrs, inputs, _):
+    """
+    Shape func for unique operator.
+    """
+    if attrs.return_counts:
+        return _unique_with_counts_shape(inputs[0])
+    else:
+        return _unique_shape(inputs[0])
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 20c5f03..3abc9c4 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -1009,3 +1009,15 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target):
         name="cumsum.cuda",
     )
     return strategy
+
+
+@unique_strategy.register(["cuda", "gpu"])
+def unique_strategy_cuda(attrs, inputs, out_type, target):
+    """unique cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_unique(topi.cuda.unique),
+        wrap_topi_schedule(topi.cuda.schedule_scan),
+        name="unique.cuda",
+    )
+    return strategy
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index f076176..8a2724d 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1432,3 +1432,24 @@ def cumsum_strategy(attrs, inputs, out_type, target):
         name="cumsum.generic",
     )
     return strategy
+
+
+def wrap_compute_unique(topi_compute):
+    """Wrap unique topi compute"""
+
+    def _compute_unique(attrs, inputs, _):
+        return topi_compute(inputs[0], attrs.sorted, attrs.return_counts)
+
+    return _compute_unique
+
+
+@override_native_generic_func("unique_strategy")
+def unique_strategy(attrs, inputs, out_type, target):
+    """unique generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_unique(topi.unique),
+        wrap_topi_schedule(topi.generic.schedule_unique),
+        name="unique.generic",
+    )
+    return strategy
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index b676fe7..c0a0d31 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -1463,3 +1463,57 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
         -> [1, 1, 2, 2, 3, 4, 4]
     """
     return _make.cumsum(data, axis, dtype, exclusive)
+
+
+def unique(data, is_sorted=True, return_counts=False):
+    """
+    Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to
+    have the same length of `data` and element with index >= num_unique[0] has undefined value.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        A 1-D tensor of integers.
+
+    sorted : bool
+        Whether to sort the unique elements in ascending order before returning as output.
+
+    return_counts : bool
+        Whether to return the count of each unique element.
+
+    Returns
+    -------
+    output : relay.Expr
+        A 1-D tensor containing the unique elements of the input data tensor.
+
+    indices : relay.Expr
+        A 1-D tensor containing the index of each data element in the output tensor.
+
+    num_unique : relay.Expr
+        A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.
+
+    counts (optional) : relay.Expr
+        A 1-D tensor containing the count of each unique element in the output.
+
+    Examples
+    --------
+    .. code-block:: python
+        [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False)
+        output         =  [4, 5, 1, 2, 3, ?, ?, ?]
+        indices        =  [0, 1, 2, 3, 4, 4, 0, 1]
+        num_unique     =  [5]
+
+        [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True)
+        output         =  [4, 5, 1, 2, 3, ?, ?, ?]
+        indices        =  [0, 1, 2, 3, 4, 4, 0, 1]
+        num_unique     =  [5]
+        counts         =  [2, 2, 1, 1, 2, ?, ?, ?]
+
+        [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True)
+        output         =  [1, 2, 3, 4, 5, ?, ?, ?]
+        indices        =  [3, 4, 0, 1, 2, 2, 3, 4]
+        num_unique     =  [5]
+    """
+    if return_counts:
+        return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)
+    return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3)
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index 2b17162..63dc4bd 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -43,6 +43,7 @@ from .scatter_add import *
 from .argwhere import *
 from .cumsum import *
 from .einsum import *
+from .unique import *
 from . import generic
 from . import nn
 from . import x86
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index bf3582c..df75c67 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -58,3 +58,4 @@ from .sparse import *
 from . import tensorcore_alter_op
 from .argwhere import *
 from .scan import *
+from .unique import *
diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py
new file mode 100644
index 0000000..02a5cf3
--- /dev/null
+++ b/python/tvm/topi/cuda/unique.py
@@ -0,0 +1,396 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Unique operator"""
+import tvm
+from tvm import te, tir
+from ...te import hybrid
+from .scan import cumsum
+from .sort import sort, argsort
+from ..utils import ceil_div
+
+
+def _calc_adjacent_diff_ir(data, output, binop=tir.Sub):
+    """Low level IR to calculate adjacent difference in an 1-D array.
+
+    Parameters
+    ----------
+    data : Buffer
+        Input 1-D Buffer.
+
+    output: Buffer
+        A buffer to store adjacent difference, of the same shape as data. The adjacent difference
+        is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1])
+        where i > 0 and i < len(data).
+
+    binop: function, optional
+        A binary associative op to use for calculating adjacent difference. The function takes two
+        TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to
+        compute the adjacent difference.
+    """
+    ib = tir.ir_builder.create()
+    data_ptr = ib.buffer_ptr(data)
+    output_ptr = ib.buffer_ptr(output)
+    batch_size = data.shape[0]
+    max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(batch_size, max_threads)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size):
+            with ib.if_scope(tid == 0):
+                output_ptr[tid] = 0
+            with ib.else_scope():
+                output_ptr[tid] = tir.Cast(output.dtype, binop(data_ptr[tid], data_ptr[tid - 1]))
+    return ib.get()
+
+
+def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub):
+    """Function calculate adjacent difference in an 1-D array.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input 1-D tensor.
+
+    output_dtype : str
+        The output tensor data type.
+
+    binop: function, optional
+        A binary associative op to use for calculating difference. The function takes two
+        TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to
+        compute the adjacent difference.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        1-D tensor storing the adjacent difference of the input tensor. The adjacent difference
+        is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1])
+        where i > 0 and i < len(data).
+    """
+    data_buf = tir.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8)
+    output_buf = tir.decl_buffer(data.shape, out_dtype, "output_buf", data_alignment=8)
+    return te.extern(
+        [data.shape],
+        [data],
+        lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0], binop=binop),
+        dtype=[out_dtype],
+        in_buffers=[data_buf],
+        out_buffers=[output_buf],
+        name="_calc_adjacent_diff",
+        tag="_calc_adjacent_diff_gpu",
+    )
+
+
+@hybrid.script
+def _calc_num_unique(inc_scan):
+    """Helper function to get the number of unique elements fron inc_scan tensor"""
+    output = output_tensor((1,), "int32")
+    for i in bind("threadIdx.x", 1):
+        output[i] = inc_scan[inc_scan.shape[0] - 1] + int32(1)
+    return output
+
+
+def _calc_unique_ir(
+    data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts
+):
+    """Low level IR to calculate unique elements, inverse indices, and counts (optional) of
+    unique elements of 1-D array.
+
+    Parameters
+    ----------
+    data : Buffer
+        Input 1-D Buffer.
+
+    argsorted_indices : Buffer
+        A buffer that stores the argsorted indices of the input data.
+
+    inc_scan : Buffer
+        A buffer that stores the inclusive scan of the binary tir.NE adjacent difference
+        of the sorted data.
+
+    index_converter (optional) : Buffer
+        An optional index converter that transforms the unique element index
+        such that new_idx = index_converter[old_idx].
+
+    unique_elements : Buffer
+        A buffer that stores the unique elements.
+
+    indices : Buffer
+        A buffer that stores the the index of each input data element in the unique element array.
+
+    counts (optional) : Buffer
+        A buffer that stores the count of each unique element.
+    """
+    ib = tir.ir_builder.create()
+    data_ptr = ib.buffer_ptr(data)
+    argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
+    inc_scan_ptr = ib.buffer_ptr(inc_scan)
+    unique_elements_ptr = ib.buffer_ptr(unique_elements)
+    indices_ptr = ib.buffer_ptr(indices)
+
+    index_converter_ptr = None
+    if isinstance(index_converter, tir.Buffer):
+        index_converter_ptr = ib.buffer_ptr(index_converter)
+
+    if isinstance(counts, tir.Buffer):
+        counts_ptr = ib.buffer_ptr(counts)
+        # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1]
+        unique_seq_indices_ptr = ib.buffer_ptr(indices)
+
+    batch_size = data.shape[0]
+    max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
+
+    # if need to return counts
+    if isinstance(counts, tir.Buffer):
+        num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1
+        num_elements = data.shape[0]
+        with ib.new_scope():
+            nthread_tx = max_threads
+            nthread_bx = ceil_div(batch_size, max_threads)
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(bx, "thread_extent", nthread_bx)
+            tid = bx * max_threads + tx
+            with ib.if_scope(tid < batch_size):
+                with ib.if_scope(tid == 0):
+                    unique_seq_indices_ptr[num_unique - 1] = num_elements
+                with ib.else_scope():
+                    with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
+                        unique_seq_indices_ptr[inc_scan_ptr[tid] - 1] = tid
+        with ib.new_scope():
+            nthread_tx = max_threads
+            nthread_bx = ceil_div(batch_size, max_threads)
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(bx, "thread_extent", nthread_bx)
+            tid = bx * max_threads + tx
+            with ib.if_scope(tid < num_unique):
+                unique_idx = tid if not index_converter_ptr else index_converter_ptr[tid]
+                with ib.if_scope(tid == 0):
+                    counts_ptr[unique_idx] = unique_seq_indices_ptr[tid]
+                with ib.else_scope():
+                    counts_ptr[unique_idx] = (
+                        unique_seq_indices_ptr[tid] - unique_seq_indices_ptr[tid - 1]
+                    )
+    # calculate unique elements and inverse indices
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(batch_size, max_threads)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size):
+            data_idx = argsorted_indices_ptr[tid]
+            unique_idx = (
+                inc_scan_ptr[tid]
+                if not index_converter_ptr
+                else index_converter_ptr[inc_scan_ptr[tid]]
+            )
+            indices_ptr[data_idx] = unique_idx
+            with ib.if_scope(tid == 0):
+                unique_elements_ptr[unique_idx] = data_ptr[data_idx]
+            with ib.else_scope():
+                with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
+                    unique_elements_ptr[unique_idx] = data_ptr[data_idx]
+    return ib.get()
+
+
+def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence):
+    """Low level IR to calculate the first occurence of each unique element in the input data.
+
+    Parameters
+    ----------
+    argsorted_indices : Buffer
+        A buffer that stores the argsorted indices of the input data.
+
+    inc_scan : Buffer
+        A buffer that stores the inclusive scan of the binary tir.NE adjacent difference
+        of the sorted data.
+
+    first_occurence : Buffer
+        A buffer that stores the first occurence of each unique element in the input data.
+    """
+    ib = tir.ir_builder.create()
+    argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
+    inc_scan_ptr = ib.buffer_ptr(inc_scan)
+    first_occurence_ptr = ib.buffer_ptr(first_occurence)
+    batch_size = argsorted_indices.shape[0]
+    max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(batch_size, max_threads)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size):
+            first_occurence_ptr[tid] = batch_size
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(batch_size, max_threads)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size):
+            with ib.if_scope(tid == 0):
+                first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid]
+            with ib.else_scope():
+                with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
+                    first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid]
+    return ib.get()
+
+
+def unique(data, is_sorted=True, return_counts=False):
+    """
+    Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to
+    have the same length of `data` and element with index >= num_unique[0] has undefined value.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        A 1-D tensor of integers.
+
+    sorted : bool
+        Whether to sort the unique elements in ascending order before returning as output.
+
+    return_counts : bool
+        Whether to return the count of each unique element.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        A 1-D tensor containing the unique elements of the input data tensor.
+
+    indices : tvm.te.Tensor
+        A 1-D tensor containing the index of each data element in the output tensor.
+
+    num_unique : tvm.te.Tensor
+        A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.
+
+    counts (optional) : tvm.te.Tensor
+        A 1-D tensor containing the count of each unique element in the output.
+
+    Examples
+    --------
+    .. code-block:: python
+        [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False)
+        output         =  [4, 5, 1, 2, 3, ?, ?, ?]
+        indices        =  [0, 1, 2, 3, 4, 4, 0, 1]
+        num_unique     =  [5]
+
+        [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True)
+        output         =  [4, 5, 1, 2, 3, ?, ?, ?]
+        indices        =  [0, 1, 2, 3, 4, 4, 0, 1]
+        num_unique     =  [5]
+        counts         =  [2, 2, 1, 1, 2, ?, ?, ?]
+
+        [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True)
+        output         =  [1, 2, 3, 4, 5, ?, ?, ?]
+        indices        =  [3, 4, 0, 1, 2, 2, 3, 4]
+        num_unique     =  [5]
+    """
+    sorted_data = sort(data)
+    argsorted_indices = argsort(data, dtype="int32")
+    # adjacent difference
+    adjacent_diff = _calc_adjacent_diff(sorted_data, out_dtype="int32", binop=tir.NE)
+    # inclusive scan
+    inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0)
+    # total number of unique elements
+    num_unique_elements = _calc_num_unique(inc_scan)
+    # buffers
+    data_buf = tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+    argsorted_indices_buf = tir.decl_buffer(
+        data.shape, "int32", "argsorted_indices_buf", data_alignment=8
+    )
+    inc_scan_buf = tvm.tir.decl_buffer(data.shape, "int32", "inc_scan_buf", data_alignment=8)
+    unique_elements_buf = tir.decl_buffer(
+        data.shape, data.dtype, "unique_elements_buf", data_alignment=8
+    )
+    inverse_indices_buf = tvm.tir.decl_buffer(
+        data.shape, "int32", "inverse_indices_buf", data_alignment=8
+    )
+    # prepare outputs
+    if return_counts:
+        counts_buf = tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8)
+        out_data_shape = [data.shape] * 3
+        out_buffers = [unique_elements_buf, inverse_indices_buf, counts_buf]
+        out_dtypes = [data.dtype, "int32", "int32"]
+    else:
+        out_data_shape = [data.shape] * 2
+        out_buffers = [unique_elements_buf, inverse_indices_buf]
+        out_dtypes = [data.dtype, "int32"]
+    # prepare inputs and fcompute
+    if is_sorted:
+        in_data = [data, argsorted_indices, inc_scan]
+        in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf]
+        if return_counts:
+            fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs)
+        else:
+            fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None)
+    else:
+        # calculate the index converter if the unique elements should not be sorted
+        # calculate first occurence
+        first_occurence_buf = tir.decl_buffer(
+            data.shape, "int32", "first_occurence_buf", data_alignment=8
+        )
+        first_occurence = te.extern(
+            [data.shape],
+            [argsorted_indices, inc_scan],
+            lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]),
+            dtype=["int32"],
+            in_buffers=[argsorted_indices_buf, inc_scan_buf],
+            out_buffers=[first_occurence_buf],
+            name="_calc_first_occurence",
+            tag="_calc_first_occurence_gpu",
+        )
+        # calculate index converter by sorting unique elements by their first occurence
+        argsorted_first_occurence = argsort(first_occurence, dtype="int32")
+        index_converter = argsort(argsorted_first_occurence, dtype="int32")
+        index_converter_buf = tir.decl_buffer(
+            data.shape, "int32", "index_converter_buf", data_alignment=8
+        )
+        in_data = [data, argsorted_indices, inc_scan, index_converter]
+        in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf, index_converter_buf]
+        if return_counts:
+            fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs)
+        else:
+            fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None)
+    outs = te.extern(
+        out_data_shape,
+        in_data,
+        fcompute,
+        dtype=out_dtypes,
+        in_buffers=in_buffers,
+        out_buffers=out_buffers,
+        name="_calc_unique",
+        tag="_calc_unique_gpu",
+    )
+    if return_counts:
+        return [outs[0], outs[1], num_unique_elements, outs[2]]
+    return [*outs, num_unique_elements]
diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py
index 5924d35..f458ee7 100644
--- a/python/tvm/topi/generic/search.py
+++ b/python/tvm/topi/generic/search.py
@@ -70,3 +70,19 @@ def schedule_scatter_add(outs):
 
 def schedule_sparse_fill_empty_rows(outs):
     return _default_schedule(outs, False)
+
+
+def schedule_unique(outs):
+    """Schedule for unique operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+      The computation graph description of unique.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py
new file mode 100644
index 0000000..b4f27b3
--- /dev/null
+++ b/python/tvm/topi/unique.py
@@ -0,0 +1,297 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Unique operator"""
+from tvm import te, tir
+from ..te import hybrid
+from .cumsum import cumsum
+from .sort import sort, argsort
+
+
+def _calc_adjacent_diff_ir(data, output, binop=tir.Sub):
+    """Low level IR to calculate adjacent difference in an 1-D array.
+
+    Parameters
+    ----------
+    data : Buffer
+        Input 1-D Buffer.
+
+    output: Buffer
+        A buffer to store adjacent difference, of the same shape as data. The adjacent difference
+        is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1])
+        where i > 0 and i < len(data).
+
+    binop: function, optional
+        A binary associative op to use for calculating adjacent difference. The function takes two
+        TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to
+        compute the adjacent difference.
+    """
+    ib = tir.ir_builder.create()
+    data_ptr = ib.buffer_ptr(data)
+    output_ptr = ib.buffer_ptr(output)
+    with ib.for_range(0, data.shape[0], kind="parallel") as i:
+        with ib.if_scope(i == 0):
+            output_ptr[0] = 0
+        with ib.else_scope():
+            output_ptr[i] = tir.Cast(output.dtype, binop(data_ptr[i], data_ptr[i - 1]))
+    return ib.get()
+
+
+def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub):
+    """Function calculate adjacent difference in an 1-D array.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input 1-D tensor.
+
+    output_dtype : str
+        The output tensor data type.
+
+    binop: function, optional
+        A binary associative op to use for calculating difference. The function takes two
+        TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to
+        compute the adjacent difference.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        1-D tensor storing the adjacent difference of the input tensor. The adjacent difference
+        is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1])
+        where i > 0 and i < len(data).
+    """
+    return te.extern(
+        [data.shape],
+        [data],
+        lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0], binop=binop),
+        dtype=[out_dtype],
+        name="_calc_adjacent_diff",
+        tag="_calc_adjacent_diff_cpu",
+    )
+
+
+@hybrid.script
+def _calc_num_unique(inc_scan):
+    """Helper function to get the number of unique elements fron inc_scan tensor"""
+    output = output_tensor((1,), "int32")
+    output[0] = inc_scan[inc_scan.shape[0] - 1] + int32(1)
+    return output
+
+
+def _calc_unique_ir(
+    data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts
+):
+    """Low level IR to calculate unique elements, inverse indices, and counts (optional) of
+    unique elements of 1-D array.
+
+    Parameters
+    ----------
+    data : Buffer
+        Input 1-D Buffer.
+
+    argsorted_indices : Buffer
+        A buffer that stores the argsorted indices of the input data.
+
+    inc_scan : Buffer
+        A buffer that stores the inclusive scan of the binary tir.NE adjacent difference
+        of the sorted data.
+
+    index_converter (optional) : Buffer
+        An optional index converter that transforms the unique element index
+        such that new_idx = index_converter[old_idx].
+
+    unique_elements : Buffer
+        A buffer that stores the unique elements.
+
+    indices : Buffer
+        A buffer that stores the the index of each input data element in the unique element array.
+
+    counts (optional) : Buffer
+        A buffer that stores the count of each unique element.
+    """
+    ib = tir.ir_builder.create()
+    data_ptr = ib.buffer_ptr(data)
+    argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
+    inc_scan_ptr = ib.buffer_ptr(inc_scan)
+    unique_elements_ptr = ib.buffer_ptr(unique_elements)
+    indices_ptr = ib.buffer_ptr(indices)
+
+    index_converter_ptr = None
+    if isinstance(index_converter, tir.Buffer):
+        index_converter_ptr = ib.buffer_ptr(index_converter)
+
+    if isinstance(counts, tir.Buffer):
+        counts_ptr = ib.buffer_ptr(counts)
+        # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1]
+        unique_seq_indices_ptr = ib.buffer_ptr(indices)
+
+    data_length = data.shape[0]
+
+    # if need to return counts
+    if isinstance(counts, tir.Buffer):
+        num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1
+        num_elements = data.shape[0]
+        unique_seq_indices_ptr[num_unique - 1] = num_elements
+        with ib.new_scope():
+            with ib.for_range(0, data_length, kind="parallel") as i:
+                with ib.if_scope(i > 0):
+                    with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]):
+                        unique_seq_indices_ptr[inc_scan_ptr[i] - 1] = i
+        with ib.new_scope():
+            with ib.for_range(0, num_unique, kind="parallel") as i:
+                unique_idx = i if not index_converter_ptr else index_converter_ptr[i]
+                with ib.if_scope(i == 0):
+                    counts_ptr[unique_idx] = unique_seq_indices_ptr[i]
+                with ib.else_scope():
+                    counts_ptr[unique_idx] = (
+                        unique_seq_indices_ptr[i] - unique_seq_indices_ptr[i - 1]
+                    )
+    # calculate unique elements and inverse indices
+    with ib.new_scope():
+        with ib.for_range(0, data_length, kind="parallel") as i:
+            data_idx = argsorted_indices_ptr[i]
+            unique_idx = (
+                inc_scan_ptr[i] if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[i]]
+            )
+            indices_ptr[data_idx] = unique_idx
+            with ib.if_scope(i == 0):
+                unique_elements_ptr[unique_idx] = data_ptr[data_idx]
+            with ib.else_scope():
+                with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]):
+                    unique_elements_ptr[unique_idx] = data_ptr[data_idx]
+    return ib.get()
+
+
+@hybrid.script
+def _calc_first_occurence(argsorted_indices, inc_scan):
+    """Hybrid script to calculate the first occurence of each unique element in the input data.
+
+    Parameters
+    ----------
+    argsorted_indices : tvm.te.Tensor
+        A tensor that stores the argsorted indices of the input data.
+
+    inc_scan : tvm.te.Tensor
+        A tensor that stores the inclusive scan of the binary tir.NE adjacent difference
+        of the sorted data.
+
+    first_occurence : tvm.te.Tensor
+        A tensor that stores the first occurence of each unique element in the input data.
+    """
+    first_occurence = output_tensor(argsorted_indices.shape, "int32")
+    for i in parallel(argsorted_indices.shape[0]):
+        first_occurence[i] = argsorted_indices.shape[0]
+    for i in parallel(argsorted_indices.shape[0]):
+        if i == 0 or inc_scan[i] != inc_scan[i - 1]:
+            first_occurence[inc_scan[i]] = argsorted_indices[i]
+    return first_occurence
+
+
+def unique(data, is_sorted=True, return_counts=False):
+    """
+    Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to
+    have the same length of `data` and element with index >= num_unique[0] has undefined value.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        A 1-D tensor of integers.
+
+    sorted : bool
+        Whether to sort the unique elements in ascending order before returning as output.
+
+    return_counts : bool
+        Whether to return the count of each unique element.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        A 1-D tensor containing the unique elements of the input data tensor.
+
+    indices : tvm.te.Tensor
+        A 1-D tensor containing the index of each data element in the output tensor.
+
+    num_unique : tvm.te.Tensor
+        A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.
+
+    counts (optional) : tvm.te.Tensor
+        A 1-D tensor containing the count of each unique element in the output.
+
+    Examples
+    --------
+    .. code-block:: python
+        [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False)
+        output         =  [4, 5, 1, 2, 3, ?, ?, ?]
+        indices        =  [0, 1, 2, 3, 4, 4, 0, 1]
+        num_unique     =  [5]
+
+        [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True)
+        output         =  [4, 5, 1, 2, 3, ?, ?, ?]
+        indices        =  [0, 1, 2, 3, 4, 4, 0, 1]
+        num_unique     =  [5]
+        counts         =  [2, 2, 1, 1, 2, ?, ?, ?]
+
+        [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True)
+        output         =  [1, 2, 3, 4, 5, ?, ?, ?]
+        indices        =  [3, 4, 0, 1, 2, 2, 3, 4]
+        num_unique     =  [5]
+    """
+    sorted_data = sort(data)
+    argsorted_indices = argsort(data, dtype="int32")
+    # adjacent difference
+    adjacent_diff = _calc_adjacent_diff(sorted_data, "int32", tir.NE)
+    # inclusive scan
+    inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0)
+    # total number of unique elements
+    num_unique_elements = _calc_num_unique(inc_scan)
+    # prepare outputs
+    if return_counts:
+        out_data_shape = [data.shape] * 3
+        out_dtypes = [data.dtype, "int32", "int32"]
+    else:
+        out_data_shape = [data.shape] * 2
+        out_dtypes = [data.dtype, "int32"]
+    # prepare inputs and fcompute
+    if is_sorted:
+        in_data = [data, argsorted_indices, inc_scan]
+        if return_counts:
+            fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs)
+        else:
+            fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None)
+    else:
+        # calculate the index converter if the unique elements should not be sorted
+        # calculate first occurence
+        first_occurence = _calc_first_occurence(argsorted_indices, inc_scan)
+        # calculate index converter by sorting unique elements by their first occurence
+        argsorted_first_occurence = argsort(first_occurence, dtype="int32")
+        index_converter = argsort(argsorted_first_occurence, dtype="int32")
+        in_data = [data, argsorted_indices, inc_scan, index_converter]
+        if return_counts:
+            fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs)
+        else:
+            fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None)
+    outs = te.extern(
+        out_data_shape,
+        in_data,
+        fcompute,
+        dtype=out_dtypes,
+        name="_calc_unique",
+        tag="_calc_unique_cpu",
+    )
+    if return_counts:
+        return [outs[0], outs[1], num_unique_elements, outs[2]]
+    return [*outs, num_unique_elements]
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 12db859..eae231f 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3772,5 +3772,52 @@ RELAY_REGISTER_OP("cumsum")
     .add_type_rel("Cumsum", CumsumRel)
     .set_attr<TOpPattern>("TOpPattern", kOpaque);
 
+TVM_REGISTER_NODE_TYPE(UniqueAttrs);
+
+bool UniqueRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+               const TypeReporter& reporter) {
+  // types: [data, result]
+  ICHECK_EQ(types.size(), 2) << "Unique: expect 2 types but " << types.size() << " provided";
+  ICHECK_EQ(num_inputs, 1) << "Unique: expect 1 inputs but " << num_inputs << " provided";
+  auto data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    ICHECK(types[0].as<IncompleteTypeNode>())
+        << "Unique: expect input type to be TensorType but get " << types[0];
+    return false;
+  }
+  const int ndim = static_cast<int>(data->shape.size());
+  ICHECK_EQ(ndim, 1) << "Unique: input must be 1-D tensor";
+  ICHECK_EQ(data->dtype.is_int(), true) << "Unique: input must have int32 or int64 dtype";
+  std::vector<Type> fields;
+  fields.push_back(TensorType(data->shape, data->dtype));               // unique
+  fields.push_back(TensorType(data->shape, DataType::Int(32)));         // indices
+  fields.push_back(TensorType(Array<PrimExpr>{1}, DataType::Int(32)));  // num_unique
+  const auto* param = attrs.as<UniqueAttrs>();
+  if (param->return_counts) {
+    fields.push_back(TensorType(data->shape, DataType::Int(32)));  // counts
+  }
+  reporter->Assign(types[1], TupleType(Array<Type>(fields)));
+  return true;
+}
+
+Expr MakeUnique(Expr data, bool sorted, bool return_counts) {
+  auto attrs = make_object<UniqueAttrs>();
+  attrs->sorted = sorted;
+  attrs->return_counts = return_counts;
+  static const Op& op = Op::Get("unique");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.unique").set_body_typed(MakeUnique);
+
+RELAY_REGISTER_OP("unique")
+    .describe(
+        R"code(This operation returns the unique elements and the new index of each item in a given 1-D array.
+    )code" TVM_ADD_FILELINE)
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .add_type_rel("unique", UniqueRel)
+    .set_support_level(3)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque);
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index aa42b0f..0cf4839 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -2064,7 +2064,12 @@ def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=["llv
             pt_result = input_model(*input_data)
 
         # Verify the accuracy
-        if not isinstance(pt_result, torch.Tensor):
+        if isinstance(pt_result, tuple):
+            # handle multiple outputs
+            for i in range(len(pt_result)):
+                tvm_res = vm_res[i].asnumpy()
+                tvm.testing.assert_allclose(tvm_res, pt_result[i].numpy(), rtol=1e-5, atol=1e-5)
+        elif not isinstance(pt_result, torch.Tensor):
             tvm_res = vm_res.asnumpy().item()
             assert pt_result == tvm_res
         else:
@@ -3654,6 +3659,23 @@ def test_masked_select():
         verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"])
 
 
+def test_unique():
+    def test_fn(is_sorted, return_inverse, return_counts):
+        return lambda x: torch.unique(x, is_sorted, return_inverse, return_counts)
+
+    in_data = torch.randint(0, 20, (10,), dtype=torch.int32)
+    targets = ["llvm", "cuda", "nvptx"]
+    verify_trace_model(test_fn(True, True, True), [in_data], targets)
+    verify_trace_model(test_fn(True, False, True), [in_data], targets)
+    verify_trace_model(test_fn(True, True, False), [in_data], targets)
+    verify_trace_model(test_fn(True, False, True), [in_data], targets)
+    in_data = torch.randint(0, 20, (20,), dtype=torch.int64)
+    verify_trace_model(test_fn(True, True, True), [in_data], targets)
+    verify_trace_model(test_fn(True, False, True), [in_data], targets)
+    verify_trace_model(test_fn(True, True, False), [in_data], targets)
+    verify_trace_model(test_fn(True, False, True), [in_data], targets)
+
+
 if __name__ == "__main__":
     # some structural tests
     test_forward_traced_function()
@@ -3789,6 +3811,7 @@ if __name__ == "__main__":
     test_argsort()
     test_logical_and()
     test_masked_select()
+    test_unique()
 
     # Model tests
     test_resnet18()
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 5ed3e72..8b146b6 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -4988,5 +4988,70 @@ def test_forward_dynmaic_rnn_lstmblockcell():
             tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
 
+#######################################################################
+# Unique
+# ------------
+
+
+def _test_unique(n, dtype, is_dyn):
+    tf.reset_default_graph()
+    np_data = np.random.randint(100, size=n).astype(dtype)
+    with tf.Graph().as_default():
+        if is_dyn:
+            in_data = tf.placeholder(dtype, [n], name="in_data")
+        else:
+            in_data = tf.constant(np_data, dtype, name="in_data")
+        tf.unique(in_data)
+        if is_dyn:
+            compare_tf_with_tvm(np_data, "in_data:0", ["Unique:0", "Unique:1"], mode="vm")
+        else:
+            compare_tf_with_tvm(None, "", ["Unique:0", "Unique:1"])
+
+
+def test_forward_unique():
+    """test Unique"""
+
+    for dtype in ["int32", "int64"]:
+        for is_dyn in [False, True]:
+            _test_unique(50, dtype, is_dyn)
+            _test_unique(100, dtype, is_dyn)
+
+
+#######################################################################
+# Unique with counts
+# ------------
+
+
+def _test_unique_with_counts(n, dtype, is_dyn):
+    tf.reset_default_graph()
+    np_data = np.random.randint(100, size=n).astype(dtype)
+    with tf.Graph().as_default():
+        if is_dyn:
+            in_data = tf.placeholder(dtype, [n], name="in_data")
+        else:
+            in_data = tf.constant(np_data, dtype, name="in_data")
+        tf.unique_with_counts(in_data)
+        if is_dyn:
+            compare_tf_with_tvm(
+                np_data,
+                "in_data:0",
+                ["UniqueWithCounts:0", "UniqueWithCounts:1", "UniqueWithCounts:2"],
+                mode="vm",
+            )
+        else:
+            compare_tf_with_tvm(
+                None, "", ["UniqueWithCounts:0", "UniqueWithCounts:1", "UniqueWithCounts:2"]
+            )
+
+
+def test_forward_unique_with_counts():
+    """test UniqueWithCounts"""
+
+    for dtype in ["int32", "int64"]:
+        for is_dyn in [False, True]:
+            _test_unique_with_counts(10, dtype, is_dyn)
+            _test_unique_with_counts(20, dtype, is_dyn)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 94fac3b..ee55b53 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1453,5 +1453,58 @@ def test_scatter_nd(target, ctx):
     verify_scatter_nd_with_stack(data, indices, shape, out)
 
 
+def test_unique():
+    def calc_numpy_unique(data, is_sorted=False):
+        uniq, index, inverse, counts = np.unique(
+            data, return_index=True, return_inverse=True, return_counts=True
+        )
+        num_uniq = np.array([len(uniq)]).astype("int32")
+        if not is_sorted:
+            order = np.argsort(index)
+            reverse_order = np.argsort(order)
+            uniq = uniq[order].astype(data.dtype)
+            inverse = np.array([reverse_order[i] for i in inverse]).astype("int32")
+            counts = counts[order].astype("int32")
+        return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq]
+
+    def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False):
+        if is_dyn:
+            x = relay.var("x", relay.TensorType([relay.Any()], dtype))
+        else:
+            x = relay.var("x", relay.TensorType([n], dtype))
+        outs = relay.unique(x, is_sorted, return_counts)
+        outs = outs.astuple()
+        func = relay.Function([x], outs)
+        x_data = np.random.randint(50, size=n).astype(dtype)
+
+        if is_dyn:
+            backends = ["vm", "debug"]
+        else:
+            backends = ["graph", "debug"]
+
+        for target, ctx in tvm.testing.enabled_targets():
+            for kind in backends:
+                mod = tvm.ir.IRModule.from_expr(func)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                tvm_res = intrp.evaluate()(x_data)
+                np_res = calc_numpy_unique(x_data, is_sorted)
+                num_unique = np_res[3][0]
+                assert num_unique == tvm_res[2].asnumpy()[0]
+                # unique
+                tvm.testing.assert_allclose(tvm_res[0].asnumpy()[:num_unique], np_res[0], rtol=1e-5)
+                # inverse_indices
+                tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5)
+                # counts
+                if return_counts:
+                    tvm.testing.assert_allclose(
+                        tvm_res[3].asnumpy()[:num_unique], np_res[2], rtol=1e-5
+                    )
+
+    for dtype in ["int32", "int64"]:
+        for i in range(8):
+            is_dyn, is_sorted, return_counts = bool(i & 1), bool(i & 2), bool(i & 4)
+            verify_unique(10, dtype, is_dyn, is_sorted, return_counts)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py
new file mode 100644
index 0000000..d7ee742
--- /dev/null
+++ b/tests/python/topi/python/test_topi_unique.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import topi
+import tvm.topi.testing
+
+
+@tvm.testing.parametrize_targets
+def test_unique(ctx, target):
+    def calc_numpy_unique(data, is_sorted=False):
+        uniq, index, inverse, counts = np.unique(
+            data, return_index=True, return_inverse=True, return_counts=True
+        )
+        num_uniq = np.array([len(uniq)]).astype("int32")
+        if not is_sorted:
+            order = np.argsort(index)
+            reverse_order = np.argsort(order)
+            uniq = uniq[order].astype(data.dtype)
+            inverse = np.array([reverse_order[i] for i in inverse]).astype("int32")
+            counts = counts[order].astype("int32")
+        return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq]
+
+    def check_unique(data, is_sorted=False):
+        # numpy reference
+        np_unique, np_indices, np_counts, np_num_unique = calc_numpy_unique(data, is_sorted)
+        num_unique = np_num_unique[0]
+
+        implementations = {
+            "generic": (
+                lambda x, return_counts: topi.unique(x, is_sorted, return_counts),
+                topi.generic.schedule_unique,
+            ),
+            "cuda": (
+                lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts),
+                topi.cuda.schedule_scan,
+            ),
+            "nvptx": (
+                lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts),
+                topi.cuda.schedule_scan,
+            ),
+        }
+        fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
+        tvm_data = tvm.nd.array(data, ctx=ctx)
+        tvm_unique = tvm.nd.array(np.zeros(data.shape).astype(data.dtype), ctx=ctx)
+        tvm_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), ctx=ctx)
+        tvm_num_unique = tvm.nd.array(np.zeros([1]).astype("int32"), ctx=ctx)
+
+        # without counts
+        with tvm.target.Target(target):
+            te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype))
+            outs = fcompute(te_input, False)
+            s = fschedule(outs)
+            func = tvm.build(s, [te_input, *outs])
+            func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique)
+
+        assert tvm_num_unique.asnumpy()[0] == np_num_unique
+        np.testing.assert_allclose(
+            tvm_unique.asnumpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5
+        )
+        np.testing.assert_allclose(tvm_indices.asnumpy(), np_indices, atol=1e-5, rtol=1e-5)
+
+        # with counts
+        tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), ctx=ctx)
+        with tvm.target.Target(target):
+            te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype))
+            outs = fcompute(te_input, True)
+            s = fschedule(outs)
+            func = tvm.build(s, [te_input, *outs])
+            func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique, tvm_counts)
+
+        np_unique, np_indices, _, np_num_unique = calc_numpy_unique(data, is_sorted)
+        num_unique = np_num_unique[0]
+        assert tvm_num_unique.asnumpy()[0] == np_num_unique
+        np.testing.assert_allclose(
+            tvm_unique.asnumpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5
+        )
+        np.testing.assert_allclose(tvm_indices.asnumpy(), np_indices, atol=1e-5, rtol=1e-5)
+        np.testing.assert_allclose(
+            tvm_counts.asnumpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5
+        )
+
+    for in_dtype in ["int32", "int64"]:
+        for is_sorted in [True, False]:
+            data = np.random.randint(0, 100, size=(1)).astype(in_dtype)
+            check_unique(data, is_sorted)
+            data = np.random.randint(0, 10, size=(10)).astype(in_dtype)
+            check_unique(data, is_sorted)
+            data = np.random.randint(0, 100, size=(10000)).astype(in_dtype)
+            check_unique(data, is_sorted)
+
+
+if __name__ == "__main__":
+    test_unique(tvm.context("cpu"), tvm.target.Target("llvm"))
+    test_unique(tvm.context("cuda"), tvm.target.Target("cuda"))
+    test_unique(tvm.context("nvptx"), tvm.target.Target("nvptx"))