You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/01/19 19:09:06 UTC

[tvm] branch main updated: [TOPI] Minor perf improvement for GPU scatter (#7233)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2290cc0  [TOPI] Minor perf improvement for GPU scatter (#7233)
2290cc0 is described below

commit 2290cc0f79e9f9c255e10bd3775c711591c34e99
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Jan 20 04:08:52 2021 +0900

    [TOPI] Minor perf improvement for GPU scatter (#7233)
    
    * improve scatter 4d init
    
    * do not launch sorting based scatter for small input
    
    * do not use hard coded num threads
    
    * separate sort based implementation
    
    * register scatter as autotvm task
    
    * add missing import
    
    * fix strategy
    
    * add dedicated schedule and dummy flop
    
    * add test tuning script
    
    * try adding dummy knob
    
    * skip random_fill when a tuning workload is from scatter
    
    This reverts commit 1fed88321e640b509fc46fac7da3b3cb79719552.
    
    * cleanup memcpy ir
    
    * remove scatter tuning script
    
    * make sure zero init arguments
    
    * add comment on why skip random init for scatter
    
    * restore ctx sync
    
    Co-authored-by: masa <ma...@pop-os.localdomain>
---
 python/tvm/autotvm/measure/measure_methods.py |   9 +-
 python/tvm/relay/op/strategy/cuda.py          |  15 ++-
 python/tvm/relay/op/strategy/generic.py       |   2 +-
 python/tvm/topi/cuda/scatter.py               | 179 +++++++++++++++-----------
 4 files changed, 123 insertions(+), 82 deletions(-)

diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index cb801ba..ffe4b97 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -30,6 +30,7 @@ import time
 from random import getrandbits
 from collections import namedtuple
 import tempfile
+import numpy as np
 
 import tvm._ffi
 import tvm.ir.transform
@@ -560,9 +561,11 @@ def run_through_rpc(
             raise AttributeError(
                 "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
             )
-        args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
-        for arg in args:
-            random_fill(arg)
+        args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
+        if "scatter" not in measure_input.task.name:
+            # the index tensor of scatter op cannot be randomly initialized
+            for arg in args:
+                random_fill(arg)
         ctx.sync()
 
         costs = time_f(*args).results
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 04c16dd..3863df0 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -783,10 +783,23 @@ def scatter_cuda(attrs, inputs, out_type, target):
     strategy = _op.OpStrategy()
     strategy.add_implementation(
         wrap_compute_scatter(topi.cuda.scatter),
-        wrap_topi_schedule(topi.generic.schedule_extern),
+        wrap_topi_schedule(topi.cuda.schedule_scatter),
         name="scatter.cuda",
         plevel=10,
     )
+
+    rank = len(inputs[0].shape)
+
+    with SpecializedCondition(rank == 1):
+        if target.kind.name == "cuda" and get_global_func(
+            "tvm.contrib.thrust.stable_sort_by_key", allow_missing=True
+        ):
+            strategy.add_implementation(
+                wrap_compute_scatter(topi.cuda.scatter_via_sort),
+                wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
+                name="scatter_via_sort.cuda",
+                plevel=9,  # use the sequential version by default
+            )
     return strategy
 
 
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 363832e..8dd9dc5 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1123,7 +1123,7 @@ def wrap_compute_scatter(topi_compute):
     """Wrap scatter topi compute"""
 
     def _compute_scatter(attrs, inputs, _):
-        return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)]
+        return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)]
 
     return _compute_scatter
 
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index be602c8..b34bd1d 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -17,16 +17,33 @@
 # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
 """Scatter operator """
 import tvm
-from tvm import te
+from tvm import te, autotvm
 from ..scatter import _verify_scatter_nd_inputs
+from ..generic import schedule_extern
 from .nms import atomic_add
 from .sort import stable_sort_by_key_thrust, is_thrust_available
+from ..utils import prod
 
 
 def ceil_div(a, b):
     return (a + b - 1) // b
 
 
+def _memcpy_ir(ib, out_ptr, data_ptr, shape):
+    fused = prod(shape)
+    with ib.new_scope():
+        num_thread = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+        num_blocks = ceil_div(fused, num_thread)
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", num_blocks)
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", num_thread)
+        tid = bx * num_thread + tx
+
+        with ib.if_scope(tid < fused):
+            out_ptr[tid] = data_ptr[tid]
+
+
 def gen_ir_1d(data, indices, updates, axis, out, update_func):
     """Generate scatter ir for 1d inputs
 
@@ -63,10 +80,7 @@ def gen_ir_1d(data, indices, updates, axis, out, update_func):
     out_ptr = ib.buffer_ptr(out)
     data_ptr = ib.buffer_ptr(data)
 
-    with ib.new_scope():
-        bx = te.thread_axis("blockIdx.x")
-        ib.scope_attr(bx, "thread_extent", n)
-        out_ptr[bx] = data_ptr[bx]
+    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
 
     indices_ptr = ib.buffer_ptr(indices)
     updates_ptr = ib.buffer_ptr(updates)
@@ -114,8 +128,6 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
     ret : tir
         The computational ir.
     """
-    warp_size = tvm.target.Target.current(False).thread_warp_size
-
     n = data.shape[0]
     c = data.shape[1]
 
@@ -124,16 +136,7 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
     out_ptr = ib.buffer_ptr(out)
     data_ptr = ib.buffer_ptr(data)
 
-    with ib.new_scope():
-        bx = te.thread_axis("blockIdx.x")
-        ib.scope_attr(bx, "thread_extent", n)
-        tx = te.thread_axis("threadIdx.x")
-        ib.scope_attr(tx, "thread_extent", warp_size)
-        with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:
-            j = j_ * warp_size + tx
-            with ib.if_scope(j < c):
-                idx = bx * c + j
-                out_ptr[idx] = data_ptr[idx]
+    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
 
     indices_ptr = ib.buffer_ptr(indices)
     updates_ptr = ib.buffer_ptr(updates)
@@ -205,18 +208,7 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func):
     out_ptr = ib.buffer_ptr(out)
     data_ptr = ib.buffer_ptr(data)
 
-    with ib.new_scope():
-        bx = te.thread_axis("blockIdx.x")
-        ib.scope_attr(bx, "thread_extent", n)
-        by = te.thread_axis("blockIdx.y")
-        ib.scope_attr(by, "thread_extent", c)
-        tx = te.thread_axis("threadIdx.x")
-        ib.scope_attr(tx, "thread_extent", warp_size)
-        with ib.for_range(0, ceil_div(h, warp_size), name="k") as k_:
-            k = k_ * warp_size + tx
-            with ib.if_scope(k < h):
-                idx = (bx * c + by) * h + k
-                out_ptr[idx] = data_ptr[idx]
+    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
 
     indices_ptr = ib.buffer_ptr(indices)
     updates_ptr = ib.buffer_ptr(updates)
@@ -311,20 +303,7 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
 
     out_ptr = ib.buffer_ptr(out)
     data_ptr = ib.buffer_ptr(data)
-    with ib.new_scope():
-        i = te.thread_axis("blockIdx.x")
-        ib.scope_attr(i, "thread_extent", n)
-        j = te.thread_axis("blockIdx.y")
-        ib.scope_attr(j, "thread_extent", c)
-        k = te.thread_axis("blockIdx.z")
-        ib.scope_attr(k, "thread_extent", h)
-        tx = te.thread_axis("threadIdx.x")
-        ib.scope_attr(tx, "thread_extent", warp_size)
-        with ib.for_range(0, ceil_div(w, warp_size), name="l") as l_:
-            l = l_ * warp_size + tx
-            with ib.if_scope(l < w):
-                idx = ((i * c + j) * h + k) * w + l
-                out_ptr[idx] = data_ptr[idx]
+    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
 
     indices_ptr = ib.buffer_ptr(indices)
     updates_ptr = ib.buffer_ptr(updates)
@@ -417,7 +396,71 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
     return ib.get()
 
 
-def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
+@autotvm.register_topi_compute("scatter.cuda")
+def scatter(cfg, data, indices, updates, axis=0):
+    """Update data at positions defined by indices with values in updates
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    indices : relay.Expr
+        The index locations to update.
+
+    updates : relay.Expr
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+    """
+    if axis < 0:
+        axis += len(data.shape)
+    assert axis >= 0
+    assert axis < len(data.shape)
+
+    rank = len(data.shape)
+    assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"
+
+    ir_funcs = {
+        1: gen_ir_1d,
+        2: gen_ir_2d,
+        3: gen_ir_3d,
+        4: gen_ir_4d,
+    }
+
+    def update_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] = update
+
+    out_shape = data.shape
+    out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
+
+    cfg.add_flop(1)  # A dummy value to satisfy AutoTVM
+
+    out = te.extern(
+        [out_shape],
+        [data, indices, updates],
+        lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
+        dtype=data.dtype,
+        out_buffers=[out_buf],
+        name="scatter_gpu",
+        tag="scatter_gpu",
+    )
+
+    return out
+
+
+@autotvm.register_topi_schedule("scatter.cuda")
+def schedule_scatter(_, outs):
+    return schedule_extern(outs)
+
+
+def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out):
     """Generate scatter ir for 1d inputs, using a sorting based approach.
     By sorting indices and comparing neighboring two indices, we can tell which
     of elements in the indices tensor can scatter its update value into the output.
@@ -438,9 +481,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
     updates : tir.Tensor
         The values to update, sorted by indices.
 
-    axis : int
-        The axis to scatter on. It must be 0 for this function.
-
     out : tir.Tensor
         The output tensor.
 
@@ -449,7 +489,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
     ret : tir
         The computational ir.
     """
-    assert axis == 0
     n = data.shape[0]
 
     ib = tvm.tir.ir_builder.create()
@@ -504,7 +543,8 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
     return ib.get()
 
 
-def scatter(data, indices, updates, axis=0):
+@autotvm.register_topi_compute("scatter_via_sort.cuda")
+def scatter_via_sort(cfg, data, indices, updates, axis=0):
     """Update data at positions defined by indices with values in updates
 
     Parameters
@@ -528,49 +568,34 @@ def scatter(data, indices, updates, axis=0):
     """
     if axis < 0:
         axis += len(data.shape)
-    assert axis >= 0
-    assert axis < len(data.shape)
+    assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input"
+    assert is_thrust_available(), "Thrust is required for this op"
 
-    rank = len(data.shape)
-    assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"
-
-    ir_funcs = {
-        1: gen_ir_1d,
-        2: gen_ir_2d,
-        3: gen_ir_3d,
-        4: gen_ir_4d,
-    }
-
-    def update_func(dst_ptr, dst_index, update):
-        dst_ptr[dst_index] = update
+    cfg.add_flop(1)  # A dummy value to satisfy AutoTVM
 
     out_shape = data.shape
     out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
 
-    in_bufs = [data]
-
-    if rank == 1 and is_thrust_available():
-        ir_funcs[1] = gen_scatter_1d_thrust
-        indices_sorted, updates_sorted = stable_sort_by_key_thrust(
-            indices, updates, for_scatter=True
-        )
-        in_bufs += [indices_sorted, updates_sorted]
-    else:
-        in_bufs += [indices, updates]
+    indices_sorted, updates_sorted = stable_sort_by_key_thrust(indices, updates, for_scatter=True)
 
     out = te.extern(
         [out_shape],
-        in_bufs,
-        lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
+        [data, indices_sorted, updates_sorted],
+        lambda ins, outs: gen_scatter_1d_thrust(ins[0], ins[1], ins[2], outs[0]),
         dtype=data.dtype,
         out_buffers=[out_buf],
-        name="scatter_gpu",
-        tag="scatter_gpu",
+        name="scatter_via_sort_gpu",
+        tag="scatter_via_sort_gpu",
     )
 
     return out
 
 
+@autotvm.register_topi_schedule("scatter_via_sort.cuda")
+def schedule_scatter_via_sort(_, outs):
+    return schedule_extern(outs)
+
+
 def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
     """Generate scatter add ir for 1d inputs, using atomic_add instruction