You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/01/19 19:09:06 UTC
[tvm] branch main updated: [TOPI] Minor perf improvement for GPU
scatter (#7233)
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2290cc0 [TOPI] Minor perf improvement for GPU scatter (#7233)
2290cc0 is described below
commit 2290cc0f79e9f9c255e10bd3775c711591c34e99
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Jan 20 04:08:52 2021 +0900
[TOPI] Minor perf improvement for GPU scatter (#7233)
* improve scatter 4d init
* do not launch sorting based scatter for small input
* do not use hard coded num threads
* separate sort based implementation
* register scatter as autotvm task
* add missing import
* fix strategy
* add dedicated schedule and dummy flop
* add test tuning script
* try adding dummy knob
* skip random_fill when a tuning workload is from scatter
This reverts commit 1fed88321e640b509fc46fac7da3b3cb79719552.
* cleanup memcpy ir
* remove scatter tuning script
* make sure zero init arguments
* add comment on why skip random init for scatter
* restore ctx sync
Co-authored-by: masa <ma...@pop-os.localdomain>
---
python/tvm/autotvm/measure/measure_methods.py | 9 +-
python/tvm/relay/op/strategy/cuda.py | 15 ++-
python/tvm/relay/op/strategy/generic.py | 2 +-
python/tvm/topi/cuda/scatter.py | 179 +++++++++++++++-----------
4 files changed, 123 insertions(+), 82 deletions(-)
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index cb801ba..ffe4b97 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -30,6 +30,7 @@ import time
from random import getrandbits
from collections import namedtuple
import tempfile
+import numpy as np
import tvm._ffi
import tvm.ir.transform
@@ -560,9 +561,11 @@ def run_through_rpc(
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
- args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
- for arg in args:
- random_fill(arg)
+ args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
+ if "scatter" not in measure_input.task.name:
+ # the index tensor of scatter op cannot be randomly initialized
+ for arg in args:
+ random_fill(arg)
ctx.sync()
costs = time_f(*args).results
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 04c16dd..3863df0 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -783,10 +783,23 @@ def scatter_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
- wrap_topi_schedule(topi.generic.schedule_extern),
+ wrap_topi_schedule(topi.cuda.schedule_scatter),
name="scatter.cuda",
plevel=10,
)
+
+ rank = len(inputs[0].shape)
+
+ with SpecializedCondition(rank == 1):
+ if target.kind.name == "cuda" and get_global_func(
+ "tvm.contrib.thrust.stable_sort_by_key", allow_missing=True
+ ):
+ strategy.add_implementation(
+ wrap_compute_scatter(topi.cuda.scatter_via_sort),
+ wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
+ name="scatter_via_sort.cuda",
+ plevel=9, # use the sequential version by default
+ )
return strategy
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 363832e..8dd9dc5 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1123,7 +1123,7 @@ def wrap_compute_scatter(topi_compute):
"""Wrap scatter topi compute"""
def _compute_scatter(attrs, inputs, _):
- return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)]
+ return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)]
return _compute_scatter
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index be602c8..b34bd1d 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -17,16 +17,33 @@
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Scatter operator """
import tvm
-from tvm import te
+from tvm import te, autotvm
from ..scatter import _verify_scatter_nd_inputs
+from ..generic import schedule_extern
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust, is_thrust_available
+from ..utils import prod
def ceil_div(a, b):
return (a + b - 1) // b
+def _memcpy_ir(ib, out_ptr, data_ptr, shape):
+ fused = prod(shape)
+ with ib.new_scope():
+ num_thread = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ num_blocks = ceil_div(fused, num_thread)
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(bx, "thread_extent", num_blocks)
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", num_thread)
+ tid = bx * num_thread + tx
+
+ with ib.if_scope(tid < fused):
+ out_ptr[tid] = data_ptr[tid]
+
+
def gen_ir_1d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 1d inputs
@@ -63,10 +80,7 @@ def gen_ir_1d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)
- with ib.new_scope():
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(bx, "thread_extent", n)
- out_ptr[bx] = data_ptr[bx]
+ _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
@@ -114,8 +128,6 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
ret : tir
The computational ir.
"""
- warp_size = tvm.target.Target.current(False).thread_warp_size
-
n = data.shape[0]
c = data.shape[1]
@@ -124,16 +136,7 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)
- with ib.new_scope():
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(bx, "thread_extent", n)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", warp_size)
- with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:
- j = j_ * warp_size + tx
- with ib.if_scope(j < c):
- idx = bx * c + j
- out_ptr[idx] = data_ptr[idx]
+ _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
@@ -205,18 +208,7 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)
- with ib.new_scope():
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(bx, "thread_extent", n)
- by = te.thread_axis("blockIdx.y")
- ib.scope_attr(by, "thread_extent", c)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", warp_size)
- with ib.for_range(0, ceil_div(h, warp_size), name="k") as k_:
- k = k_ * warp_size + tx
- with ib.if_scope(k < h):
- idx = (bx * c + by) * h + k
- out_ptr[idx] = data_ptr[idx]
+ _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
@@ -311,20 +303,7 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)
- with ib.new_scope():
- i = te.thread_axis("blockIdx.x")
- ib.scope_attr(i, "thread_extent", n)
- j = te.thread_axis("blockIdx.y")
- ib.scope_attr(j, "thread_extent", c)
- k = te.thread_axis("blockIdx.z")
- ib.scope_attr(k, "thread_extent", h)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", warp_size)
- with ib.for_range(0, ceil_div(w, warp_size), name="l") as l_:
- l = l_ * warp_size + tx
- with ib.if_scope(l < w):
- idx = ((i * c + j) * h + k) * w + l
- out_ptr[idx] = data_ptr[idx]
+ _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
@@ -417,7 +396,71 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
return ib.get()
-def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
+@autotvm.register_topi_compute("scatter.cuda")
+def scatter(cfg, data, indices, updates, axis=0):
+ """Update data at positions defined by indices with values in updates
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data to the operator.
+
+ indices : relay.Expr
+ The index locations to update.
+
+ updates : relay.Expr
+ The values to update.
+
+ axis : int
+ The axis to scatter on
+
+ Returns
+ -------
+ ret : relay.Expr
+ The computed result.
+ """
+ if axis < 0:
+ axis += len(data.shape)
+ assert axis >= 0
+ assert axis < len(data.shape)
+
+ rank = len(data.shape)
+ assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"
+
+ ir_funcs = {
+ 1: gen_ir_1d,
+ 2: gen_ir_2d,
+ 3: gen_ir_3d,
+ 4: gen_ir_4d,
+ }
+
+ def update_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] = update
+
+ out_shape = data.shape
+ out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
+
+ cfg.add_flop(1) # A dummy value to satisfy AutoTVM
+
+ out = te.extern(
+ [out_shape],
+ [data, indices, updates],
+ lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
+ dtype=data.dtype,
+ out_buffers=[out_buf],
+ name="scatter_gpu",
+ tag="scatter_gpu",
+ )
+
+ return out
+
+
+@autotvm.register_topi_schedule("scatter.cuda")
+def schedule_scatter(_, outs):
+ return schedule_extern(outs)
+
+
+def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out):
"""Generate scatter ir for 1d inputs, using a sorting based approach.
By sorting indices and comparing neighboring two indices, we can tell which
of elements in the indices tensor can scatter its update value into the output.
@@ -438,9 +481,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
updates : tir.Tensor
The values to update, sorted by indices.
- axis : int
- The axis to scatter on. It must be 0 for this function.
-
out : tir.Tensor
The output tensor.
@@ -449,7 +489,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
ret : tir
The computational ir.
"""
- assert axis == 0
n = data.shape[0]
ib = tvm.tir.ir_builder.create()
@@ -504,7 +543,8 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
return ib.get()
-def scatter(data, indices, updates, axis=0):
+@autotvm.register_topi_compute("scatter_via_sort.cuda")
+def scatter_via_sort(cfg, data, indices, updates, axis=0):
"""Update data at positions defined by indices with values in updates
Parameters
@@ -528,49 +568,34 @@ def scatter(data, indices, updates, axis=0):
"""
if axis < 0:
axis += len(data.shape)
- assert axis >= 0
- assert axis < len(data.shape)
+ assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input"
+ assert is_thrust_available(), "Thrust is required for this op"
- rank = len(data.shape)
- assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"
-
- ir_funcs = {
- 1: gen_ir_1d,
- 2: gen_ir_2d,
- 3: gen_ir_3d,
- 4: gen_ir_4d,
- }
-
- def update_func(dst_ptr, dst_index, update):
- dst_ptr[dst_index] = update
+ cfg.add_flop(1) # A dummy value to satisfy AutoTVM
out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
- in_bufs = [data]
-
- if rank == 1 and is_thrust_available():
- ir_funcs[1] = gen_scatter_1d_thrust
- indices_sorted, updates_sorted = stable_sort_by_key_thrust(
- indices, updates, for_scatter=True
- )
- in_bufs += [indices_sorted, updates_sorted]
- else:
- in_bufs += [indices, updates]
+ indices_sorted, updates_sorted = stable_sort_by_key_thrust(indices, updates, for_scatter=True)
out = te.extern(
[out_shape],
- in_bufs,
- lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
+ [data, indices_sorted, updates_sorted],
+ lambda ins, outs: gen_scatter_1d_thrust(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
- name="scatter_gpu",
- tag="scatter_gpu",
+ name="scatter_via_sort_gpu",
+ tag="scatter_via_sort_gpu",
)
return out
+@autotvm.register_topi_schedule("scatter_via_sort.cuda")
+def schedule_scatter_via_sort(_, outs):
+ return schedule_extern(outs)
+
+
def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
"""Generate scatter add ir for 1d inputs, using atomic_add instruction