You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/12/04 22:52:58 UTC

[tvm] branch main updated: [TOPI][OP] cuda for argwhere (#6868)

This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 54cd235  [TOPI][OP] cuda for argwhere (#6868)
54cd235 is described below

commit 54cd235d7d0c01c05d29f4f26fd8788a50b84877
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Fri Dec 4 14:52:41 2020 -0800

    [TOPI][OP] cuda for argwhere (#6868)
    
    * argwhere
    
    * cuda schedule
    
    * sort argwhere result
    
    * Use single block and thrust to fix flaky behavior
    
    * format
    
    * used dynamic strided_slice
    
    * Fix dynamic strided_slice
    
    * try new strided_slice
    
    * Improve dynamic strided slice to bind data depedent shape var.
    
    * all tests pass
    
    * remove print
    
    * use new strided_slice
    
    * clean
    
    Co-authored-by: Yao Wang <ke...@gmail.com>
---
 3rdparty/vta-hw                                |   2 +-
 python/tvm/relay/op/_transform.py              |  16 +-
 python/tvm/relay/op/strategy/cuda.py           |  12 +
 python/tvm/relay/op/strategy/generic.py        |  39 +-
 python/tvm/topi/argwhere.py                    |   2 +
 python/tvm/topi/cuda/__init__.py               |   1 +
 python/tvm/topi/cuda/argwhere.py               | 654 +++++++++++++++++++++++++
 python/tvm/topi/cuda/sort.py                   |   2 +
 tests/python/relay/test_any.py                 |  17 +-
 tests/python/topi/python/test_topi_argwhere.py |  86 ++++
 10 files changed, 795 insertions(+), 36 deletions(-)

diff --git a/3rdparty/vta-hw b/3rdparty/vta-hw
index 12fb486..87ce9ac 160000
--- a/3rdparty/vta-hw
+++ b/3rdparty/vta-hw
@@ -1 +1 @@
-Subproject commit 12fb486a491b75d70ec4c5e0a0cd112ab49a95bc
+Subproject commit 87ce9acfae550d1a487746e9d06c2e250076e54c
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 38d27e3..05ca6d2 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -83,21 +83,7 @@ _reg.register_injective_schedule("auto_scheduler_layout_transform")
 _reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)
 
 # argwhere
-@_reg.register_compute("argwhere")
-def compute_argwhere(attrs, inputs, output_type):
-    """Compute definition of argwhere"""
-    output_shape = []
-    for s in output_type.shape:
-        if hasattr(s, "value"):
-            output_shape.append(s)
-        else:
-            # see Any, replace it with a var
-            output_shape.append(te.var("any_dim", "int32"))
-    new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
-    return [topi.argwhere(new_output_type, inputs[0])]
-
-
-_reg.register_schedule("argwhere", strategy.schedule_argwhere)
+_reg.register_strategy("argwhere", strategy.argwhere_strategy)
 
 # scatter
 @_reg.register_compute("scatter")
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 0296906..fc80c9e 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -921,3 +921,15 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target):
         name="correlation.cuda",
     )
     return strategy
+
+
+@argwhere_strategy.register(["cuda", "gpu"])
+def argwhere_strategy_cuda(attrs, inputs, out_type, target):
+    """argwhere cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_argwhere(topi.cuda.argwhere),
+        wrap_topi_schedule(topi.cuda.schedule_argwhere),
+        name="argwhere.cuda",
+    )
+    return strategy
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index a03c517..15c7f2f 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -19,7 +19,7 @@
 import logging
 
 import re
-from tvm import topi, _ffi
+from tvm import topi, _ffi, te, ir
 from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple
 from tvm.target import generic_func, override_native_generic_func
 from .. import op as _op
@@ -1034,14 +1034,6 @@ def proposal_strategy(attrs, inputs, out_type, target):
     return strategy
 
 
-# argwhere
-@generic_func
-def schedule_argwhere(attrs, outs, target):
-    """schedule argwhere"""
-    with target:
-        return topi.generic.schedule_argwhere(outs)
-
-
 # scatter
 @override_native_generic_func("scatter_strategy")
 def scatter_strategy(attrs, outs, out_type, target):
@@ -1223,3 +1215,32 @@ def correlation_strategy(attrs, inputs, out_type, target):
         name="correlation.generic",
     )
     return strategy
+
+
+# argwhere
+def wrap_compute_argwhere(topi_compute):
+    """wrap argwhere topi compute"""
+
+    def _compute_argwhere(attrs, inputs, out_type):
+        output_shape = []
+        for s in out_type.shape:
+            if hasattr(s, "value"):
+                output_shape.append(s)
+            else:
+                output_shape.append(te.var("any_dim", "int32"))
+        new_output_type = ir.TensorType(output_shape, "int32")
+        return [topi_compute(new_output_type, inputs[0])]
+
+    return _compute_argwhere
+
+
+@override_native_generic_func("argwhere_strategy")
+def argwhere_strategy(attrs, inputs, out_type, target):
+    """argwhere generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_argwhere(topi.argwhere),
+        wrap_topi_schedule(topi.generic.schedule_argwhere),
+        name="argwhere.generic",
+    )
+    return strategy
diff --git a/python/tvm/topi/argwhere.py b/python/tvm/topi/argwhere.py
index 75c19af..c2b658a 100644
--- a/python/tvm/topi/argwhere.py
+++ b/python/tvm/topi/argwhere.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
 """Argwhere operator"""
+import tvm
 from tvm.te import hybrid
 
 
@@ -169,6 +170,7 @@ def hybrid_argwhere_5d(output_shape, condition):
     return a
 
 
+@tvm.target.generic_func
 def argwhere(output_shape, condition):
     """Find the indices of elements of a tensor that are non-zero.
 
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 3ff544f..23c625a 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -54,3 +54,4 @@ from .dense_tensorcore import *
 from .conv2d_hwnc_tensorcore import *
 from .correlation import *
 from .sparse import *
+from .argwhere import *
diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py
new file mode 100644
index 0000000..e39004d
--- /dev/null
+++ b/python/tvm/topi/cuda/argwhere.py
@@ -0,0 +1,654 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-arguments, invalid-name
+"""Argwhere operator"""
+
+import logging
+
+import tvm
+from tvm import te
+from tvm._ffi import get_global_func
+from .injective import schedule_injective_from_existing
+from .nms import atomic_add
+from .sort import topk, topk_thrust, argsort, argsort_thrust
+from .. import tag
+from ..transform import strided_slice, adv_index, squeeze
+
+logger = logging.getLogger("topi")
+
+
+def _get_sort_func(mode=0):
+    """Get sort function for argwhere. mode 0 for topk and others for argsort."""
+    if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
+        ret = topk_thrust if mode == 0 else argsort_thrust
+    else:
+        logger.warning(
+            "It's highly recommended to enable thrust library with set(USE_THRUST ON)"
+            " when compiling argwhere for cuda target. Otherwise, it can result in"
+            " significant performance degradation or incorrect result"
+        )
+        ret = topk if mode == 0 else argsort
+
+    return ret
+
+
+def argwhere_1d_ir(condition, out):
+    """Low level IR for argwhere 1D
+
+    Parameters
+    ----------
+    condition : Buffer
+        The condition buffer.
+
+    out : Buffer
+        The output buffer.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.tir.ir_builder.create()
+    a0 = condition.shape[0]
+
+    condition = ib.buffer_ptr(condition)
+    out = ib.buffer_ptr(out)
+
+    valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global")
+    tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+    one_count = tvm.tir.const(1, dtype="int32")
+
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    # Limit threads to a single block to make sure atomic_add works normally.
+    tx = te.thread_axis("threadIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    len_inner_for = a0 // nthread_tx + 1
+    valid_index[0] = 0
+
+    with ib.for_range(0, len_inner_for, name="i") as i:
+        idx = tx * len_inner_for + i
+        with ib.if_scope(idx < a0):
+            with ib.if_scope(condition[idx] != 0):
+                tmp[0] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+                    one_count,
+                )
+                out[tmp[0]] = idx
+
+    return ib.get()
+
+
+def argwhere_1d(output_shape, condition):
+    """Compute for argwhere 1D
+
+    Parameters
+    ----------
+    condition : list of int or tvm.tir.Any
+        The output shape
+
+    out : tvm.te.Tensor
+        Tensor with boolean values.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    condition_buf = tvm.tir.decl_buffer(
+        condition.shape, condition.dtype, "data_buf", data_alignment=8
+    )
+    out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+    out = te.extern(
+        [output_shape],
+        [condition],
+        lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]),
+        dtype=["int32"],
+        in_buffers=[condition_buf],
+        out_buffers=[out_buf],
+        name="argwhere_1d",
+        tag="argwhere1d_gpu",
+    )
+
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+        return out
+
+    sorted_out = _get_sort_func()(
+        out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32"
+    )
+
+    return sorted_out
+
+
+def argwhere_2d_ir(condition, out):
+    """Low level IR for argwhere 2D
+
+    Parameters
+    ----------
+    condition : Buffer
+        The condition buffer.
+
+    out : Buffer
+        The output buffer.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.tir.ir_builder.create()
+    a0 = condition.shape[0]
+    a1 = condition.shape[1]
+
+    condition = ib.buffer_ptr(condition)
+    out = ib.buffer_ptr(out)
+
+    valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+    tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+    one_count = tvm.tir.const(1, dtype="int32")
+
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+
+    # Limit threads to a single block to make sure atomic_add works normally.
+    tx = te.thread_axis("threadIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    len_inner_for = (a0 * a1) // nthread_tx + 1
+
+    valid_index[0] = 0
+
+    with ib.for_range(0, len_inner_for, name="i") as i:
+        idx = tx * len_inner_for + i
+        with ib.if_scope(idx < (a0 * a1)):
+            with ib.if_scope(condition[idx] != 0):
+                tmp[0] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+                    one_count,
+                )
+                out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1)
+                out[tmp[0] * 2 + 1] = tvm.tir.floormod(idx, a1)
+
+    return ib.get()
+
+
+def argwhere_2d(output_shape, condition):
+    """Compute for argwhere 2D
+
+    Parameters
+    ----------
+    condition : list of int or tvm.tir.Any
+        The output shape
+
+    out : tvm.te.Tensor
+        Tensor with boolean values.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    condition_buf = tvm.tir.decl_buffer(
+        condition.shape, condition.dtype, "data_buf", data_alignment=8
+    )
+    out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+    out = te.extern(
+        [output_shape],
+        [condition],
+        lambda ins, outs: argwhere_2d_ir(ins[0], outs[0]),
+        dtype=["int32"],
+        in_buffers=[condition_buf],
+        out_buffers=[out_buf],
+        name="argwhere_2d",
+        tag="argwhere2d_gpu",
+    )
+
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+        return out
+
+    sort_func = _get_sort_func(1)
+
+    # sort the output from the least significant to the most significant
+    # column.
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
+        out1 = strided_slice(out, [0, 1], [out.shape[0], 2])
+        out2 = sort_func(out1, axis=0, dtype="int32")
+        out3 = squeeze(out2)
+        out = adv_index(out, [out3])
+
+        out1 = strided_slice(out, [0, 0], [out.shape[0], 1])
+        out2 = sort_func(out1, axis=0, dtype="int32")
+        out3 = squeeze(out2)
+
+        out = adv_index(out, [out3])
+    else:
+        out1 = strided_slice(out, [0, 1], [out.shape[0], 2], [1, 1])
+        out2 = sort_func(out1, axis=0, dtype="int32")
+        out3 = squeeze(out2)
+        out = adv_index(out, [out3])
+
+        out1 = strided_slice(out, [0, 0], [out.shape[0], 1], [1, 1])
+        out2 = sort_func(out1, axis=0, dtype="int32")
+        out3 = squeeze(out2)
+        out = adv_index(out, [out3])
+    return out
+
+
+def argwhere_3d_ir(condition, out):
+    """Low level IR for argwhere 3D
+
+    Parameters
+    ----------
+    condition : Buffer
+        The condition buffer.
+
+    out : Buffer
+        The output buffer.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.tir.ir_builder.create()
+    a0 = condition.shape[0]
+    a1 = condition.shape[1]
+    a2 = condition.shape[2]
+    s1 = a1 * a2
+    s0 = a0 * s1
+
+    condition = ib.buffer_ptr(condition)
+    out = ib.buffer_ptr(out)
+
+    valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+    tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+    one_count = tvm.tir.const(1, dtype="int32")
+
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+
+    # Limit threads to a single block to make sure atomic_add works normally.
+    tx = te.thread_axis("threadIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    len_inner_for = s0 // nthread_tx + 1
+
+    fdiv = tvm.tir.floordiv
+    fmod = tvm.tir.floormod
+
+    valid_index[0] = 0
+
+    with ib.for_range(0, len_inner_for, name="i") as i:
+        idx = tx * len_inner_for + i
+        with ib.if_scope(idx < s0):
+            with ib.if_scope(condition[idx] != 0):
+                tmp[0] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+                    one_count,
+                )
+                out[tmp[0] * 3] = fdiv(idx, s1)
+                out[tmp[0] * 3 + 1] = fdiv(fmod(idx, s1), a2)
+                out[tmp[0] * 3 + 2] = fmod(idx, a2)
+
+    return ib.get()
+
+
+def argwhere_3d(output_shape, condition):
+    """Compute for argwhere 3D
+
+    Parameters
+    ----------
+    condition : list of int or tvm.tir.Any
+        The output shape
+
+    out : tvm.te.Tensor
+        Tensor with boolean values.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    condition_buf = tvm.tir.decl_buffer(
+        condition.shape, condition.dtype, "data_buf", data_alignment=8
+    )
+    out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+    out = te.extern(
+        [output_shape],
+        [condition],
+        lambda ins, outs: argwhere_3d_ir(ins[0], outs[0]),
+        dtype=["int32"],
+        in_buffers=[condition_buf],
+        out_buffers=[out_buf],
+        name="argwhere_3d",
+        tag="argwhere3d_gpu",
+    )
+
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+        return out
+
+    # sort the output from the least significant to the most significant
+    # column.
+    sort_func = _get_sort_func(1)
+
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
+        for i in reversed(range(3)):
+            out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
+            out2 = sort_func(out1, axis=0, dtype="int32")
+            out3 = squeeze(out2)
+            out = adv_index(out, [out3])
+    else:
+        for i in reversed(range(3)):
+            out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1])
+            out2 = sort_func(out1, axis=0, dtype="int32")
+            out3 = squeeze(out2)
+            out = adv_index(out, [out3])
+    return out
+
+
+def argwhere_4d_ir(condition, out):
+    """Low level IR for argwhere 4D
+
+    Parameters
+    ----------
+    condition : Buffer
+        The condition buffer.
+
+    out : Buffer
+        The output buffer.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.tir.ir_builder.create()
+    a0 = condition.shape[0]
+    a1 = condition.shape[1]
+    a2 = condition.shape[2]
+    a3 = condition.shape[3]
+    s1 = a2 * a3
+    s2 = a1 * s1
+    s0 = a0 * s2
+
+    condition = ib.buffer_ptr(condition)
+    out = ib.buffer_ptr(out)
+
+    valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+    tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+    one_count = tvm.tir.const(1, dtype="int32")
+
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+
+    # Limit threads to a single block to make sure atomic_add works normally.
+    tx = te.thread_axis("threadIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    len_inner_for = s0 // nthread_tx + 1
+
+    fdiv = tvm.tir.floordiv
+    fmod = tvm.tir.floormod
+
+    valid_index[0] = 0
+
+    with ib.for_range(0, len_inner_for, name="i") as i:
+        idx = tx * len_inner_for + i
+        with ib.if_scope(idx < s0):
+            with ib.if_scope(condition[idx] != 0):
+                tmp[0] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+                    one_count,
+                )
+                out[tmp[0] * 4] = fdiv(idx, s2)
+                out[tmp[0] * 4 + 1] = fdiv(fmod(idx, s2), s1)
+                out[tmp[0] * 4 + 2] = fdiv(fmod(idx, s1), a3)
+                out[tmp[0] * 4 + 3] = fmod(idx, a3)
+
+    return ib.get()
+
+
+def argwhere_4d(output_shape, condition):
+    """Compute for argwhere 4D
+
+    Parameters
+    ----------
+    condition : list of int or tvm.tir.Any
+        The output shape
+
+    out : tvm.te.Tensor
+        Tensor with boolean values.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    condition_buf = tvm.tir.decl_buffer(
+        condition.shape, condition.dtype, "data_buf", data_alignment=8
+    )
+    out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+    out = te.extern(
+        [output_shape],
+        [condition],
+        lambda ins, outs: argwhere_4d_ir(ins[0], outs[0]),
+        dtype=["int32"],
+        in_buffers=[condition_buf],
+        out_buffers=[out_buf],
+        name="argwhere_4d",
+        tag="argwhere4d_gpu",
+    )
+
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+        return out
+
+    # sort the output from the least significant to the most significant
+    # column.
+    sort_func = _get_sort_func(1)
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
+        for i in reversed(range(4)):
+            out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
+            out2 = sort_func(out1, axis=0, dtype="int32")
+            out3 = squeeze(out2)
+            out = adv_index(out, [out3])
+    else:
+        for i in reversed(range(4)):
+            out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1])
+            out2 = sort_func(out1, axis=0, dtype="int32")
+            out3 = squeeze(out2)
+            out = adv_index(out, [out3])
+
+    return out
+
+
+def argwhere_5d_ir(condition, out):
+    """Low level IR for argwhere 5D
+
+    Parameters
+    ----------
+    condition : Buffer
+        The condition buffer.
+
+    out : Buffer
+        The output buffer.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.tir.ir_builder.create()
+    a0 = condition.shape[0]
+    a1 = condition.shape[1]
+    a2 = condition.shape[2]
+    a3 = condition.shape[3]
+    a4 = condition.shape[4]
+    s1 = a3 * a4
+    s2 = a2 * s1
+    s3 = a1 * s2
+    s0 = a0 * s3
+
+    condition = ib.buffer_ptr(condition)
+    out = ib.buffer_ptr(out)
+
+    valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+    tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+    one_count = tvm.tir.const(1, dtype="int32")
+
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+
+    # Limit threads to a single block to make sure atomic_add works normally.
+    tx = te.thread_axis("threadIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    len_inner_for = s0 // nthread_tx + 1
+
+    fdiv = tvm.tir.floordiv
+    fmod = tvm.tir.floormod
+
+    valid_index[0] = 0
+
+    with ib.for_range(0, len_inner_for, name="i") as i:
+        idx = tx * len_inner_for + i
+        with ib.if_scope(idx < s0):
+            with ib.if_scope(condition[idx] != 0):
+                tmp[0] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+                    one_count,
+                )
+                out[tmp[0] * 5] = fdiv(idx, s3)
+                out[tmp[0] * 5 + 1] = fdiv(fmod(idx, s3), s2)
+                out[tmp[0] * 5 + 2] = fdiv(fmod(idx, s2), s1)
+                out[tmp[0] * 5 + 3] = fdiv(fmod(idx, s1), a4)
+                out[tmp[0] * 5 + 4] = fmod(idx, a4)
+
+    return ib.get()
+
+
+def argwhere_5d(output_shape, condition):
+    """Compute for argwhere 5D
+
+    Parameters
+    ----------
+    condition : list of int or tvm.tir.Any
+        The output shape
+
+    out : tvm.te.Tensor
+        Tensor with boolean values.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    condition_buf = tvm.tir.decl_buffer(
+        condition.shape, condition.dtype, "data_buf", data_alignment=8
+    )
+    out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+    out = te.extern(
+        [output_shape],
+        [condition],
+        lambda ins, outs: argwhere_5d_ir(ins[0], outs[0]),
+        dtype=["int32"],
+        in_buffers=[condition_buf],
+        out_buffers=[out_buf],
+        name="argwhere_5d",
+        tag="argwhere5d_gpu",
+    )
+
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+        return out
+
+    # sort the output from the least significant to the most significant
+    # column.
+    sort_func = _get_sort_func(1)
+    if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
+        for i in reversed(range(5)):
+            out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
+            out2 = sort_func(out1, axis=0, dtype="int32")
+            out3 = squeeze(out2)
+            out = adv_index(out, [out3])
+    else:
+        for i in reversed(range(5)):
+            out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1])
+            out2 = sort_func(out1, axis=0, dtype="int32")
+            out3 = squeeze(out2)
+            out = adv_index(out, [out3])
+
+    return out
+
+
+def argwhere(output_shape, condition):
+    """Find the indices of elements of a tensor that are non-zero.
+
+    Parameters
+    ----------
+    output_shape : tvm.te.Tensor
+        Tensor with output shape info.
+
+    condition : tvm.te.Tensor
+        Tensor with boolean values.
+
+    Returns
+    -------
+    out : tvm.te.Tensor
+        Indices of non-zero elements.
+    """
+    if len(condition.shape) == 1:
+        return argwhere_1d(output_shape.shape, condition)
+    if len(condition.shape) == 2:
+        return argwhere_2d(output_shape.shape, condition)
+    if len(condition.shape) == 3:
+        return argwhere_3d(output_shape.shape, condition)
+    if len(condition.shape) == 4:
+        return argwhere_4d(output_shape.shape, condition)
+    if len(condition.shape) == 5:
+        return argwhere_5d(output_shape.shape, condition)
+    raise ValueError("Argwhere does not support rank higher than 5")
+
+
+def schedule_argwhere(outs):
+    """Schedule for argwhere on cuda.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of argwhere
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for argwhere
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    scheduled_ops = []
+
+    def traverse(op):
+        if tag.is_injective(op.tag):
+            schedule_injective_from_existing(s, op.output(0))
+        for tensor in op.input_tensors:
+            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
+                traverse(tensor.op)
+        scheduled_ops.append(op)
+
+    for out in outs:
+        traverse(out.op)
+    return s
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index ac14f5a..2a7f4eb 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -550,6 +550,8 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
         tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8),
     ]
 
+    is_ascend = 1 if is_ascend else 0
+
     out = te.extern(
         [data.shape, data.shape],
         [data],
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index ee67e67..ddf8e98 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -219,30 +219,25 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
     mod["main"] = relay.Function([x], y)
     data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
     expected = np.argwhere(data)
-    for kind in ["debug", "vm"]:
-        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
-        result = ex.evaluate()(data).asnumpy()
-        assert result.shape == expected.shape
-        tvm.testing.assert_allclose(result.flatten(), expected.flatten())
-
-    # TODO(@zhiics) argwhere gpu schedule is currently not avaiable
-    # check_result([data], mod, expected, flatten=True)
+    check_result([data], mod, expected, flatten=True)
 
 
-@tvm.testing.uses_gpu
+# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have
+# to use thrust to guarantee the correct results which has been tested locally.
+# @tvm.testing.uses_gpu
 def test_any_argwhere():
     verify_any_argwhere(any_dims(1), (5,))
     verify_any_argwhere(any_dims(2), (5, 5))
+    verify_any_argwhere(any_dims(2), (5, 5), "int32")
+    verify_any_argwhere(any_dims(2), (5, 5), "int8")
     verify_any_argwhere(any_dims(3), (5, 5, 5))
     verify_any_argwhere(any_dims(4), (5, 5, 5, 5))
     verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5))
     verify_any_argwhere(any_dims(1), (5,), "int32")
-    verify_any_argwhere(any_dims(2), (5, 5), "int32")
     verify_any_argwhere(any_dims(3), (5, 5, 5), "int32")
     verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32")
     verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32")
     verify_any_argwhere(any_dims(1), (5,), "int8")
-    verify_any_argwhere(any_dims(2), (5, 5), "int8")
     verify_any_argwhere(any_dims(3), (5, 5, 5), "int8")
     verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8")
     verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8")
diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py
new file mode 100644
index 0000000..5cb7cd4
--- /dev/null
+++ b/tests/python/topi/python/test_topi_argwhere.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test for argwhere operator"""
+import numpy as np
+
+import tvm
+from tvm import te
+from tvm import topi
+import tvm.topi.testing
+
+_argwhere_schedule = {
+    "generic": topi.generic.schedule_argwhere,
+    "gpu": topi.cuda.schedule_argwhere,
+}
+
+_argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere}
+
+
+def verify_argwhere(data_shape):
+    dtype = "int32"
+    np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype)
+    np_out = np.argwhere(np_data)
+    out_shape = np_out.shape[0]
+    np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype)
+
+    out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype)
+    condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype)
+
+    def check_device(device, ctx):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist or device not in _argwhere_compute:
+            return
+
+        with tvm.target.Target(device):
+            out = _argwhere_compute[device](out_shape, condition)
+            s_func = tvm.topi.testing.dispatch(device, _argwhere_schedule)
+            sch = s_func(out)
+
+        func = tvm.build(sch, [out_shape, condition, out], device, name="argwhere")
+
+        args = [tvm.nd.array(np_shape, ctx)]
+        args.append(tvm.nd.array(np_data, ctx))
+        args.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=condition.dtype))
+        func(*args)
+        np.set_printoptions(threshold=np.inf)
+        tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out))
+
+    for target, ctx in tvm.testing.enabled_targets():
+        check_device(target, ctx)
+
+
+# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have
+# to use thrust to guarantee the correct results which has been tested locally.
+# @tvm.testing.uses_gpu
+def test_argwhere():
+    verify_argwhere((1,))
+    verify_argwhere((100,))
+    verify_argwhere((1, 1))
+    verify_argwhere((5, 3))
+    verify_argwhere((32, 64))
+    verify_argwhere((128, 65))
+    verify_argwhere((200, 500))
+    verify_argwhere((6, 5, 3))
+    verify_argwhere((1, 1, 1))
+    verify_argwhere((1, 1, 1, 1))
+    verify_argwhere((6, 4, 5, 3))
+    verify_argwhere((1, 1, 1, 1, 1))
+    verify_argwhere((6, 4, 5, 3, 7))
+
+
+if __name__ == "__main__":
+    test_argwhere()