You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/12/21 13:48:40 UTC

[tvm] branch main updated: [CUDA] Parallel Cuda Mergesort (#7099)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 38273ee  [CUDA] Parallel Cuda Mergesort (#7099)
38273ee is described below

commit 38273eeb39bd9b1ef642bd8e940e732f19ee98e8
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Dec 21 06:48:29 2020 -0700

    [CUDA] Parallel Cuda Mergesort (#7099)
---
 python/tvm/driver/build_module.py              |   2 +-
 python/tvm/topi/cuda/sort.py                   | 292 +++++++++++++++++++------
 tests/python/relay/test_any.py                 |   7 +-
 tests/python/relay/test_op_level6.py           |   8 +-
 tests/python/topi/python/test_topi_argwhere.py |   7 +-
 5 files changed, 234 insertions(+), 82 deletions(-)

diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py
index 058bd62..dc9d741 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -277,7 +277,7 @@ def _build_for_device(input_mod, target, target_host):
                 lambda f: "calling_conv" not in f.attrs
                 or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
             ),
-            tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
+            tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
             tvm.tir.transform.LowerTVMBuiltin(),
             tvm.tir.transform.LowerDeviceStorageAccessInfo(),
             tvm.tir.transform.LowerCustomDatatypes(),
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index ea14905..039ebe3 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument, no-else-return
 """Sort related operators """
 import tvm
 from tvm import te
@@ -62,7 +62,9 @@ def _schedule_sort(outs):
     return s
 
 
-def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
+def sort_ir(
+    data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None
+):
     """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
 
     Parameters
@@ -70,8 +72,11 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
     data: Buffer
         Buffer of input data. Data will be sorted in place.
 
-    output : Buffer
-        Output buffer of indicies of sorted tensor with same shape as data.
+    values_out : Buffer
+        Output buffer of values of sorted tensor with same shape as data.
+
+    values_out_swap : Buffer
+        Output buffer of values with same shape as data to use as swap.
 
     axis : Int
         Axis long which to sort the input tensor.
@@ -79,11 +84,21 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
     is_ascend : Boolean
         Whether to sort in ascending or descending order.
 
+    indicess_out : Buffer
+        Output buffer of indices of sorted tensor with same shape as data.
+
+    indices_out_swap : Buffer
+        Output buffer of indices with same shape as data to use as swap.
+
     Returns
     -------
     stmt : Stmt
         The result IR statement.
     """
+
+    def ceil_div(a, b):
+        return tvm.tir.indexdiv(a + b - 1, b)
+
     axis_mul_before = 1
     axis_mul_after = 1
     shape = data.shape
@@ -94,64 +109,182 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
             axis_mul_before *= value
         elif i > axis:
             axis_mul_after *= value
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
     ib = tvm.tir.ir_builder.create()
+
     data = ib.buffer_ptr(data)
     values_out = ib.buffer_ptr(values_out)
+    values_out_swap = ib.buffer_ptr(values_out_swap)
     if indices_out is not None:
         indices_out = ib.buffer_ptr(indices_out)
-    nthread_tx = max_threads
-    nthread_bx = shape[axis] // max_threads + 1
+        assert indices_out_swap is not None
+        indices_out_swap = ib.buffer_ptr(indices_out_swap)
 
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    tid = bx * nthread_tx + tx
-    temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
-    if indices_out is not None:
-        temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
+    # Set up threading
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    nthread_bx = ceil_div(shape[axis], max_threads)
+    nthread_by = axis_mul_before
+    nthread_bz = axis_mul_after
+
+    # Copy the data to initial output
+    with ib.new_scope():
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * nthread_tx + tx
+
+        by = te.thread_axis("blockIdx.y")
+        bz = te.thread_axis("blockIdx.z")
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(bz, "thread_extent", nthread_bz)
+        idx = (by * shape[axis] + tid) * axis_mul_after + bz
+        with ib.if_scope(tid < shape[axis]):
+            values_out[idx] = data[idx]
+            if indices_out is not None:
+                indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)
+
+    ## we are looping over the array doing mergesort from the bottom up.
+    ## The outer loop runs on the host and launches a cuda kernel for each iteration
+    ## of the algorithm.
+    ## The basic idea is that at iteration 0, each thread does sort on 2 elements.
+    ## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
+    ## to deal with 4 total elements.
+    ## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
+    ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc
+    ## On the final iteration of the algorithm, one thread will merge two sorted lists
+    ## to sort the entire array
+    lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64"
+    )
+    with ib.for_range(0, lim, dtype="int64") as l2_width:
+        width = 2 << l2_width
+        # Define and launch the cuda kernel
+        with ib.new_scope():
+            i = ib.allocate("int64", (1,), name="i", scope="local")
+            j = ib.allocate("int64", (1,), name="j", scope="local")
+            start = ib.allocate("int64", (1,), name="start", scope="local")
+            middle = ib.allocate("int64", (1,), name="middle", scope="local")
+            end = ib.allocate("int64", (1,), name="end", scope="local")
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            # Reduce the number of blocks as the work per thread grows
+            ib.scope_attr(
+                bx,
+                "thread_extent",
+                tvm.tir.generic.cast(ceil_div(shape[axis], width * max_threads), "int32"),
+            )
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            bz = te.thread_axis("blockIdx.z")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            ib.scope_attr(bz, "thread_extent", nthread_bz)
+
+            def compare(a, b):
+                """
+                Compare a and b in proper ascending or descending order
+                """
+                if is_ascend:
+                    out = a <= b
+                else:
+                    out = b <= a
+                return out
+
+            def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even):
+                """
+                Merge the two sections of the array assigned to this thread
+                """
+                # pylint: disable=arguments-out-of-order
+                # initialize iterators
+                i[0] = start
+                j[0] = middle
+                # set up indexes
+                base_idx = by * shape[axis] * axis_mul_after + bz
+                # iterate over the output loop
+                with ib.for_range(0, end - start) as k:
+                    i_idx = base_idx + i[0] * axis_mul_after
+                    j_idx = base_idx + j[0] * axis_mul_after
+                    k_idx = base_idx + (k + start) * axis_mul_after
+
+                    def swap_values(source, dest, source_idx, dest_idx):
+                        def assign_i():
+                            """assign i value to current output"""
+                            dest[k_idx] = source[i_idx]
+                            if indices_out is not None:
+                                dest_idx[k_idx] = source_idx[i_idx]
+                            i[0] += 1
+
+                        def assign_j():
+                            """assign j value to current output"""
+                            dest[k_idx] = source[j_idx]
+                            if indices_out is not None:
+                                dest_idx[k_idx] = source_idx[j_idx]
+                            j[0] += 1
+
+                        ## if both of the iterators are in range
+                        with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
+                            # compare them and insert whichever is next into the output
+                            with ib.if_scope(compare(source[i_idx], source[j_idx])):
+                                assign_i()
+                            with ib.else_scope():
+                                assign_j()
+                        # otherwise, simply copy the remainder of the valid iterator to the output
+                        with ib.else_scope():
+                            with ib.if_scope(i[0] < middle):
+                                assign_i()
+                            with ib.else_scope():
+                                assign_j()
+
+                    # Switch which input is the source and which is the destination each iteration
+                    with ib.if_scope(even):
+                        swap_values(source, dest, source_idx, dest_idx)
+                    with ib.else_scope():
+                        swap_values(dest, source, dest_idx, source_idx)
+
+            def mergesort(source, dest, source_idx, dest_idx, size, width, even):
+                # calculate the start, mid, and end points of this section
+                start[0] = width * tid
+                with ib.if_scope(start[0] < size):
+                    middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
+                    end[0] = tvm.te.min(start[0] + width, size)
+                    ## merge the start->middle and middle->end arrays
+                    bottom_up_merge(
+                        source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even
+                    )
 
-    with ib.for_range(0, axis_mul_before) as i:
-        with ib.for_range(0, axis_mul_after) as j:
-            base_idx = i * shape[axis] * axis_mul_after + j
+            # Call the kernel
+            mergesort(
+                values_out,
+                values_out_swap,
+                indices_out,
+                indices_out_swap,
+                shape[axis],
+                width,
+                tvm.tir.indexmod(l2_width, 2) == 0,
+            )
+
+    ## if the final sorted data ended up in the swap, copy it to the real output
+    with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1):
+        with ib.new_scope():
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(bx, "thread_extent", nthread_bx)
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            bz = te.thread_axis("blockIdx.z")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            ib.scope_attr(bz, "thread_extent", nthread_bz)
+            idx = (by * shape[axis] + tid) * axis_mul_after + bz
             with ib.if_scope(tid < shape[axis]):
-                values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
+                idx = (by * shape[axis] + tid) * axis_mul_after + bz
+                values_out[idx] = values_out_swap[idx]
                 if indices_out is not None:
-                    indices_out[base_idx + tid * axis_mul_after] = tvm.tir.generic.cast(
-                        tid, indices_out.dtype
-                    )
-    ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
-    idxd = tvm.tir.indexdiv
-    idxm = tvm.tir.indexmod
-
-    with ib.for_range(0, axis_mul_before) as i:
-        with ib.for_range(0, axis_mul_after) as j:
-            current_sort_num = shape[axis]
-            base_idx = i * shape[axis] * axis_mul_after + j
-            # OddEvenTransposeSort
-            with ib.for_range(0, current_sort_num) as k:
-                with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
-                    offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
-                    if is_ascend:
-                        cond = tvm.tir.all(
-                            2 * tid + idxm(k, 2) + 1 < current_sort_num,
-                            values_out[offset] > values_out[offset + axis_mul_after],
-                        )
-                    else:
-                        cond = tvm.tir.all(
-                            2 * tid + idxm(k, 2) + 1 < current_sort_num,
-                            values_out[offset] < values_out[offset + axis_mul_after],
-                        )
-                    with ib.if_scope(cond):
-                        temp_data[0] = values_out[offset]
-                        values_out[offset] = values_out[offset + axis_mul_after]
-                        values_out[offset + axis_mul_after] = temp_data[0]
-                        if indices_out is not None:
-                            temp_index[0] = indices_out[offset]
-                            indices_out[offset] = indices_out[offset + axis_mul_after]
-                            indices_out[offset + axis_mul_after] = temp_index[0]
-                ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
+                    indices_out[idx] = indices_out_swap[idx]
 
     return ib.get()
 
@@ -336,14 +469,13 @@ def sort(data, axis=-1, is_ascend=1):
     out : tvm.te.Tensor
         The output of this function.
     """
-    dtype = "float32"
     value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
-    indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+    value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8)
     out = te.extern(
         [data.shape, data.shape],
         [data],
-        lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
-        out_buffers=[value_buf, indices_buf],
+        lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
+        out_buffers=[value_buf, value_buf_swap],
         name="sort_gpu",
         tag="sort_gpu",
     )[0]
@@ -449,12 +581,24 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
         )
     else:
         value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
+        value_swap_buf = tvm.tir.decl_buffer(
+            data.shape, data.dtype, "value_swap_buf", data_alignment=8
+        )
         indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+        indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8)
         out = te.extern(
-            [data.shape, data.shape],
+            [data.shape, data.shape, data.shape, data.shape],
             [data],
-            lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
-            out_buffers=[value_buf, indices_buf],
+            lambda ins, outs: sort_ir(
+                ins[0],
+                outs[0],
+                outs[2],
+                axis,
+                is_ascend,
+                indices_out=outs[1],
+                indices_out_swap=outs[3],
+            ),
+            out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf],
             name="argsort_gpu",
             tag="argsort_gpu",
         )[1]
@@ -564,25 +708,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     axis = axis + ndim if axis < 0 else axis
     assert 0 <= axis < ndim
     values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
+    values_swap_buf = tvm.tir.decl_buffer(
+        data.shape, data.dtype, "values_swap_buf", data_alignment=8
+    )
     indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
+    indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8)
     if ret_type == "values":
         output = te.extern(
-            [data.shape],
+            [data.shape, data.shape],
             [data],
-            lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend),
-            out_buffers=[values_buf],
+            lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
+            out_buffers=[values_buf, values_swap_buf],
             name="topk_gpu",
             tag="topk_gpu",
-        )
+        )[0]
     else:
         output = te.extern(
-            [data.shape, data.shape],
+            [data.shape, data.shape, data.shape, data.shape],
             [data],
-            lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
-            out_buffers=[values_buf, indices_buf],
+            lambda ins, outs: sort_ir(
+                ins[0],
+                outs[0],
+                outs[2],
+                axis,
+                is_ascend,
+                indices_out=outs[1],
+                indices_out_swap=outs[3],
+            ),
+            out_buffers=[values_buf, indices_buf, values_swap_buf, indices_swap_buf],
             name="topk_gpu",
             tag="topk_gpu",
-        )
+        )[0:2]
     if isinstance(k, int) and k < 1:
         if ret_type == "indices":
             return output[1]
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index dfc03c0..e6812aa 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -250,9 +250,7 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
     check_result([data], mod, expected, flatten=True)
 
 
-# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have
-# to use thrust to guarantee the correct results which has been tested locally.
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_any_argwhere():
     verify_any_argwhere(any_dims(1), (5,))
     verify_any_argwhere(any_dims(2), (5, 5))
@@ -839,8 +837,7 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
     check_result(in_vals, mod, ref_out)
 
 
-# TODO(kevinthesun): enable this test when Thrust is available in ci.
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_any_topk():
     verify_any_topk(any_dims(1), 5, (10,), "float32")
     verify_any_topk(any_dims(2), 2, (6, 3), "int32")
diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py
index a5ce1fd..0dac69e 100644
--- a/tests/python/relay/test_op_level6.py
+++ b/tests/python/relay/test_op_level6.py
@@ -53,6 +53,8 @@ def test_sort():
         verify_sort((2, 3, 4), axis=0, is_ascend=False, is_dyn=is_dyn)
         verify_sort((1, 4, 6), axis=1, is_ascend=True, is_dyn=is_dyn)
         verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn)
+        verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn)
+        verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn)
 
 
 @tvm.testing.uses_gpu
@@ -66,9 +68,9 @@ def test_argsort():
         func = relay.Function([x], z)
         x_data = np.random.uniform(size=shape).astype("float32")
         if is_ascend:
-            ref_res = np.argsort(x_data, axis=axis)
+            ref_res = np.argsort(x_data, axis=axis, kind="stable")
         else:
-            ref_res = np.argsort(-x_data, axis=axis)
+            ref_res = np.argsort(-x_data, axis=axis, kind="stable")
 
         if is_dyn:
             backends = ["vm", "debug"]
@@ -86,6 +88,8 @@ def test_argsort():
             verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
             verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn)
             verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+            verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+            verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
 
 
 @tvm.testing.uses_gpu
diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py
index 4330308..69993d2 100644
--- a/tests/python/topi/python/test_topi_argwhere.py
+++ b/tests/python/topi/python/test_topi_argwhere.py
@@ -60,15 +60,10 @@ def verify_argwhere(data_shape):
         tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out))
 
     for target, ctx in tvm.testing.enabled_targets():
-        # TODO(zhiics) Enable argwhere gpu test after sort is fixed.
-        if ctx.device_type != 1:
-            continue
         check_device(target, ctx)
 
 
-# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have
-# to use thrust to guarantee the correct results which has been tested locally.
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_argwhere():
     verify_argwhere((1,))
     verify_argwhere((100,))