You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/12/21 13:48:40 UTC
[tvm] branch main updated: [CUDA] Parallel Cuda Mergesort (#7099)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 38273ee [CUDA] Parallel Cuda Mergesort (#7099)
38273ee is described below
commit 38273eeb39bd9b1ef642bd8e940e732f19ee98e8
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Dec 21 06:48:29 2020 -0700
[CUDA] Parallel Cuda Mergesort (#7099)
---
python/tvm/driver/build_module.py | 2 +-
python/tvm/topi/cuda/sort.py | 292 +++++++++++++++++++------
tests/python/relay/test_any.py | 7 +-
tests/python/relay/test_op_level6.py | 8 +-
tests/python/topi/python/test_topi_argwhere.py | 7 +-
5 files changed, 234 insertions(+), 82 deletions(-)
diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py
index 058bd62..dc9d741 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -277,7 +277,7 @@ def _build_for_device(input_mod, target, target_host):
lambda f: "calling_conv" not in f.attrs
or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
),
- tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
+ tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index ea14905..039ebe3 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument, no-else-return
"""Sort related operators """
import tvm
from tvm import te
@@ -62,7 +62,9 @@ def _schedule_sort(outs):
return s
-def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
+def sort_ir(
+ data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None
+):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
Parameters
@@ -70,8 +72,11 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
data: Buffer
Buffer of input data. Data will be sorted in place.
- output : Buffer
- Output buffer of indicies of sorted tensor with same shape as data.
+ values_out : Buffer
+ Output buffer of values of sorted tensor with same shape as data.
+
+ values_out_swap : Buffer
+ Output buffer of values with same shape as data to use as swap.
axis : Int
Axis long which to sort the input tensor.
@@ -79,11 +84,21 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
is_ascend : Boolean
Whether to sort in ascending or descending order.
+ indicess_out : Buffer
+ Output buffer of indices of sorted tensor with same shape as data.
+
+ indices_out_swap : Buffer
+ Output buffer of indices with same shape as data to use as swap.
+
Returns
-------
stmt : Stmt
The result IR statement.
"""
+
+ def ceil_div(a, b):
+ return tvm.tir.indexdiv(a + b - 1, b)
+
axis_mul_before = 1
axis_mul_after = 1
shape = data.shape
@@ -94,64 +109,182 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
axis_mul_before *= value
elif i > axis:
axis_mul_after *= value
- max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
ib = tvm.tir.ir_builder.create()
+
data = ib.buffer_ptr(data)
values_out = ib.buffer_ptr(values_out)
+ values_out_swap = ib.buffer_ptr(values_out_swap)
if indices_out is not None:
indices_out = ib.buffer_ptr(indices_out)
- nthread_tx = max_threads
- nthread_bx = shape[axis] // max_threads + 1
+ assert indices_out_swap is not None
+ indices_out_swap = ib.buffer_ptr(indices_out_swap)
- tx = te.thread_axis("threadIdx.x")
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
- tid = bx * nthread_tx + tx
- temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
- if indices_out is not None:
- temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
+ # Set up threading
+ max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+ nthread_bx = ceil_div(shape[axis], max_threads)
+ nthread_by = axis_mul_before
+ nthread_bz = axis_mul_after
+
+ # Copy the data to initial output
+ with ib.new_scope():
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * nthread_tx + tx
+
+ by = te.thread_axis("blockIdx.y")
+ bz = te.thread_axis("blockIdx.z")
+ ib.scope_attr(by, "thread_extent", nthread_by)
+ ib.scope_attr(bz, "thread_extent", nthread_bz)
+ idx = (by * shape[axis] + tid) * axis_mul_after + bz
+ with ib.if_scope(tid < shape[axis]):
+ values_out[idx] = data[idx]
+ if indices_out is not None:
+ indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)
+
+ ## we are looping over the array doing mergesort from the bottom up.
+ ## The outer loop runs on the host and launches a cuda kernel for each iteration
+ ## of the algorithm.
+ ## The basic idea is that at iteration 0, each thread does sort on 2 elements.
+ ## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
+ ## to deal with 4 total elements.
+ ## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
+ ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc
+ ## On the final iteration of the algorithm, one thread will merge two sorted lists
+ ## to sort the entire array
+ lim = tvm.tir.generic.cast(
+ tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64"
+ )
+ with ib.for_range(0, lim, dtype="int64") as l2_width:
+ width = 2 << l2_width
+ # Define and launch the cuda kernel
+ with ib.new_scope():
+ i = ib.allocate("int64", (1,), name="i", scope="local")
+ j = ib.allocate("int64", (1,), name="j", scope="local")
+ start = ib.allocate("int64", (1,), name="start", scope="local")
+ middle = ib.allocate("int64", (1,), name="middle", scope="local")
+ end = ib.allocate("int64", (1,), name="end", scope="local")
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ # Reduce the number of blocks as the work per thread grows
+ ib.scope_attr(
+ bx,
+ "thread_extent",
+ tvm.tir.generic.cast(ceil_div(shape[axis], width * max_threads), "int32"),
+ )
+ tid = bx * nthread_tx + tx
+
+ by = te.thread_axis("blockIdx.y")
+ bz = te.thread_axis("blockIdx.z")
+ ib.scope_attr(by, "thread_extent", nthread_by)
+ ib.scope_attr(bz, "thread_extent", nthread_bz)
+
+ def compare(a, b):
+ """
+ Compare a and b in proper ascending or descending order
+ """
+ if is_ascend:
+ out = a <= b
+ else:
+ out = b <= a
+ return out
+
+ def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even):
+ """
+ Merge the two sections of the array assigned to this thread
+ """
+ # pylint: disable=arguments-out-of-order
+ # initialize iterators
+ i[0] = start
+ j[0] = middle
+ # set up indexes
+ base_idx = by * shape[axis] * axis_mul_after + bz
+ # iterate over the output loop
+ with ib.for_range(0, end - start) as k:
+ i_idx = base_idx + i[0] * axis_mul_after
+ j_idx = base_idx + j[0] * axis_mul_after
+ k_idx = base_idx + (k + start) * axis_mul_after
+
+ def swap_values(source, dest, source_idx, dest_idx):
+ def assign_i():
+ """assign i value to current output"""
+ dest[k_idx] = source[i_idx]
+ if indices_out is not None:
+ dest_idx[k_idx] = source_idx[i_idx]
+ i[0] += 1
+
+ def assign_j():
+ """assign j value to current output"""
+ dest[k_idx] = source[j_idx]
+ if indices_out is not None:
+ dest_idx[k_idx] = source_idx[j_idx]
+ j[0] += 1
+
+ ## if both of the iterators are in range
+ with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
+ # compare them and insert whichever is next into the output
+ with ib.if_scope(compare(source[i_idx], source[j_idx])):
+ assign_i()
+ with ib.else_scope():
+ assign_j()
+ # otherwise, simply copy the remainder of the valid iterator to the output
+ with ib.else_scope():
+ with ib.if_scope(i[0] < middle):
+ assign_i()
+ with ib.else_scope():
+ assign_j()
+
+ # Switch which input is the source and which is the destination each iteration
+ with ib.if_scope(even):
+ swap_values(source, dest, source_idx, dest_idx)
+ with ib.else_scope():
+ swap_values(dest, source, dest_idx, source_idx)
+
+ def mergesort(source, dest, source_idx, dest_idx, size, width, even):
+ # calculate the start, mid, and end points of this section
+ start[0] = width * tid
+ with ib.if_scope(start[0] < size):
+ middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
+ end[0] = tvm.te.min(start[0] + width, size)
+ ## merge the start->middle and middle->end arrays
+ bottom_up_merge(
+ source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even
+ )
- with ib.for_range(0, axis_mul_before) as i:
- with ib.for_range(0, axis_mul_after) as j:
- base_idx = i * shape[axis] * axis_mul_after + j
+ # Call the kernel
+ mergesort(
+ values_out,
+ values_out_swap,
+ indices_out,
+ indices_out_swap,
+ shape[axis],
+ width,
+ tvm.tir.indexmod(l2_width, 2) == 0,
+ )
+
+ ## if the final sorted data ended up in the swap, copy it to the real output
+ with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1):
+ with ib.new_scope():
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * nthread_tx + tx
+
+ by = te.thread_axis("blockIdx.y")
+ bz = te.thread_axis("blockIdx.z")
+ ib.scope_attr(by, "thread_extent", nthread_by)
+ ib.scope_attr(bz, "thread_extent", nthread_bz)
+ idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
- values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
+ idx = (by * shape[axis] + tid) * axis_mul_after + bz
+ values_out[idx] = values_out_swap[idx]
if indices_out is not None:
- indices_out[base_idx + tid * axis_mul_after] = tvm.tir.generic.cast(
- tid, indices_out.dtype
- )
- ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
- idxd = tvm.tir.indexdiv
- idxm = tvm.tir.indexmod
-
- with ib.for_range(0, axis_mul_before) as i:
- with ib.for_range(0, axis_mul_after) as j:
- current_sort_num = shape[axis]
- base_idx = i * shape[axis] * axis_mul_after + j
- # OddEvenTransposeSort
- with ib.for_range(0, current_sort_num) as k:
- with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
- offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
- if is_ascend:
- cond = tvm.tir.all(
- 2 * tid + idxm(k, 2) + 1 < current_sort_num,
- values_out[offset] > values_out[offset + axis_mul_after],
- )
- else:
- cond = tvm.tir.all(
- 2 * tid + idxm(k, 2) + 1 < current_sort_num,
- values_out[offset] < values_out[offset + axis_mul_after],
- )
- with ib.if_scope(cond):
- temp_data[0] = values_out[offset]
- values_out[offset] = values_out[offset + axis_mul_after]
- values_out[offset + axis_mul_after] = temp_data[0]
- if indices_out is not None:
- temp_index[0] = indices_out[offset]
- indices_out[offset] = indices_out[offset + axis_mul_after]
- indices_out[offset + axis_mul_after] = temp_index[0]
- ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
+ indices_out[idx] = indices_out_swap[idx]
return ib.get()
@@ -336,14 +469,13 @@ def sort(data, axis=-1, is_ascend=1):
out : tvm.te.Tensor
The output of this function.
"""
- dtype = "float32"
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
- indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+ value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data],
- lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
- out_buffers=[value_buf, indices_buf],
+ lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
+ out_buffers=[value_buf, value_buf_swap],
name="sort_gpu",
tag="sort_gpu",
)[0]
@@ -449,12 +581,24 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
)
else:
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
+ value_swap_buf = tvm.tir.decl_buffer(
+ data.shape, data.dtype, "value_swap_buf", data_alignment=8
+ )
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+ indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8)
out = te.extern(
- [data.shape, data.shape],
+ [data.shape, data.shape, data.shape, data.shape],
[data],
- lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
- out_buffers=[value_buf, indices_buf],
+ lambda ins, outs: sort_ir(
+ ins[0],
+ outs[0],
+ outs[2],
+ axis,
+ is_ascend,
+ indices_out=outs[1],
+ indices_out_swap=outs[3],
+ ),
+ out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf],
name="argsort_gpu",
tag="argsort_gpu",
)[1]
@@ -564,25 +708,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
axis = axis + ndim if axis < 0 else axis
assert 0 <= axis < ndim
values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
+ values_swap_buf = tvm.tir.decl_buffer(
+ data.shape, data.dtype, "values_swap_buf", data_alignment=8
+ )
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
+ indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8)
if ret_type == "values":
output = te.extern(
- [data.shape],
+ [data.shape, data.shape],
[data],
- lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend),
- out_buffers=[values_buf],
+ lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
+ out_buffers=[values_buf, values_swap_buf],
name="topk_gpu",
tag="topk_gpu",
- )
+ )[0]
else:
output = te.extern(
- [data.shape, data.shape],
+ [data.shape, data.shape, data.shape, data.shape],
[data],
- lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
- out_buffers=[values_buf, indices_buf],
+ lambda ins, outs: sort_ir(
+ ins[0],
+ outs[0],
+ outs[2],
+ axis,
+ is_ascend,
+ indices_out=outs[1],
+ indices_out_swap=outs[3],
+ ),
+ out_buffers=[values_buf, indices_buf, values_swap_buf, indices_swap_buf],
name="topk_gpu",
tag="topk_gpu",
- )
+ )[0:2]
if isinstance(k, int) and k < 1:
if ret_type == "indices":
return output[1]
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index dfc03c0..e6812aa 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -250,9 +250,7 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
check_result([data], mod, expected, flatten=True)
-# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have
-# to use thrust to guarantee the correct results which has been tested locally.
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
def test_any_argwhere():
verify_any_argwhere(any_dims(1), (5,))
verify_any_argwhere(any_dims(2), (5, 5))
@@ -839,8 +837,7 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
check_result(in_vals, mod, ref_out)
-# TODO(kevinthesun): enable this test when Thrust is available in ci.
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
def test_any_topk():
verify_any_topk(any_dims(1), 5, (10,), "float32")
verify_any_topk(any_dims(2), 2, (6, 3), "int32")
diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py
index a5ce1fd..0dac69e 100644
--- a/tests/python/relay/test_op_level6.py
+++ b/tests/python/relay/test_op_level6.py
@@ -53,6 +53,8 @@ def test_sort():
verify_sort((2, 3, 4), axis=0, is_ascend=False, is_dyn=is_dyn)
verify_sort((1, 4, 6), axis=1, is_ascend=True, is_dyn=is_dyn)
verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn)
+ verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn)
+ verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn)
@tvm.testing.uses_gpu
@@ -66,9 +68,9 @@ def test_argsort():
func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32")
if is_ascend:
- ref_res = np.argsort(x_data, axis=axis)
+ ref_res = np.argsort(x_data, axis=axis, kind="stable")
else:
- ref_res = np.argsort(-x_data, axis=axis)
+ ref_res = np.argsort(-x_data, axis=axis, kind="stable")
if is_dyn:
backends = ["vm", "debug"]
@@ -86,6 +88,8 @@ def test_argsort():
verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+ verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
+ verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
@tvm.testing.uses_gpu
diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py
index 4330308..69993d2 100644
--- a/tests/python/topi/python/test_topi_argwhere.py
+++ b/tests/python/topi/python/test_topi_argwhere.py
@@ -60,15 +60,10 @@ def verify_argwhere(data_shape):
tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out))
for target, ctx in tvm.testing.enabled_targets():
- # TODO(zhiics) Enable argwhere gpu test after sort is fixed.
- if ctx.device_type != 1:
- continue
check_device(target, ctx)
-# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have
-# to use thrust to guarantee the correct results which has been tested locally.
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
def test_argwhere():
verify_argwhere((1,))
verify_argwhere((100,))