You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/08 22:45:47 UTC

[GitHub] [tvm] mbrookhart commented on a change in pull request #7611: [TOPI][GPU] Mergepath sort with odd-even block sort

mbrookhart commented on a change in pull request #7611:
URL: https://github.com/apache/tvm/pull/7611#discussion_r589805036



##########
File path: python/tvm/topi/cuda/sort.py
##########
@@ -136,93 +236,223 @@ def compare(a, b):
             out = b <= a
         return out
 
-    def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even):
-        """
-        Merge the two sections of the array assigned to this thread
-        """
-        # pylint: disable=arguments-out-of-order
-        # initialize iterators
-        i[0] = start
-        j[0] = middle
-        # set up indexes
-        base_idx = by * size * axis_mul_after + bz
-        # iterate over the output loop
-        with ib.for_range(0, end - start) as k:
-            i_idx = base_idx + i[0] * axis_mul_after
-            j_idx = base_idx + j[0] * axis_mul_after
-            k_idx = base_idx + (k + start) * axis_mul_after
-
-            def swap_values(source, dest, source_idx, dest_idx):
-                def assign_i():
-                    """assign i value to current output"""
-                    dest[k_idx] = source[i_idx]
-                    if values is not None:
-                        dest_idx[k_idx] = source_idx[i_idx]
-                    i[0] += 1
-
-                def assign_j():
-                    """assign j value to current output"""
-                    dest[k_idx] = source[j_idx]
-                    if values is not None:
-                        dest_idx[k_idx] = source_idx[j_idx]
-                    j[0] += 1
-
-                ## if both of the iterators are in range
-                with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
-                    # compare them and insert whichever is next into the output
-                    with ib.if_scope(compare(source[i_idx], source[j_idx])):
-                        assign_i()
-                    with ib.else_scope():
-                        assign_j()
-                # otherwise, simply copy the remainder of the valid iterator to the output
-                with ib.else_scope():
-                    with ib.if_scope(i[0] < middle):
-                        assign_i()
-                    with ib.else_scope():
-                        assign_j()
+    # Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
+    lower_lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64"
+    )
 
-            # Switch which input is the source and which is the destination each iteration
-            with ib.if_scope(even):
-                swap_values(source, dest, source_idx, dest_idx)
-            with ib.else_scope():
-                swap_values(dest, source, dest_idx, source_idx)
-
-    def mergesort(source, dest, source_idx, dest_idx, size, width, even):
-        # calculate the start, mid, and end points of this section
-        start[0] = width * tid
-        with ib.if_scope(start[0] < size):
-            middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
-            end[0] = tvm.te.min(start[0] + width, size)
-            ## merge the start->middle and middle->end arrays
-            bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even)
-
-    lim = tvm.tir.generic.cast(
+    _odd_even_sort(
+        ib,
+        size,
+        axis_mul_before * axis_mul_after,
+        1,
+        is_ascend,
+        keys,
+        keys_swap,
+        values,
+        values_swap,
+    )
+
+    upper_lim = tvm.tir.generic.cast(
         tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
     )
-    with ib.for_range(0, lim, dtype="int64") as l2_width:
-        width = 2 << l2_width
+
+    def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
+        first = ib.allocate("int64", (1,), name="first", scope="local")
+        mid = ib.allocate("int64", (1,), name="mid", scope="local")
+        last = ib.allocate("int64", (1,), name="last", scope="local")
+        first[0] = tvm.te.max(0, diag - bCount)
+        last[0] = tvm.te.min(diag, aCount)
+        with ib.while_loop(first[0] < last[0]):
+            mid[0] = (first[0] + last[0]) >> 1
+            a = source[base_idx + (aStart + mid[0])]
+            b = source[base_idx + (bStart + diag - 1 - mid[0])]
+            with ib.if_scope(compare(a, b)):
+                first[0] = mid[0] + 1
+            with ib.else_scope():
+                last[0] = mid[0]
+        return first, last
+
+    def serial_merge(
+        source,
+        dest,
+        source_idx,
+        dest_idx,
+        base_idx,
+        aCount,
+        bCount,
+        aStart,
+        bStart,
+        kStart,
+        diag,
+        step_count,
+        first,
+        last,
+    ):
+        i = ib.allocate("int64", (1,), name="i", scope="local")
+        j = ib.allocate("int64", (1,), name="j", scope="local")
+        i[0] = aStart + first[0]
+        j[0] = bStart + diag - last[0]
+        with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count:
+            i_idx = base_idx + i[0]
+            j_idx = base_idx + j[0]
+            k_idx = base_idx + (kStart + diag + count)
+
+            def assign_i():
+                """assign i value to current output"""
+                dest[k_idx] = source[i_idx]
+                if values is not None:
+                    dest_idx[k_idx] = source_idx[i_idx]
+                i[0] += 1
+
+            def assign_j():
+                """assign j value to current output"""
+                dest[k_idx] = source[j_idx]
+                if values is not None:
+                    dest_idx[k_idx] = source_idx[j_idx]
+                j[0] += 1
+
+            ## if both of the iterators are in range
+            with ib.if_scope(tvm.tir.all(i[0] < aStart + aCount, j[0] < bStart + bCount)):
+                # compare them and insert whichever is next into the output
+                with ib.if_scope(compare(source[i_idx], source[j_idx])):
+                    assign_i()
+                with ib.else_scope():
+                    assign_j()
+            # otherwise, simply copy the remainder of the valid iterator to the output
+            with ib.else_scope():
+                with ib.if_scope(i[0] < aStart + aCount):
+                    assign_i()
+                with ib.else_scope():
+                    assign_j()
+
+    with ib.for_range(0, upper_lim - lower_lim, dtype="int64") as l2_width:
+        width = 2 << (l2_width + lower_lim)
         # Define and launch the cuda kernel
         with ib.new_scope():
-            i = ib.allocate("int64", (1,), name="i", scope="local")
-            j = ib.allocate("int64", (1,), name="j", scope="local")
-            start = ib.allocate("int64", (1,), name="start", scope="local")
-            middle = ib.allocate("int64", (1,), name="middle", scope="local")
-            end = ib.allocate("int64", (1,), name="end", scope="local")
-            tx = te.thread_axis("threadIdx.x")
-            bx = te.thread_axis("blockIdx.x")
-            ib.scope_attr(tx, "thread_extent", nthread_tx)
-            # Reduce the number of blocks as the work per thread grows
-            ib.scope_attr(
-                bx,
-                "thread_extent",
-                tvm.tir.generic.cast(ceil_div(size, width * max_threads), "int32"),
-            )
-            tid = bx * nthread_tx + tx

Review comment:
       @masahi I'm kind of surprised spirv doesn't like it, I was doing a thread calculation before. Perhaps it's just complaining about nthread_tx? Perhaps we just make `ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32")` into `ntx = max_threads`. That should only hit performance slightly and on a couple of levels (there are already if `tid < x` statements everywhere), would that work for spirv?
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org