You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/16 05:12:27 UTC

[tvm] branch main updated: [TOPI][GPU] Mergepath sort with odd-even block sort (#7611)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d288bbc  [TOPI][GPU] Mergepath sort with odd-even block sort (#7611)
d288bbc is described below

commit d288bbc5df3660355adbf97f2f84ecd232e269ff
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Mar 15 23:12:16 2021 -0600

    [TOPI][GPU] Mergepath sort with odd-even block sort (#7611)
    
    * Mergepath sort with odd-even block sort
    
    * fix lint, add test
    
    * respond to review comments
    
    * speed up tests by reducing dtype skews
    
    * fix bad rebase
    
    * change threading to support vulkan
    
    * fix lint
    
    * only sort if the data is non-empty
    
    * fix lint again
    
    * fix for vk
    
    * move if to higher scope
    
    * fix typo
    
    Co-authored-by: Masahiro Masuda <ma...@gmail.com>
---
 include/tvm/tir/stmt.h               |   4 +
 python/tvm/topi/cuda/sort.py         | 604 +++++++++++++++++++++++++----------
 src/tir/transforms/storage_access.cc |   4 +
 tests/python/relay/test_op_level6.py |   9 +-
 4 files changed, 457 insertions(+), 164 deletions(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index ac660bf..6445bb1 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1313,6 +1313,10 @@ constexpr const char* fragment_shape = "fragment_shape";
 constexpr const char* fragment_layout = "fragment_layout";
 
 /*!
+ * \brief Mark that the kernel is hand threaded and doesn't need syncs inserted
+ */
+constexpr const char* hand_threaded = "hand_threaded";
+/*!
  * \brief Check if attr_key is a pragma key extension
  * \param attr_key The attr key to be compared
  * \return true if it is a pragma key
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index ca832ef..5ebd306 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -57,6 +57,20 @@ def _schedule_sort(outs):
     return s
 
 
+def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz):
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+
+    by = te.thread_axis("blockIdx.y")
+    bz = te.thread_axis("blockIdx.z")
+    ib.scope_attr(by, "thread_extent", nthread_by)
+    ib.scope_attr(bz, "thread_extent", nthread_bz)
+
+    return tx, bx, by, bz
+
+
 def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None):
     """Initialize the output buffers by copying from inputs"""
     axis_mul_before = 1
@@ -78,16 +92,8 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f
 
     # Copy the keys_in to initial output
     with ib.new_scope():
-        tx = te.thread_axis("threadIdx.x")
-        bx = te.thread_axis("blockIdx.x")
-        ib.scope_attr(tx, "thread_extent", nthread_tx)
-        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
         tid = bx * nthread_tx + tx
-
-        by = te.thread_axis("blockIdx.y")
-        bz = te.thread_axis("blockIdx.z")
-        ib.scope_attr(by, "thread_extent", nthread_by)
-        ib.scope_attr(bz, "thread_extent", nthread_bz)
         idx = (by * shape[axis] + tid) * axis_mul_after + bz
         with ib.if_scope(tid < shape[axis]):
             keys_out[idx] = keys_in[idx]
@@ -97,6 +103,100 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f
     return axis_mul_before, axis_mul_after
 
 
+## TODO(mbrookhart): These are effective optimziation hyperparametrs
+## Perhaps we can autotune?
+block_size = 128
+thread_work = 4
+
+
+def _odd_even_sort(
+    ib,
+    size,
+    axis_mul_before,
+    axis_mul_after,
+    is_ascend,
+    keys,
+    keys_swap,
+    values=None,
+    values_swap=None,
+):
+
+    nthread_tx = block_size // 2
+    nthread_bx = ceil_div(size, block_size)
+    nthread_by = axis_mul_before
+    nthread_bz = axis_mul_after
+    with ib.new_scope():
+        ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0)
+        tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
+        tid = 2 * tx
+        start = bx * block_size
+
+        ## Create shared memory as syncable thread scratch space
+        tmp_keys_swap = ib.allocate(
+            keys_swap.dtype,
+            (block_size,),
+            name="temp_keys_swap",
+            scope="shared",
+        )
+        if values_swap is not None:
+            tmp_values_swap = ib.allocate(
+                values_swap.dtype,
+                (block_size,),
+                name="temp_values_swap",
+                scope="shared",
+            )
+
+        ## Create thread local data for swapping
+        temp_keys = ib.allocate(keys_swap.dtype, (1,), name="temp_keys", scope="local")
+        if values_swap is not None:
+            temp_values = ib.allocate(values_swap.dtype, (1,), name="temp_values", scope="local")
+
+        temp_cond1 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond1", scope="local")
+        temp_cond2 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond2", scope="local")
+        # Copy data to scratch space
+        base_idx = by * size * axis_mul_after + bz
+        with ib.for_range(0, 2) as n:
+            with ib.if_scope((tid + n + start) < size):
+                tmp_keys_swap[tid + n] = keys[base_idx + (tid + n + start) * axis_mul_after]
+                if values_swap is not None:
+                    tmp_values_swap[tid + n] = values[base_idx + (tid + n + start) * axis_mul_after]
+
+        ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
+
+        idxm = tvm.tir.indexmod
+        # OddEvenTransposeSort
+        current_sort_num = tvm.tir.min(block_size, size - start)
+        with ib.for_range(0, current_sort_num) as k:
+            n = idxm(tid + k, 2)
+            with ib.if_scope(tid + n < current_sort_num - 1):
+                temp_cond1[0] = tmp_keys_swap[tid + n]
+                temp_cond2[0] = tmp_keys_swap[tid + n + 1]
+                if is_ascend:
+                    cond = temp_cond1[0] > temp_cond2[0]
+                else:
+                    cond = temp_cond1[0] < temp_cond2[0]
+                with ib.if_scope(cond):
+                    temp_keys[0] = tmp_keys_swap[tid + n]
+                    tmp_keys_swap[tid + n] = tmp_keys_swap[tid + n + 1]
+                    tmp_keys_swap[tid + n + 1] = temp_keys[0]
+                    if values_swap is not None:
+                        temp_values[0] = tmp_values_swap[tid + n]
+                        tmp_values_swap[tid + n] = tmp_values_swap[tid + n + 1]
+                        tmp_values_swap[tid + n + 1] = temp_values[0]
+            ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
+
+        ## Copy sorted data to output
+        with ib.for_range(0, 2) as n:
+            with ib.if_scope(tid + n + start < size):
+                keys[base_idx + (tid + n + start) * axis_mul_after] = tmp_keys_swap[tid + n]
+                keys_swap[base_idx + (tid + n + start) * axis_mul_after] = tmp_keys_swap[tid + n]
+                if values_swap is not None:
+                    values[base_idx + (tid + n + start) * axis_mul_after] = tmp_values_swap[tid + n]
+                    values_swap[base_idx + (tid + n + start) * axis_mul_after] = tmp_values_swap[
+                        tid + n
+                    ]
+
+
 def _sort_common(
     ib,
     size,
@@ -110,22 +210,22 @@ def _sort_common(
 ):
     """Either sort only values or sort values by keys."""
 
-    ## we are looping over the array doing mergesort from the bottom up.
-    ## The outer loop runs on the host and launches a cuda kernel for each iteration
-    ## of the algorithm.
-    ## The basic idea is that at iteration 0, each thread does sort on 2 elements.
-    ## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
-    ## to deal with 4 total elements.
-    ## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
-    ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc
-    ## On the final iteration of the algorithm, one thread will merge two sorted lists
-    ## to sort the entire array
+    ## This function performs a multi-level mergesort
+    ## For blocks of length <= block_size, it does odd-even transpose sort
+    ##    in GPU shared memory
+    ## For intermediate block sizes (>block_size, < max_threads * thread_work)
+    ##    it uses the mergpath algorthim https://arxiv.org/abs/1406.2628
+    ##    to merge blocks in parallel
+    ## At some point, the size of the blocks to be merged is too big for max_threads
+    ##    and we switch to using a dual-level mergepath where the outer mergepath
+    ##    finds the start/end locations of the inner mergepath so that we can split
+    ##    the merge into more blocks
 
     max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_by = axis_mul_before * axis_mul_after
+    nthread_bz = 1
     nthread_tx = max_threads
-    nthread_bx = ceil_div(size, max_threads)
-    nthread_by = axis_mul_before
-    nthread_bz = axis_mul_after
+    nthread_bx = ceil_div(size, nthread_tx)
 
     def compare(a, b):
         """
@@ -137,91 +237,234 @@ def _sort_common(
             out = b <= a
         return out
 
-    def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even):
-        """
-        Merge the two sections of the array assigned to this thread
-        """
-        # pylint: disable=arguments-out-of-order
-        # initialize iterators
+    # Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
+    lower_lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64"
+    )
+
+    _odd_even_sort(
+        ib,
+        size,
+        axis_mul_before * axis_mul_after,
+        1,
+        is_ascend,
+        keys,
+        keys_swap,
+        values,
+        values_swap,
+    )
+
+    upper_lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
+    )
+
+    def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
+        first = ib.allocate("int64", (1,), name="first", scope="local")
+        mid = ib.allocate("int64", (1,), name="mid", scope="local")
+        last = ib.allocate("int64", (1,), name="last", scope="local")
+        first[0] = tvm.te.max(0, diag - bCount)
+        last[0] = tvm.te.min(diag, aCount)
+        with ib.while_loop(first[0] < last[0]):
+            mid = (first[0] + last[0]) >> 1
+            a = source[base_idx + (aStart + mid)]
+            b = source[base_idx + (bStart + diag - 1 - mid)]
+            with ib.if_scope(compare(a, b)):
+                first[0] = mid + 1
+            with ib.else_scope():
+                last[0] = mid
+        return first[0], last[0]
+
+    def serial_merge(
+        source,
+        dest,
+        source_idx,
+        dest_idx,
+        base_idx,
+        aCount,
+        bCount,
+        aStart,
+        bStart,
+        kStart,
+        diag,
+        step_count,
+        first,
+        last,
+    ):
         i = ib.allocate("int64", (1,), name="i", scope="local")
         j = ib.allocate("int64", (1,), name="j", scope="local")
-        i[0] = start
-        j[0] = middle
-        # set up indexes
-        base_idx = by * size * axis_mul_after + bz
-        # iterate over the output loop
-        with ib.for_range(0, end - start) as k:
-            i_idx = base_idx + i[0] * axis_mul_after
-            j_idx = base_idx + j[0] * axis_mul_after
-            k_idx = base_idx + (k + start) * axis_mul_after
-
-            def swap_values(source, dest, source_idx, dest_idx):
-                def assign_i():
-                    """assign i value to current output"""
-                    dest[k_idx] = source[i_idx]
-                    if values is not None:
-                        dest_idx[k_idx] = source_idx[i_idx]
-                    i[0] += 1
-
-                def assign_j():
-                    """assign j value to current output"""
-                    dest[k_idx] = source[j_idx]
-                    if values is not None:
-                        dest_idx[k_idx] = source_idx[j_idx]
-                    j[0] += 1
-
-                ## if both of the iterators are in range
-                with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
-                    # compare them and insert whichever is next into the output
-                    with ib.if_scope(compare(source[i_idx], source[j_idx])):
-                        assign_i()
-                    with ib.else_scope():
-                        assign_j()
-                # otherwise, simply copy the remainder of the valid iterator to the output
-                with ib.else_scope():
-                    with ib.if_scope(i[0] < middle):
-                        assign_i()
-                    with ib.else_scope():
-                        assign_j()
+        i[0] = aStart + first
+        j[0] = bStart + diag - last
+        with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count:
+            i_idx = base_idx + i[0]
+            j_idx = base_idx + j[0]
+            k_idx = base_idx + (kStart + diag + count)
+
+            def assign_i():
+                """assign i value to current output"""
+                dest[k_idx] = source[i_idx]
+                if values is not None:
+                    dest_idx[k_idx] = source_idx[i_idx]
+                i[0] += 1
 
-            # Switch which input is the source and which is the destination each iteration
-            with ib.if_scope(even):
-                swap_values(source, dest, source_idx, dest_idx)
+            def assign_j():
+                """assign j value to current output"""
+                dest[k_idx] = source[j_idx]
+                if values is not None:
+                    dest_idx[k_idx] = source_idx[j_idx]
+                j[0] += 1
+
+            ## if both of the iterators are in range
+            with ib.if_scope(tvm.tir.all(i[0] < aStart + aCount, j[0] < bStart + bCount)):
+                # compare them and insert whichever is next into the output
+                with ib.if_scope(compare(source[i_idx], source[j_idx])):
+                    assign_i()
+                with ib.else_scope():
+                    assign_j()
+            # otherwise, simply copy the remainder of the valid iterator to the output
             with ib.else_scope():
-                swap_values(dest, source, dest_idx, source_idx)
-
-    def mergesort(source, dest, source_idx, dest_idx, size, width, even):
-        # calculate the start, mid, and end points of this section
-        start = width * tid
-
-        with ib.if_scope(start < size):
-            middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64")
-            end = cast(tvm.te.min(start + width, size), "int64")
-            # merge the start->middle and middle->end arrays
-            bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even)
+                with ib.if_scope(i[0] < aStart + aCount):
+                    assign_i()
+                with ib.else_scope():
+                    assign_j()
 
-    lim = tvm.tir.generic.cast(
-        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
-    )
-    with ib.for_range(0, lim, dtype="int64") as l2_width:
-        width = 2 << l2_width
+    with ib.for_range(0, upper_lim - lower_lim, dtype="int64") as l2_width:
+        width = 2 << (l2_width + lower_lim)
         # Define and launch the cuda kernel
         with ib.new_scope():
-            tx = te.thread_axis("threadIdx.x")
-            bx = te.thread_axis("blockIdx.x")
-            ib.scope_attr(tx, "thread_extent", nthread_tx)
-            # Reduce the number of blocks as the work per thread grows
-            ib.scope_attr(
-                bx,
-                "thread_extent",
-                tvm.tir.generic.cast(ceil_div(size, width * max_threads), "int32"),
-            )
-            tid = bx * nthread_tx + tx
-
-            by = te.thread_axis("blockIdx.y")
-            bz = te.thread_axis("blockIdx.z")
-            ib.scope_attr(by, "thread_extent", nthread_by)
-            ib.scope_attr(bz, "thread_extent", nthread_bz)
+            target = tvm.target.Target.current()
+            if "vulkan" in str(target):
+                # Vulkan can't handle dynamic nthread, so we thread slightly differently
+                # for vulkan. We don't do this generally because it causes a 15% perf
+                # regression on other platforms
+                ntx = max_threads
+                nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
+                nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
+                tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
+            else:
+                ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32")
+                nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
+                nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
+                tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
+
+            def mergepath(
+                source,
+                dest,
+                source_idx,
+                dest_idx,
+                aCount,
+                bCount,
+                aStart,
+                bStart,
+                kStart,
+                step_count,
+                even,
+            ):
+                # pylint: disable=arguments-out-of-order
+                def merge(source, dest, source_idx, dest_idx):
+                    diag = tx * step_count
+                    first, last = get_merge_begin(
+                        source,
+                        by * size,
+                        aCount,
+                        bCount,
+                        aStart,
+                        bStart,
+                        diag,
+                        step_count,
+                    )
+                    # iterate over the output loop
+                    serial_merge(
+                        source,
+                        dest,
+                        source_idx,
+                        dest_idx,
+                        by * size,
+                        aCount,
+                        bCount,
+                        aStart,
+                        bStart,
+                        kStart,
+                        diag,
+                        step_count,
+                        first,
+                        last,
+                    )
+
+                with ib.if_scope(even):
+                    merge(source, dest, source_idx, dest_idx)
+                with ib.else_scope():
+                    merge(dest, source, dest_idx, source_idx)
+
+            def mergesort(source, dest, source_idx, dest_idx, size, width, even):
+                # calculate the start, mid, and end points of this section
+                start = width * bz
+                middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64")
+                end = cast(tvm.te.min(start + width, size), "int64")
+                with ib.if_scope(start < size):
+                    with ib.if_scope(nbx == 1):
+                        ## merge the start->middle and middle->end arrays
+                        aCount = middle - start
+                        bCount = end - middle
+                        mergepath(
+                            source,
+                            dest,
+                            source_idx,
+                            dest_idx,
+                            aCount,
+                            bCount,
+                            start,
+                            middle,
+                            start,
+                            ceil_div(width, ntx),
+                            even,
+                        )
+                    with ib.else_scope():
+                        step_count = max_threads * thread_work
+                        diag = bx * step_count
+
+                        def do_merge(first, last):
+                            aStart = start + first
+                            bStart = middle + diag - last
+                            aCount = tvm.te.min(middle - aStart, step_count)
+                            bCount = tvm.te.min(end - bStart, step_count)
+                            mergepath(
+                                source,
+                                dest,
+                                source_idx,
+                                dest_idx,
+                                aCount,
+                                bCount,
+                                aStart,
+                                bStart,
+                                start + diag,
+                                thread_work,
+                                even,
+                            )
+
+                        with ib.if_scope(even):
+                            first, last = get_merge_begin(
+                                source,
+                                by * size,
+                                middle - start,
+                                end - middle,
+                                start,
+                                middle,
+                                diag,
+                                step_count,
+                            )
+                            do_merge(first, last)
+                        with ib.else_scope():
+                            first, last = get_merge_begin(
+                                dest,
+                                by * size,
+                                middle - start,
+                                end - middle,
+                                start,
+                                middle,
+                                diag,
+                                step_count,
+                            )
+                            do_merge(first, last)
 
             # Call the kernel
             mergesort(
@@ -233,29 +476,23 @@ def _sort_common(
                 width,
                 tvm.tir.indexmod(l2_width, 2) == 0,
             )
-
+    nthread_by = axis_mul_before
+    nthread_bz = axis_mul_after
+    nthread_tx = max_threads
+    nthread_bx = ceil_div(size, nthread_tx)
     ## if the final sorted data ended up in the swap, copy it to the real output
-    with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1):
+    with ib.if_scope(
+        tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1)
+    ):
         with ib.new_scope():
-            tx = te.thread_axis("threadIdx.x")
-            bx = te.thread_axis("blockIdx.x")
-            ib.scope_attr(tx, "thread_extent", nthread_tx)
-            ib.scope_attr(bx, "thread_extent", nthread_bx)
+            tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
             tid = bx * nthread_tx + tx
-
-            by = te.thread_axis("blockIdx.y")
-            bz = te.thread_axis("blockIdx.z")
-            ib.scope_attr(by, "thread_extent", nthread_by)
-            ib.scope_attr(bz, "thread_extent", nthread_bz)
-            idx = (by * size + tid) * axis_mul_after + bz
+            idx = (by * axis_mul_after + bz) * size + tid
             with ib.if_scope(tid < size):
-                idx = (by * size + tid) * axis_mul_after + bz
                 keys[idx] = keys_swap[idx]
                 if values is not None:
                     values[idx] = values_swap[idx]
 
-    return ib.get()
-
 
 def sort_ir(
     data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None
@@ -301,27 +538,30 @@ def sort_ir(
         assert indices_out_swap is not None
         indices_out_swap = ib.buffer_ptr(indices_out_swap)
 
-    axis_mul_before, axis_mul_after = _sort_init(
-        ib,
-        shape,
-        axis,
-        data,
-        values_out,
-        indices_out,
-        value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype),
-    )
+    with ib.if_scope(shape[axis] > 0):
+        axis_mul_before, axis_mul_after = _sort_init(
+            ib,
+            shape,
+            axis,
+            data,
+            values_out,
+            indices_out,
+            value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype),
+        )
+
+        _sort_common(
+            ib,
+            shape[axis],
+            axis_mul_before,
+            axis_mul_after,
+            is_ascend,
+            values_out,
+            values_out_swap,
+            values=indices_out,
+            values_swap=indices_out_swap,
+        )
 
-    return _sort_common(
-        ib,
-        shape[axis],
-        axis_mul_before,
-        axis_mul_after,
-        is_ascend,
-        values_out,
-        values_out_swap,
-        values=indices_out,
-        values_swap=indices_out_swap,
-    )
+    return ib.get()
 
 
 def sort_by_key_ir(
@@ -376,27 +616,29 @@ def sort_by_key_ir(
     values_out = ib.buffer_ptr(values_out)
     values_out_swap = ib.buffer_ptr(values_out_swap)
 
-    axis_mul_before, axis_mul_after = _sort_init(
-        ib,
-        shape,
-        axis,
-        keys_in,
-        keys_out,
-        values_out,
-        value_init_func=lambda idx, _: values_in[idx],
-    )
-
-    return _sort_common(
-        ib,
-        shape[axis],
-        axis_mul_before,
-        axis_mul_after,
-        is_ascend,
-        keys_out,
-        keys_out_swap,
-        values=values_out,
-        values_swap=values_out_swap,
-    )
+    with ib.if_scope(shape[axis] > 0):
+        axis_mul_before, axis_mul_after = _sort_init(
+            ib,
+            shape,
+            axis,
+            keys_in,
+            keys_out,
+            values_out,
+            value_init_func=lambda idx, _: values_in[idx],
+        )
+
+        _sort_common(
+            ib,
+            shape[axis],
+            axis_mul_before,
+            axis_mul_after,
+            is_ascend,
+            keys_out,
+            keys_out_swap,
+            values=values_out,
+            values_swap=values_out_swap,
+        )
+    return ib.get()
 
 
 def sort(data, axis=-1, is_ascend=1):
@@ -419,16 +661,29 @@ def sort(data, axis=-1, is_ascend=1):
     out : tvm.te.Tensor
         The output of this function.
     """
+    ndim = len(data.shape)
+    axis = ndim + axis if axis < 0 else axis
+    if axis != ndim - 1:
+        # Prepare for sorting along axis -1.
+        axes = swap(list(range(ndim)), axis)
+        data = transpose(data, axes)
+
     value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
     value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8)
+
     out = te.extern(
         [data.shape, data.shape],
         [data],
-        lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
+        lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], -1, is_ascend),
         out_buffers=[value_buf, value_buf_swap],
         name="sort_gpu",
         tag="sort_gpu",
     )[0]
+
+    if axis != ndim - 1:
+        axes = swap(list(range(ndim)), axis)
+        out = transpose(out, axes)
+
     return out
 
 
@@ -507,10 +762,18 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
     out : tvm.te.Tensor
         The output of this function.
     """
+    ndim = len(data.shape)
+    axis = ndim + axis if axis < 0 else axis
+    if axis != ndim - 1:
+        # Prepare for sorting along axis -1.
+        axes = swap(list(range(ndim)), axis)
+        data = transpose(data, axes)
+
     value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
     value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_swap_buf", data_alignment=8)
     indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
     indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8)
+
     out = te.extern(
         [data.shape, data.shape, data.shape, data.shape],
         [data],
@@ -518,7 +781,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
             ins[0],
             outs[0],
             outs[2],
-            axis,
+            -1,
             is_ascend,
             indices_out=outs[1],
             indices_out_swap=outs[3],
@@ -527,6 +790,11 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
         name="argsort_gpu",
         tag="argsort_gpu",
     )[1]
+
+    if axis != ndim - 1:
+        axes = swap(list(range(ndim)), axis)
+        out = transpose(out, axes)
+
     return out
 
 
@@ -625,21 +893,30 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     ndim = len(data.shape)
     axis = axis + ndim if axis < 0 else axis
     assert 0 <= axis < ndim
+    dshape = data.shape
+    if axis != ndim - 1:
+        axes = swap(list(range(ndim)), axis)
+        data = transpose(data, axes)
+
     values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
     values_swap_buf = tvm.tir.decl_buffer(
         data.shape, data.dtype, "values_swap_buf", data_alignment=8
     )
     indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
     indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8)
+
     if ret_type == "values":
         output = te.extern(
             [data.shape, data.shape],
             [data],
-            lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
+            lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], -1, is_ascend),
             out_buffers=[values_buf, values_swap_buf],
             name="topk_gpu",
             tag="topk_gpu",
         )[0]
+        if axis != ndim - 1:
+            axes = swap(list(range(ndim)), axis)
+            output = transpose(output, axes)
     else:
         output = te.extern(
             [data.shape, data.shape, data.shape, data.shape],
@@ -648,7 +925,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
                 ins[0],
                 outs[0],
                 outs[2],
-                axis,
+                -1,
                 is_ascend,
                 indices_out=outs[1],
                 indices_out_swap=outs[3],
@@ -657,6 +934,11 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
             name="topk_gpu",
             tag="topk_gpu",
         )[0:2]
+        if axis != ndim - 1:
+            axes = swap(list(range(ndim)), axis)
+            output[0] = transpose(output[0], axes)
+            output[1] = transpose(output[1], axes)
+
     if isinstance(k, int) and k < 1:
         if ret_type == "indices":
             return output[1]
@@ -668,7 +950,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
         if i == axis:
             end.append(k if isinstance(k, int) else tvm.te.size_var("dim"))
         else:
-            end.append(data.shape[i])
+            end.append(dshape[i])
     if ret_type == "both":
         values_out, indices_out = output
         values_out = strided_slice(values_out, beg, end, strides)
diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc
index 38143c1..00002d3 100644
--- a/src/tir/transforms/storage_access.cc
+++ b/src/tir/transforms/storage_access.cc
@@ -132,6 +132,10 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
       StmtExprVisitor::VisitStmt_(op);
     }
     env_threads_.pop_back();
+  } else if (op->attr_key == attr::hand_threaded) {
+    // skip this pass on blocks that were hand_threaded
+    // this avoids control flow and read/write conflicts
+    // between hand-threaded kernels and automatic threading
   } else {
     StmtExprVisitor::VisitStmt_(op);
   }
diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py
index 0dac69e..f4b785f 100644
--- a/tests/python/relay/test_op_level6.py
+++ b/tests/python/relay/test_op_level6.py
@@ -26,6 +26,7 @@ import tvm.testing
 @tvm.testing.uses_gpu
 def test_sort():
     def verify_sort(shape, axis, is_ascend, is_dyn=False):
+
         if is_dyn:
             x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32"))
         else:
@@ -87,9 +88,11 @@ def test_argsort():
         for dtype in ["int32", "int64", "float32", "float64"]:
             verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
             verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn)
-            verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
-            verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
-            verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+        dtype = "int32"
+        verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+        verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+        verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+        verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
 
 
 @tvm.testing.uses_gpu