You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/14 18:03:18 UTC

[GitHub] [tvm] csullivan commented on a change in pull request #7099: [WIP][CUDA] Parallel Cuda Mergesort

csullivan commented on a change in pull request #7099:
URL: https://github.com/apache/tvm/pull/7099#discussion_r542555503



##########
File path: python/tvm/topi/cuda/sort.py
##########
@@ -94,64 +96,155 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
             axis_mul_before *= value
         elif i > axis:
             axis_mul_after *= value
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
     ib = tvm.tir.ir_builder.create()
+
     data = ib.buffer_ptr(data)
     values_out = ib.buffer_ptr(values_out)
+    values_out_swap = ib.buffer_ptr(values_out_swap)
     if indices_out is not None:
         indices_out = ib.buffer_ptr(indices_out)
+        assert indices_out_swap is not None
+        indices_out_swap = ib.buffer_ptr(indices_out_swap)
+
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = shape[axis] // max_threads + 1
-
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("vthread")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "virtual_thread", nthread_bx)
-    tid = bx * nthread_tx + tx
-    temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
-    if indices_out is not None:
-        temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
-
-    with ib.for_range(0, axis_mul_before) as i:
-        with ib.for_range(0, axis_mul_after) as j:
-            base_idx = i * shape[axis] * axis_mul_after + j
+    nthread_by = axis_mul_before
+    nthread_bz = axis_mul_after
+
+    with ib.new_scope():
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * nthread_tx + tx
+
+        by = te.thread_axis("blockIdx.y")
+        bz = te.thread_axis("blockIdx.z")
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(bz, "thread_extent", nthread_bz)
+        idx = (by * shape[axis] + tid) * axis_mul_after + bz
+        with ib.if_scope(tid < shape[axis]):
+            idx = (by * shape[axis] + tid) * axis_mul_after + bz

Review comment:
       `idx` assignment duplicated (L127/129)

##########
File path: python/tvm/topi/cuda/sort.py
##########
@@ -94,64 +96,155 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
             axis_mul_before *= value
         elif i > axis:
             axis_mul_after *= value
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
     ib = tvm.tir.ir_builder.create()
+
     data = ib.buffer_ptr(data)
     values_out = ib.buffer_ptr(values_out)
+    values_out_swap = ib.buffer_ptr(values_out_swap)
     if indices_out is not None:
         indices_out = ib.buffer_ptr(indices_out)
+        assert indices_out_swap is not None
+        indices_out_swap = ib.buffer_ptr(indices_out_swap)
+
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = shape[axis] // max_threads + 1
-
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("vthread")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "virtual_thread", nthread_bx)
-    tid = bx * nthread_tx + tx
-    temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
-    if indices_out is not None:
-        temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
-
-    with ib.for_range(0, axis_mul_before) as i:
-        with ib.for_range(0, axis_mul_after) as j:
-            base_idx = i * shape[axis] * axis_mul_after + j
+    nthread_by = axis_mul_before
+    nthread_bz = axis_mul_after
+
+    with ib.new_scope():
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * nthread_tx + tx
+
+        by = te.thread_axis("blockIdx.y")
+        bz = te.thread_axis("blockIdx.z")
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(bz, "thread_extent", nthread_bz)
+        idx = (by * shape[axis] + tid) * axis_mul_after + bz
+        with ib.if_scope(tid < shape[axis]):
+            idx = (by * shape[axis] + tid) * axis_mul_after + bz
+            values_out[idx] = data[idx]
+            if indices_out is not None:
+                indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)
+
+    source = values_out
+    dest = values_out_swap
+    source_idx = indices_out
+    dest_idx = indices_out_swap
+    lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int32"
+    )
+    with ib.for_range(0, lim) as l2_width:
+        width = 2 << l2_width
+        slices = tvm.tir.indexdiv(shape[axis], (max_threads * width)) + 1

Review comment:
       nit: `slices` -> `num_slices`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org