You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/16 05:12:27 UTC
[tvm] branch main updated: [TOPI][GPU] Mergepath sort with odd-even
block sort (#7611)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d288bbc [TOPI][GPU] Mergepath sort with odd-even block sort (#7611)
d288bbc is described below
commit d288bbc5df3660355adbf97f2f84ecd232e269ff
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Mar 15 23:12:16 2021 -0600
[TOPI][GPU] Mergepath sort with odd-even block sort (#7611)
* Mergepath sort with odd-even block sort
* fix lint, add test
* respond to review comments
* speed up tests by reducing dtype skews
* fix bad rebase
* change threading to support vulkan
* fix lint
* only sort if the data is non-empty
* fix lint again
* fix for vk
* move if to higher scope
* fix typo
Co-authored-by: Masahiro Masuda <ma...@gmail.com>
---
include/tvm/tir/stmt.h | 4 +
python/tvm/topi/cuda/sort.py | 604 +++++++++++++++++++++++++----------
src/tir/transforms/storage_access.cc | 4 +
tests/python/relay/test_op_level6.py | 9 +-
4 files changed, 457 insertions(+), 164 deletions(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index ac660bf..6445bb1 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1313,6 +1313,10 @@ constexpr const char* fragment_shape = "fragment_shape";
constexpr const char* fragment_layout = "fragment_layout";
/*!
+ * \brief Mark that the kernel is hand threaded and doesn't need syncs inserted
+ */
+constexpr const char* hand_threaded = "hand_threaded";
+/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
* \return true if it is a pragma key
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index ca832ef..5ebd306 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -57,6 +57,20 @@ def _schedule_sort(outs):
return s
+def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz):
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+
+ by = te.thread_axis("blockIdx.y")
+ bz = te.thread_axis("blockIdx.z")
+ ib.scope_attr(by, "thread_extent", nthread_by)
+ ib.scope_attr(bz, "thread_extent", nthread_bz)
+
+ return tx, bx, by, bz
+
+
def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None):
"""Initialize the output buffers by copying from inputs"""
axis_mul_before = 1
@@ -78,16 +92,8 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f
# Copy the keys_in to initial output
with ib.new_scope():
- tx = te.thread_axis("threadIdx.x")
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tid = bx * nthread_tx + tx
-
- by = te.thread_axis("blockIdx.y")
- bz = te.thread_axis("blockIdx.z")
- ib.scope_attr(by, "thread_extent", nthread_by)
- ib.scope_attr(bz, "thread_extent", nthread_bz)
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
keys_out[idx] = keys_in[idx]
@@ -97,6 +103,100 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f
return axis_mul_before, axis_mul_after
+## TODO(mbrookhart): These are effective optimziation hyperparametrs
+## Perhaps we can autotune?
+block_size = 128
+thread_work = 4
+
+
+def _odd_even_sort(
+ ib,
+ size,
+ axis_mul_before,
+ axis_mul_after,
+ is_ascend,
+ keys,
+ keys_swap,
+ values=None,
+ values_swap=None,
+):
+
+ nthread_tx = block_size // 2
+ nthread_bx = ceil_div(size, block_size)
+ nthread_by = axis_mul_before
+ nthread_bz = axis_mul_after
+ with ib.new_scope():
+ ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0)
+ tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
+ tid = 2 * tx
+ start = bx * block_size
+
+ ## Create shared memory as syncable thread scratch space
+ tmp_keys_swap = ib.allocate(
+ keys_swap.dtype,
+ (block_size,),
+ name="temp_keys_swap",
+ scope="shared",
+ )
+ if values_swap is not None:
+ tmp_values_swap = ib.allocate(
+ values_swap.dtype,
+ (block_size,),
+ name="temp_values_swap",
+ scope="shared",
+ )
+
+ ## Create thread local data for swapping
+ temp_keys = ib.allocate(keys_swap.dtype, (1,), name="temp_keys", scope="local")
+ if values_swap is not None:
+ temp_values = ib.allocate(values_swap.dtype, (1,), name="temp_values", scope="local")
+
+ temp_cond1 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond1", scope="local")
+ temp_cond2 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond2", scope="local")
+ # Copy data to scratch space
+ base_idx = by * size * axis_mul_after + bz
+ with ib.for_range(0, 2) as n:
+ with ib.if_scope((tid + n + start) < size):
+ tmp_keys_swap[tid + n] = keys[base_idx + (tid + n + start) * axis_mul_after]
+ if values_swap is not None:
+ tmp_values_swap[tid + n] = values[base_idx + (tid + n + start) * axis_mul_after]
+
+ ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
+
+ idxm = tvm.tir.indexmod
+ # OddEvenTransposeSort
+ current_sort_num = tvm.tir.min(block_size, size - start)
+ with ib.for_range(0, current_sort_num) as k:
+ n = idxm(tid + k, 2)
+ with ib.if_scope(tid + n < current_sort_num - 1):
+ temp_cond1[0] = tmp_keys_swap[tid + n]
+ temp_cond2[0] = tmp_keys_swap[tid + n + 1]
+ if is_ascend:
+ cond = temp_cond1[0] > temp_cond2[0]
+ else:
+ cond = temp_cond1[0] < temp_cond2[0]
+ with ib.if_scope(cond):
+ temp_keys[0] = tmp_keys_swap[tid + n]
+ tmp_keys_swap[tid + n] = tmp_keys_swap[tid + n + 1]
+ tmp_keys_swap[tid + n + 1] = temp_keys[0]
+ if values_swap is not None:
+ temp_values[0] = tmp_values_swap[tid + n]
+ tmp_values_swap[tid + n] = tmp_values_swap[tid + n + 1]
+ tmp_values_swap[tid + n + 1] = temp_values[0]
+ ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
+
+ ## Copy sorted data to output
+ with ib.for_range(0, 2) as n:
+ with ib.if_scope(tid + n + start < size):
+ keys[base_idx + (tid + n + start) * axis_mul_after] = tmp_keys_swap[tid + n]
+ keys_swap[base_idx + (tid + n + start) * axis_mul_after] = tmp_keys_swap[tid + n]
+ if values_swap is not None:
+ values[base_idx + (tid + n + start) * axis_mul_after] = tmp_values_swap[tid + n]
+ values_swap[base_idx + (tid + n + start) * axis_mul_after] = tmp_values_swap[
+ tid + n
+ ]
+
+
def _sort_common(
ib,
size,
@@ -110,22 +210,22 @@ def _sort_common(
):
"""Either sort only values or sort values by keys."""
- ## we are looping over the array doing mergesort from the bottom up.
- ## The outer loop runs on the host and launches a cuda kernel for each iteration
- ## of the algorithm.
- ## The basic idea is that at iteration 0, each thread does sort on 2 elements.
- ## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
- ## to deal with 4 total elements.
- ## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
- ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc
- ## On the final iteration of the algorithm, one thread will merge two sorted lists
- ## to sort the entire array
+ ## This function performs a multi-level mergesort
+ ## For blocks of length <= block_size, it does odd-even transpose sort
+ ## in GPU shared memory
+ ## For intermediate block sizes (>block_size, < max_threads * thread_work)
+ ## it uses the mergpath algorthim https://arxiv.org/abs/1406.2628
+ ## to merge blocks in parallel
+ ## At some point, the size of the blocks to be merged is too big for max_threads
+ ## and we switch to using a dual-level mergepath where the outer mergepath
+ ## finds the start/end locations of the inner mergepath so that we can split
+ ## the merge into more blocks
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_by = axis_mul_before * axis_mul_after
+ nthread_bz = 1
nthread_tx = max_threads
- nthread_bx = ceil_div(size, max_threads)
- nthread_by = axis_mul_before
- nthread_bz = axis_mul_after
+ nthread_bx = ceil_div(size, nthread_tx)
def compare(a, b):
"""
@@ -137,91 +237,234 @@ def _sort_common(
out = b <= a
return out
- def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even):
- """
- Merge the two sections of the array assigned to this thread
- """
- # pylint: disable=arguments-out-of-order
- # initialize iterators
+ # Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
+ lower_lim = tvm.tir.generic.cast(
+ tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64"
+ )
+
+ _odd_even_sort(
+ ib,
+ size,
+ axis_mul_before * axis_mul_after,
+ 1,
+ is_ascend,
+ keys,
+ keys_swap,
+ values,
+ values_swap,
+ )
+
+ upper_lim = tvm.tir.generic.cast(
+ tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
+ )
+
+ def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
+ first = ib.allocate("int64", (1,), name="first", scope="local")
+ mid = ib.allocate("int64", (1,), name="mid", scope="local")
+ last = ib.allocate("int64", (1,), name="last", scope="local")
+ first[0] = tvm.te.max(0, diag - bCount)
+ last[0] = tvm.te.min(diag, aCount)
+ with ib.while_loop(first[0] < last[0]):
+ mid = (first[0] + last[0]) >> 1
+ a = source[base_idx + (aStart + mid)]
+ b = source[base_idx + (bStart + diag - 1 - mid)]
+ with ib.if_scope(compare(a, b)):
+ first[0] = mid + 1
+ with ib.else_scope():
+ last[0] = mid
+ return first[0], last[0]
+
+ def serial_merge(
+ source,
+ dest,
+ source_idx,
+ dest_idx,
+ base_idx,
+ aCount,
+ bCount,
+ aStart,
+ bStart,
+ kStart,
+ diag,
+ step_count,
+ first,
+ last,
+ ):
i = ib.allocate("int64", (1,), name="i", scope="local")
j = ib.allocate("int64", (1,), name="j", scope="local")
- i[0] = start
- j[0] = middle
- # set up indexes
- base_idx = by * size * axis_mul_after + bz
- # iterate over the output loop
- with ib.for_range(0, end - start) as k:
- i_idx = base_idx + i[0] * axis_mul_after
- j_idx = base_idx + j[0] * axis_mul_after
- k_idx = base_idx + (k + start) * axis_mul_after
-
- def swap_values(source, dest, source_idx, dest_idx):
- def assign_i():
- """assign i value to current output"""
- dest[k_idx] = source[i_idx]
- if values is not None:
- dest_idx[k_idx] = source_idx[i_idx]
- i[0] += 1
-
- def assign_j():
- """assign j value to current output"""
- dest[k_idx] = source[j_idx]
- if values is not None:
- dest_idx[k_idx] = source_idx[j_idx]
- j[0] += 1
-
- ## if both of the iterators are in range
- with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
- # compare them and insert whichever is next into the output
- with ib.if_scope(compare(source[i_idx], source[j_idx])):
- assign_i()
- with ib.else_scope():
- assign_j()
- # otherwise, simply copy the remainder of the valid iterator to the output
- with ib.else_scope():
- with ib.if_scope(i[0] < middle):
- assign_i()
- with ib.else_scope():
- assign_j()
+ i[0] = aStart + first
+ j[0] = bStart + diag - last
+ with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count:
+ i_idx = base_idx + i[0]
+ j_idx = base_idx + j[0]
+ k_idx = base_idx + (kStart + diag + count)
+
+ def assign_i():
+ """assign i value to current output"""
+ dest[k_idx] = source[i_idx]
+ if values is not None:
+ dest_idx[k_idx] = source_idx[i_idx]
+ i[0] += 1
- # Switch which input is the source and which is the destination each iteration
- with ib.if_scope(even):
- swap_values(source, dest, source_idx, dest_idx)
+ def assign_j():
+ """assign j value to current output"""
+ dest[k_idx] = source[j_idx]
+ if values is not None:
+ dest_idx[k_idx] = source_idx[j_idx]
+ j[0] += 1
+
+ ## if both of the iterators are in range
+ with ib.if_scope(tvm.tir.all(i[0] < aStart + aCount, j[0] < bStart + bCount)):
+ # compare them and insert whichever is next into the output
+ with ib.if_scope(compare(source[i_idx], source[j_idx])):
+ assign_i()
+ with ib.else_scope():
+ assign_j()
+ # otherwise, simply copy the remainder of the valid iterator to the output
with ib.else_scope():
- swap_values(dest, source, dest_idx, source_idx)
-
- def mergesort(source, dest, source_idx, dest_idx, size, width, even):
- # calculate the start, mid, and end points of this section
- start = width * tid
-
- with ib.if_scope(start < size):
- middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64")
- end = cast(tvm.te.min(start + width, size), "int64")
- # merge the start->middle and middle->end arrays
- bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even)
+ with ib.if_scope(i[0] < aStart + aCount):
+ assign_i()
+ with ib.else_scope():
+ assign_j()
- lim = tvm.tir.generic.cast(
- tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
- )
- with ib.for_range(0, lim, dtype="int64") as l2_width:
- width = 2 << l2_width
+ with ib.for_range(0, upper_lim - lower_lim, dtype="int64") as l2_width:
+ width = 2 << (l2_width + lower_lim)
# Define and launch the cuda kernel
with ib.new_scope():
- tx = te.thread_axis("threadIdx.x")
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- # Reduce the number of blocks as the work per thread grows
- ib.scope_attr(
- bx,
- "thread_extent",
- tvm.tir.generic.cast(ceil_div(size, width * max_threads), "int32"),
- )
- tid = bx * nthread_tx + tx
-
- by = te.thread_axis("blockIdx.y")
- bz = te.thread_axis("blockIdx.z")
- ib.scope_attr(by, "thread_extent", nthread_by)
- ib.scope_attr(bz, "thread_extent", nthread_bz)
+ target = tvm.target.Target.current()
+ if "vulkan" in str(target):
+ # Vulkan can't handle dynamic nthread, so we thread slightly differently
+ # for vulkan. We don't do this generally because it causes a 15% perf
+ # regression on other platforms
+ ntx = max_threads
+ nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
+ nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
+ tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
+ else:
+ ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32")
+ nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
+ nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
+ tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
+
+ def mergepath(
+ source,
+ dest,
+ source_idx,
+ dest_idx,
+ aCount,
+ bCount,
+ aStart,
+ bStart,
+ kStart,
+ step_count,
+ even,
+ ):
+ # pylint: disable=arguments-out-of-order
+ def merge(source, dest, source_idx, dest_idx):
+ diag = tx * step_count
+ first, last = get_merge_begin(
+ source,
+ by * size,
+ aCount,
+ bCount,
+ aStart,
+ bStart,
+ diag,
+ step_count,
+ )
+ # iterate over the output loop
+ serial_merge(
+ source,
+ dest,
+ source_idx,
+ dest_idx,
+ by * size,
+ aCount,
+ bCount,
+ aStart,
+ bStart,
+ kStart,
+ diag,
+ step_count,
+ first,
+ last,
+ )
+
+ with ib.if_scope(even):
+ merge(source, dest, source_idx, dest_idx)
+ with ib.else_scope():
+ merge(dest, source, dest_idx, source_idx)
+
+ def mergesort(source, dest, source_idx, dest_idx, size, width, even):
+ # calculate the start, mid, and end points of this section
+ start = width * bz
+ middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64")
+ end = cast(tvm.te.min(start + width, size), "int64")
+ with ib.if_scope(start < size):
+ with ib.if_scope(nbx == 1):
+ ## merge the start->middle and middle->end arrays
+ aCount = middle - start
+ bCount = end - middle
+ mergepath(
+ source,
+ dest,
+ source_idx,
+ dest_idx,
+ aCount,
+ bCount,
+ start,
+ middle,
+ start,
+ ceil_div(width, ntx),
+ even,
+ )
+ with ib.else_scope():
+ step_count = max_threads * thread_work
+ diag = bx * step_count
+
+ def do_merge(first, last):
+ aStart = start + first
+ bStart = middle + diag - last
+ aCount = tvm.te.min(middle - aStart, step_count)
+ bCount = tvm.te.min(end - bStart, step_count)
+ mergepath(
+ source,
+ dest,
+ source_idx,
+ dest_idx,
+ aCount,
+ bCount,
+ aStart,
+ bStart,
+ start + diag,
+ thread_work,
+ even,
+ )
+
+ with ib.if_scope(even):
+ first, last = get_merge_begin(
+ source,
+ by * size,
+ middle - start,
+ end - middle,
+ start,
+ middle,
+ diag,
+ step_count,
+ )
+ do_merge(first, last)
+ with ib.else_scope():
+ first, last = get_merge_begin(
+ dest,
+ by * size,
+ middle - start,
+ end - middle,
+ start,
+ middle,
+ diag,
+ step_count,
+ )
+ do_merge(first, last)
# Call the kernel
mergesort(
@@ -233,29 +476,23 @@ def _sort_common(
width,
tvm.tir.indexmod(l2_width, 2) == 0,
)
-
+ nthread_by = axis_mul_before
+ nthread_bz = axis_mul_after
+ nthread_tx = max_threads
+ nthread_bx = ceil_div(size, nthread_tx)
## if the final sorted data ended up in the swap, copy it to the real output
- with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1):
+ with ib.if_scope(
+ tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1)
+ ):
with ib.new_scope():
- tx = te.thread_axis("threadIdx.x")
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tid = bx * nthread_tx + tx
-
- by = te.thread_axis("blockIdx.y")
- bz = te.thread_axis("blockIdx.z")
- ib.scope_attr(by, "thread_extent", nthread_by)
- ib.scope_attr(bz, "thread_extent", nthread_bz)
- idx = (by * size + tid) * axis_mul_after + bz
+ idx = (by * axis_mul_after + bz) * size + tid
with ib.if_scope(tid < size):
- idx = (by * size + tid) * axis_mul_after + bz
keys[idx] = keys_swap[idx]
if values is not None:
values[idx] = values_swap[idx]
- return ib.get()
-
def sort_ir(
data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None
@@ -301,27 +538,30 @@ def sort_ir(
assert indices_out_swap is not None
indices_out_swap = ib.buffer_ptr(indices_out_swap)
- axis_mul_before, axis_mul_after = _sort_init(
- ib,
- shape,
- axis,
- data,
- values_out,
- indices_out,
- value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype),
- )
+ with ib.if_scope(shape[axis] > 0):
+ axis_mul_before, axis_mul_after = _sort_init(
+ ib,
+ shape,
+ axis,
+ data,
+ values_out,
+ indices_out,
+ value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype),
+ )
+
+ _sort_common(
+ ib,
+ shape[axis],
+ axis_mul_before,
+ axis_mul_after,
+ is_ascend,
+ values_out,
+ values_out_swap,
+ values=indices_out,
+ values_swap=indices_out_swap,
+ )
- return _sort_common(
- ib,
- shape[axis],
- axis_mul_before,
- axis_mul_after,
- is_ascend,
- values_out,
- values_out_swap,
- values=indices_out,
- values_swap=indices_out_swap,
- )
+ return ib.get()
def sort_by_key_ir(
@@ -376,27 +616,29 @@ def sort_by_key_ir(
values_out = ib.buffer_ptr(values_out)
values_out_swap = ib.buffer_ptr(values_out_swap)
- axis_mul_before, axis_mul_after = _sort_init(
- ib,
- shape,
- axis,
- keys_in,
- keys_out,
- values_out,
- value_init_func=lambda idx, _: values_in[idx],
- )
-
- return _sort_common(
- ib,
- shape[axis],
- axis_mul_before,
- axis_mul_after,
- is_ascend,
- keys_out,
- keys_out_swap,
- values=values_out,
- values_swap=values_out_swap,
- )
+ with ib.if_scope(shape[axis] > 0):
+ axis_mul_before, axis_mul_after = _sort_init(
+ ib,
+ shape,
+ axis,
+ keys_in,
+ keys_out,
+ values_out,
+ value_init_func=lambda idx, _: values_in[idx],
+ )
+
+ _sort_common(
+ ib,
+ shape[axis],
+ axis_mul_before,
+ axis_mul_after,
+ is_ascend,
+ keys_out,
+ keys_out_swap,
+ values=values_out,
+ values_swap=values_out_swap,
+ )
+ return ib.get()
def sort(data, axis=-1, is_ascend=1):
@@ -419,16 +661,29 @@ def sort(data, axis=-1, is_ascend=1):
out : tvm.te.Tensor
The output of this function.
"""
+ ndim = len(data.shape)
+ axis = ndim + axis if axis < 0 else axis
+ if axis != ndim - 1:
+ # Prepare for sorting along axis -1.
+ axes = swap(list(range(ndim)), axis)
+ data = transpose(data, axes)
+
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8)
+
out = te.extern(
[data.shape, data.shape],
[data],
- lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
+ lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], -1, is_ascend),
out_buffers=[value_buf, value_buf_swap],
name="sort_gpu",
tag="sort_gpu",
)[0]
+
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ out = transpose(out, axes)
+
return out
@@ -507,10 +762,18 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
out : tvm.te.Tensor
The output of this function.
"""
+ ndim = len(data.shape)
+ axis = ndim + axis if axis < 0 else axis
+ if axis != ndim - 1:
+ # Prepare for sorting along axis -1.
+ axes = swap(list(range(ndim)), axis)
+ data = transpose(data, axes)
+
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_swap_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8)
+
out = te.extern(
[data.shape, data.shape, data.shape, data.shape],
[data],
@@ -518,7 +781,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
ins[0],
outs[0],
outs[2],
- axis,
+ -1,
is_ascend,
indices_out=outs[1],
indices_out_swap=outs[3],
@@ -527,6 +790,11 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
name="argsort_gpu",
tag="argsort_gpu",
)[1]
+
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ out = transpose(out, axes)
+
return out
@@ -625,21 +893,30 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
ndim = len(data.shape)
axis = axis + ndim if axis < 0 else axis
assert 0 <= axis < ndim
+ dshape = data.shape
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ data = transpose(data, axes)
+
values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
values_swap_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "values_swap_buf", data_alignment=8
)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8)
+
if ret_type == "values":
output = te.extern(
[data.shape, data.shape],
[data],
- lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
+ lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], -1, is_ascend),
out_buffers=[values_buf, values_swap_buf],
name="topk_gpu",
tag="topk_gpu",
)[0]
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ output = transpose(output, axes)
else:
output = te.extern(
[data.shape, data.shape, data.shape, data.shape],
@@ -648,7 +925,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
ins[0],
outs[0],
outs[2],
- axis,
+ -1,
is_ascend,
indices_out=outs[1],
indices_out_swap=outs[3],
@@ -657,6 +934,11 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
name="topk_gpu",
tag="topk_gpu",
)[0:2]
+ if axis != ndim - 1:
+ axes = swap(list(range(ndim)), axis)
+ output[0] = transpose(output[0], axes)
+ output[1] = transpose(output[1], axes)
+
if isinstance(k, int) and k < 1:
if ret_type == "indices":
return output[1]
@@ -668,7 +950,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
if i == axis:
end.append(k if isinstance(k, int) else tvm.te.size_var("dim"))
else:
- end.append(data.shape[i])
+ end.append(dshape[i])
if ret_type == "both":
values_out, indices_out = output
values_out = strided_slice(values_out, beg, end, strides)
diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc
index 38143c1..00002d3 100644
--- a/src/tir/transforms/storage_access.cc
+++ b/src/tir/transforms/storage_access.cc
@@ -132,6 +132,10 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
StmtExprVisitor::VisitStmt_(op);
}
env_threads_.pop_back();
+ } else if (op->attr_key == attr::hand_threaded) {
+ // skip this pass on blocks that were hand_threaded
+ // this avoids control flow and read/write conflicts
+ // between hand-threaded kernels and automatic threading
} else {
StmtExprVisitor::VisitStmt_(op);
}
diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py
index 0dac69e..f4b785f 100644
--- a/tests/python/relay/test_op_level6.py
+++ b/tests/python/relay/test_op_level6.py
@@ -26,6 +26,7 @@ import tvm.testing
@tvm.testing.uses_gpu
def test_sort():
def verify_sort(shape, axis, is_ascend, is_dyn=False):
+
if is_dyn:
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32"))
else:
@@ -87,9 +88,11 @@ def test_argsort():
for dtype in ["int32", "int64", "float32", "float64"]:
verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn)
- verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
- verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
- verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+ dtype = "int32"
+ verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+ verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+ verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+ verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
@tvm.testing.uses_gpu