You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/12/04 22:52:58 UTC
[tvm] branch main updated: [TOPI][OP] cuda for argwhere (#6868)
This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 54cd235 [TOPI][OP] cuda for argwhere (#6868)
54cd235 is described below
commit 54cd235d7d0c01c05d29f4f26fd8788a50b84877
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Fri Dec 4 14:52:41 2020 -0800
[TOPI][OP] cuda for argwhere (#6868)
* argwhere
* cuda schedule
* sort argwhere result
* Use single block and thrust to fix flaky behavior
* format
* used dynamic strided_slice
* Fix dynamic strided_slice
* try new strided_slice
* Improve dynamic strided slice to bind data depedent shape var.
* all tests pass
* remove print
* use new strided_slice
* clean
Co-authored-by: Yao Wang <ke...@gmail.com>
---
3rdparty/vta-hw | 2 +-
python/tvm/relay/op/_transform.py | 16 +-
python/tvm/relay/op/strategy/cuda.py | 12 +
python/tvm/relay/op/strategy/generic.py | 39 +-
python/tvm/topi/argwhere.py | 2 +
python/tvm/topi/cuda/__init__.py | 1 +
python/tvm/topi/cuda/argwhere.py | 654 +++++++++++++++++++++++++
python/tvm/topi/cuda/sort.py | 2 +
tests/python/relay/test_any.py | 17 +-
tests/python/topi/python/test_topi_argwhere.py | 86 ++++
10 files changed, 795 insertions(+), 36 deletions(-)
diff --git a/3rdparty/vta-hw b/3rdparty/vta-hw
index 12fb486..87ce9ac 160000
--- a/3rdparty/vta-hw
+++ b/3rdparty/vta-hw
@@ -1 +1 @@
-Subproject commit 12fb486a491b75d70ec4c5e0a0cd112ab49a95bc
+Subproject commit 87ce9acfae550d1a487746e9d06c2e250076e54c
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 38d27e3..05ca6d2 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -83,21 +83,7 @@ _reg.register_injective_schedule("auto_scheduler_layout_transform")
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)
# argwhere
-@_reg.register_compute("argwhere")
-def compute_argwhere(attrs, inputs, output_type):
- """Compute definition of argwhere"""
- output_shape = []
- for s in output_type.shape:
- if hasattr(s, "value"):
- output_shape.append(s)
- else:
- # see Any, replace it with a var
- output_shape.append(te.var("any_dim", "int32"))
- new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
- return [topi.argwhere(new_output_type, inputs[0])]
-
-
-_reg.register_schedule("argwhere", strategy.schedule_argwhere)
+_reg.register_strategy("argwhere", strategy.argwhere_strategy)
# scatter
@_reg.register_compute("scatter")
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 0296906..fc80c9e 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -921,3 +921,15 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target):
name="correlation.cuda",
)
return strategy
+
+
+@argwhere_strategy.register(["cuda", "gpu"])
+def argwhere_strategy_cuda(attrs, inputs, out_type, target):
+ """argwhere cuda strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_argwhere(topi.cuda.argwhere),
+ wrap_topi_schedule(topi.cuda.schedule_argwhere),
+ name="argwhere.cuda",
+ )
+ return strategy
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index a03c517..15c7f2f 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -19,7 +19,7 @@
import logging
import re
-from tvm import topi, _ffi
+from tvm import topi, _ffi, te, ir
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple
from tvm.target import generic_func, override_native_generic_func
from .. import op as _op
@@ -1034,14 +1034,6 @@ def proposal_strategy(attrs, inputs, out_type, target):
return strategy
-# argwhere
-@generic_func
-def schedule_argwhere(attrs, outs, target):
- """schedule argwhere"""
- with target:
- return topi.generic.schedule_argwhere(outs)
-
-
# scatter
@override_native_generic_func("scatter_strategy")
def scatter_strategy(attrs, outs, out_type, target):
@@ -1223,3 +1215,32 @@ def correlation_strategy(attrs, inputs, out_type, target):
name="correlation.generic",
)
return strategy
+
+
+# argwhere
+def wrap_compute_argwhere(topi_compute):
+ """wrap argwhere topi compute"""
+
+ def _compute_argwhere(attrs, inputs, out_type):
+ output_shape = []
+ for s in out_type.shape:
+ if hasattr(s, "value"):
+ output_shape.append(s)
+ else:
+ output_shape.append(te.var("any_dim", "int32"))
+ new_output_type = ir.TensorType(output_shape, "int32")
+ return [topi_compute(new_output_type, inputs[0])]
+
+ return _compute_argwhere
+
+
+@override_native_generic_func("argwhere_strategy")
+def argwhere_strategy(attrs, inputs, out_type, target):
+ """argwhere generic strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_argwhere(topi.argwhere),
+ wrap_topi_schedule(topi.generic.schedule_argwhere),
+ name="argwhere.generic",
+ )
+ return strategy
diff --git a/python/tvm/topi/argwhere.py b/python/tvm/topi/argwhere.py
index 75c19af..c2b658a 100644
--- a/python/tvm/topi/argwhere.py
+++ b/python/tvm/topi/argwhere.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Argwhere operator"""
+import tvm
from tvm.te import hybrid
@@ -169,6 +170,7 @@ def hybrid_argwhere_5d(output_shape, condition):
return a
+@tvm.target.generic_func
def argwhere(output_shape, condition):
"""Find the indices of elements of a tensor that are non-zero.
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 3ff544f..23c625a 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -54,3 +54,4 @@ from .dense_tensorcore import *
from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
+from .argwhere import *
diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py
new file mode 100644
index 0000000..e39004d
--- /dev/null
+++ b/python/tvm/topi/cuda/argwhere.py
@@ -0,0 +1,654 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-arguments, invalid-name
+"""Argwhere operator"""
+
+import logging
+
+import tvm
+from tvm import te
+from tvm._ffi import get_global_func
+from .injective import schedule_injective_from_existing
+from .nms import atomic_add
+from .sort import topk, topk_thrust, argsort, argsort_thrust
+from .. import tag
+from ..transform import strided_slice, adv_index, squeeze
+
+logger = logging.getLogger("topi")
+
+
+def _get_sort_func(mode=0):
+ """Get sort function for argwhere. mode 0 for topk and others for argsort."""
+ if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
+ ret = topk_thrust if mode == 0 else argsort_thrust
+ else:
+ logger.warning(
+ "It's highly recommended to enable thrust library with set(USE_THRUST ON)"
+ " when compiling argwhere for cuda target. Otherwise, it can result in"
+ " significant performance degradation or incorrect result"
+ )
+ ret = topk if mode == 0 else argsort
+
+ return ret
+
+
+def argwhere_1d_ir(condition, out):
+ """Low level IR for argwhere 1D
+
+ Parameters
+ ----------
+ condition : Buffer
+ The condition buffer.
+
+ out : Buffer
+ The output buffer.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ ib = tvm.tir.ir_builder.create()
+ a0 = condition.shape[0]
+
+ condition = ib.buffer_ptr(condition)
+ out = ib.buffer_ptr(out)
+
+ valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global")
+ tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+ one_count = tvm.tir.const(1, dtype="int32")
+
+ max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+ # Limit threads to a single block to make sure atomic_add works normally.
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ len_inner_for = a0 // nthread_tx + 1
+ valid_index[0] = 0
+
+ with ib.for_range(0, len_inner_for, name="i") as i:
+ idx = tx * len_inner_for + i
+ with ib.if_scope(idx < a0):
+ with ib.if_scope(condition[idx] != 0):
+ tmp[0] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+ one_count,
+ )
+ out[tmp[0]] = idx
+
+ return ib.get()
+
+
+def argwhere_1d(output_shape, condition):
+ """Compute for argwhere 1D
+
+ Parameters
+ ----------
+ condition : list of int or tvm.tir.Any
+ The output shape
+
+ out : tvm.te.Tensor
+ Tensor with boolean values.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ condition_buf = tvm.tir.decl_buffer(
+ condition.shape, condition.dtype, "data_buf", data_alignment=8
+ )
+ out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+ out = te.extern(
+ [output_shape],
+ [condition],
+ lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]),
+ dtype=["int32"],
+ in_buffers=[condition_buf],
+ out_buffers=[out_buf],
+ name="argwhere_1d",
+ tag="argwhere1d_gpu",
+ )
+
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+ return out
+
+ sorted_out = _get_sort_func()(
+ out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32"
+ )
+
+ return sorted_out
+
+
+def argwhere_2d_ir(condition, out):
+ """Low level IR for argwhere 2D
+
+ Parameters
+ ----------
+ condition : Buffer
+ The condition buffer.
+
+ out : Buffer
+ The output buffer.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ ib = tvm.tir.ir_builder.create()
+ a0 = condition.shape[0]
+ a1 = condition.shape[1]
+
+ condition = ib.buffer_ptr(condition)
+ out = ib.buffer_ptr(out)
+
+ valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+ tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+ one_count = tvm.tir.const(1, dtype="int32")
+
+ max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+
+ # Limit threads to a single block to make sure atomic_add works normally.
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ len_inner_for = (a0 * a1) // nthread_tx + 1
+
+ valid_index[0] = 0
+
+ with ib.for_range(0, len_inner_for, name="i") as i:
+ idx = tx * len_inner_for + i
+ with ib.if_scope(idx < (a0 * a1)):
+ with ib.if_scope(condition[idx] != 0):
+ tmp[0] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+ one_count,
+ )
+ out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1)
+ out[tmp[0] * 2 + 1] = tvm.tir.floormod(idx, a1)
+
+ return ib.get()
+
+
+def argwhere_2d(output_shape, condition):
+ """Compute for argwhere 2D
+
+ Parameters
+ ----------
+ condition : list of int or tvm.tir.Any
+ The output shape
+
+ out : tvm.te.Tensor
+ Tensor with boolean values.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ condition_buf = tvm.tir.decl_buffer(
+ condition.shape, condition.dtype, "data_buf", data_alignment=8
+ )
+ out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+ out = te.extern(
+ [output_shape],
+ [condition],
+ lambda ins, outs: argwhere_2d_ir(ins[0], outs[0]),
+ dtype=["int32"],
+ in_buffers=[condition_buf],
+ out_buffers=[out_buf],
+ name="argwhere_2d",
+ tag="argwhere2d_gpu",
+ )
+
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+ return out
+
+ sort_func = _get_sort_func(1)
+
+ # sort the output from the least significant to the most significant
+ # column.
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
+ out1 = strided_slice(out, [0, 1], [out.shape[0], 2])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+
+ out1 = strided_slice(out, [0, 0], [out.shape[0], 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+
+ out = adv_index(out, [out3])
+ else:
+ out1 = strided_slice(out, [0, 1], [out.shape[0], 2], [1, 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+
+ out1 = strided_slice(out, [0, 0], [out.shape[0], 1], [1, 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+ return out
+
+
+def argwhere_3d_ir(condition, out):
+ """Low level IR for argwhere 3D
+
+ Parameters
+ ----------
+ condition : Buffer
+ The condition buffer.
+
+ out : Buffer
+ The output buffer.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ ib = tvm.tir.ir_builder.create()
+ a0 = condition.shape[0]
+ a1 = condition.shape[1]
+ a2 = condition.shape[2]
+ s1 = a1 * a2
+ s0 = a0 * s1
+
+ condition = ib.buffer_ptr(condition)
+ out = ib.buffer_ptr(out)
+
+ valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+ tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+ one_count = tvm.tir.const(1, dtype="int32")
+
+ max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+
+ # Limit threads to a single block to make sure atomic_add works normally.
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ len_inner_for = s0 // nthread_tx + 1
+
+ fdiv = tvm.tir.floordiv
+ fmod = tvm.tir.floormod
+
+ valid_index[0] = 0
+
+ with ib.for_range(0, len_inner_for, name="i") as i:
+ idx = tx * len_inner_for + i
+ with ib.if_scope(idx < s0):
+ with ib.if_scope(condition[idx] != 0):
+ tmp[0] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+ one_count,
+ )
+ out[tmp[0] * 3] = fdiv(idx, s1)
+ out[tmp[0] * 3 + 1] = fdiv(fmod(idx, s1), a2)
+ out[tmp[0] * 3 + 2] = fmod(idx, a2)
+
+ return ib.get()
+
+
+def argwhere_3d(output_shape, condition):
+ """Compute for argwhere 3D
+
+ Parameters
+ ----------
+ condition : list of int or tvm.tir.Any
+ The output shape
+
+ out : tvm.te.Tensor
+ Tensor with boolean values.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ condition_buf = tvm.tir.decl_buffer(
+ condition.shape, condition.dtype, "data_buf", data_alignment=8
+ )
+ out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+ out = te.extern(
+ [output_shape],
+ [condition],
+ lambda ins, outs: argwhere_3d_ir(ins[0], outs[0]),
+ dtype=["int32"],
+ in_buffers=[condition_buf],
+ out_buffers=[out_buf],
+ name="argwhere_3d",
+ tag="argwhere3d_gpu",
+ )
+
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+ return out
+
+ # sort the output from the least significant to the most significant
+ # column.
+ sort_func = _get_sort_func(1)
+
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
+ for i in reversed(range(3)):
+ out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+ else:
+ for i in reversed(range(3)):
+ out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+ return out
+
+
+def argwhere_4d_ir(condition, out):
+ """Low level IR for argwhere 4D
+
+ Parameters
+ ----------
+ condition : Buffer
+ The condition buffer.
+
+ out : Buffer
+ The output buffer.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ ib = tvm.tir.ir_builder.create()
+ a0 = condition.shape[0]
+ a1 = condition.shape[1]
+ a2 = condition.shape[2]
+ a3 = condition.shape[3]
+ s1 = a2 * a3
+ s2 = a1 * s1
+ s0 = a0 * s2
+
+ condition = ib.buffer_ptr(condition)
+ out = ib.buffer_ptr(out)
+
+ valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+ tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+ one_count = tvm.tir.const(1, dtype="int32")
+
+ max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+
+ # Limit threads to a single block to make sure atomic_add works normally.
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ len_inner_for = s0 // nthread_tx + 1
+
+ fdiv = tvm.tir.floordiv
+ fmod = tvm.tir.floormod
+
+ valid_index[0] = 0
+
+ with ib.for_range(0, len_inner_for, name="i") as i:
+ idx = tx * len_inner_for + i
+ with ib.if_scope(idx < s0):
+ with ib.if_scope(condition[idx] != 0):
+ tmp[0] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+ one_count,
+ )
+ out[tmp[0] * 4] = fdiv(idx, s2)
+ out[tmp[0] * 4 + 1] = fdiv(fmod(idx, s2), s1)
+ out[tmp[0] * 4 + 2] = fdiv(fmod(idx, s1), a3)
+ out[tmp[0] * 4 + 3] = fmod(idx, a3)
+
+ return ib.get()
+
+
+def argwhere_4d(output_shape, condition):
+ """Compute for argwhere 4D
+
+ Parameters
+ ----------
+ condition : list of int or tvm.tir.Any
+ The output shape
+
+ out : tvm.te.Tensor
+ Tensor with boolean values.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ condition_buf = tvm.tir.decl_buffer(
+ condition.shape, condition.dtype, "data_buf", data_alignment=8
+ )
+ out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+ out = te.extern(
+ [output_shape],
+ [condition],
+ lambda ins, outs: argwhere_4d_ir(ins[0], outs[0]),
+ dtype=["int32"],
+ in_buffers=[condition_buf],
+ out_buffers=[out_buf],
+ name="argwhere_4d",
+ tag="argwhere4d_gpu",
+ )
+
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+ return out
+
+ # sort the output from the least significant to the most significant
+ # column.
+ sort_func = _get_sort_func(1)
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
+ for i in reversed(range(4)):
+ out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+ else:
+ for i in reversed(range(4)):
+ out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+
+ return out
+
+
+def argwhere_5d_ir(condition, out):
+ """Low level IR for argwhere 5D
+
+ Parameters
+ ----------
+ condition : Buffer
+ The condition buffer.
+
+ out : Buffer
+ The output buffer.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ ib = tvm.tir.ir_builder.create()
+ a0 = condition.shape[0]
+ a1 = condition.shape[1]
+ a2 = condition.shape[2]
+ a3 = condition.shape[3]
+ a4 = condition.shape[4]
+ s1 = a3 * a4
+ s2 = a2 * s1
+ s3 = a1 * s2
+ s0 = a0 * s3
+
+ condition = ib.buffer_ptr(condition)
+ out = ib.buffer_ptr(out)
+
+ valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+ tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+ one_count = tvm.tir.const(1, dtype="int32")
+
+ max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+
+ # Limit threads to a single block to make sure atomic_add works normally.
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ len_inner_for = s0 // nthread_tx + 1
+
+ fdiv = tvm.tir.floordiv
+ fmod = tvm.tir.floormod
+
+ valid_index[0] = 0
+
+ with ib.for_range(0, len_inner_for, name="i") as i:
+ idx = tx * len_inner_for + i
+ with ib.if_scope(idx < s0):
+ with ib.if_scope(condition[idx] != 0):
+ tmp[0] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
+ one_count,
+ )
+ out[tmp[0] * 5] = fdiv(idx, s3)
+ out[tmp[0] * 5 + 1] = fdiv(fmod(idx, s3), s2)
+ out[tmp[0] * 5 + 2] = fdiv(fmod(idx, s2), s1)
+ out[tmp[0] * 5 + 3] = fdiv(fmod(idx, s1), a4)
+ out[tmp[0] * 5 + 4] = fmod(idx, a4)
+
+ return ib.get()
+
+
+def argwhere_5d(output_shape, condition):
+ """Compute for argwhere 5D
+
+ Parameters
+ ----------
+ condition : list of int or tvm.tir.Any
+ The output shape
+
+ out : tvm.te.Tensor
+ Tensor with boolean values.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ condition_buf = tvm.tir.decl_buffer(
+ condition.shape, condition.dtype, "data_buf", data_alignment=8
+ )
+ out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)
+
+ out = te.extern(
+ [output_shape],
+ [condition],
+ lambda ins, outs: argwhere_5d_ir(ins[0], outs[0]),
+ dtype=["int32"],
+ in_buffers=[condition_buf],
+ out_buffers=[out_buf],
+ name="argwhere_5d",
+ tag="argwhere5d_gpu",
+ )
+
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
+ return out
+
+ # sort the output from the least significant to the most significant
+ # column.
+ sort_func = _get_sort_func(1)
+ if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
+ for i in reversed(range(5)):
+ out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+ else:
+ for i in reversed(range(5)):
+ out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1])
+ out2 = sort_func(out1, axis=0, dtype="int32")
+ out3 = squeeze(out2)
+ out = adv_index(out, [out3])
+
+ return out
+
+
+def argwhere(output_shape, condition):
+ """Find the indices of elements of a tensor that are non-zero.
+
+ Parameters
+ ----------
+ output_shape : tvm.te.Tensor
+ Tensor with output shape info.
+
+ condition : tvm.te.Tensor
+ Tensor with boolean values.
+
+ Returns
+ -------
+ out : tvm.te.Tensor
+ Indices of non-zero elements.
+ """
+ if len(condition.shape) == 1:
+ return argwhere_1d(output_shape.shape, condition)
+ if len(condition.shape) == 2:
+ return argwhere_2d(output_shape.shape, condition)
+ if len(condition.shape) == 3:
+ return argwhere_3d(output_shape.shape, condition)
+ if len(condition.shape) == 4:
+ return argwhere_4d(output_shape.shape, condition)
+ if len(condition.shape) == 5:
+ return argwhere_5d(output_shape.shape, condition)
+ raise ValueError("Argwhere does not support rank higher than 5")
+
+
+def schedule_argwhere(outs):
+ """Schedule for argwhere on cuda.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of argwhere
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for argwhere
+ """
+ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+ s = te.create_schedule([x.op for x in outs])
+ scheduled_ops = []
+
+ def traverse(op):
+ if tag.is_injective(op.tag):
+ schedule_injective_from_existing(s, op.output(0))
+ for tensor in op.input_tensors:
+ if tensor.op.input_tensors and tensor.op not in scheduled_ops:
+ traverse(tensor.op)
+ scheduled_ops.append(op)
+
+ for out in outs:
+ traverse(out.op)
+ return s
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index ac14f5a..2a7f4eb 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -550,6 +550,8 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8),
]
+ is_ascend = 1 if is_ascend else 0
+
out = te.extern(
[data.shape, data.shape],
[data],
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index ee67e67..ddf8e98 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -219,30 +219,25 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
mod["main"] = relay.Function([x], y)
data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
expected = np.argwhere(data)
- for kind in ["debug", "vm"]:
- ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
- result = ex.evaluate()(data).asnumpy()
- assert result.shape == expected.shape
- tvm.testing.assert_allclose(result.flatten(), expected.flatten())
-
- # TODO(@zhiics) argwhere gpu schedule is currently not avaiable
- # check_result([data], mod, expected, flatten=True)
+ check_result([data], mod, expected, flatten=True)
-@tvm.testing.uses_gpu
+# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have
+# to use thrust to guarantee the correct results which has been tested locally.
+# @tvm.testing.uses_gpu
def test_any_argwhere():
verify_any_argwhere(any_dims(1), (5,))
verify_any_argwhere(any_dims(2), (5, 5))
+ verify_any_argwhere(any_dims(2), (5, 5), "int32")
+ verify_any_argwhere(any_dims(2), (5, 5), "int8")
verify_any_argwhere(any_dims(3), (5, 5, 5))
verify_any_argwhere(any_dims(4), (5, 5, 5, 5))
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5))
verify_any_argwhere(any_dims(1), (5,), "int32")
- verify_any_argwhere(any_dims(2), (5, 5), "int32")
verify_any_argwhere(any_dims(3), (5, 5, 5), "int32")
verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32")
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32")
verify_any_argwhere(any_dims(1), (5,), "int8")
- verify_any_argwhere(any_dims(2), (5, 5), "int8")
verify_any_argwhere(any_dims(3), (5, 5, 5), "int8")
verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8")
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8")
diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py
new file mode 100644
index 0000000..5cb7cd4
--- /dev/null
+++ b/tests/python/topi/python/test_topi_argwhere.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test for argwhere operator"""
+import numpy as np
+
+import tvm
+from tvm import te
+from tvm import topi
+import tvm.topi.testing
+
+_argwhere_schedule = {
+ "generic": topi.generic.schedule_argwhere,
+ "gpu": topi.cuda.schedule_argwhere,
+}
+
+_argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere}
+
+
+def verify_argwhere(data_shape):
+ dtype = "int32"
+ np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype)
+ np_out = np.argwhere(np_data)
+ out_shape = np_out.shape[0]
+ np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype)
+
+ out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype)
+ condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype)
+
+ def check_device(device, ctx):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist or device not in _argwhere_compute:
+ return
+
+ with tvm.target.Target(device):
+ out = _argwhere_compute[device](out_shape, condition)
+ s_func = tvm.topi.testing.dispatch(device, _argwhere_schedule)
+ sch = s_func(out)
+
+ func = tvm.build(sch, [out_shape, condition, out], device, name="argwhere")
+
+ args = [tvm.nd.array(np_shape, ctx)]
+ args.append(tvm.nd.array(np_data, ctx))
+ args.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=condition.dtype))
+ func(*args)
+ np.set_printoptions(threshold=np.inf)
+ tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out))
+
+ for target, ctx in tvm.testing.enabled_targets():
+ check_device(target, ctx)
+
+
+# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have
+# to use thrust to guarantee the correct results which has been tested locally.
+# @tvm.testing.uses_gpu
+def test_argwhere():
+ verify_argwhere((1,))
+ verify_argwhere((100,))
+ verify_argwhere((1, 1))
+ verify_argwhere((5, 3))
+ verify_argwhere((32, 64))
+ verify_argwhere((128, 65))
+ verify_argwhere((200, 500))
+ verify_argwhere((6, 5, 3))
+ verify_argwhere((1, 1, 1))
+ verify_argwhere((1, 1, 1, 1))
+ verify_argwhere((6, 4, 5, 3))
+ verify_argwhere((1, 1, 1, 1, 1))
+ verify_argwhere((6, 4, 5, 3, 7))
+
+
+if __name__ == "__main__":
+ test_argwhere()