You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2020/12/01 19:59:26 UTC
[tvm] branch main updated: [TOPI] deformable_conv2d in NHWC (#6999)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 73a1a9a [TOPI] deformable_conv2d in NHWC (#6999)
73a1a9a is described below
commit 73a1a9a69f62281d61148280a023e58e6dcd08f0
Author: Wuwei Lin <vi...@gmail.com>
AuthorDate: Tue Dec 1 14:59:09 2020 -0500
[TOPI] deformable_conv2d in NHWC (#6999)
* [TOPI] deformable_conv2d in NHWC
* Update python/tvm/topi/generic/nn.py
Co-authored-by: Cody Yu <co...@gmail.com>
* Update python/tvm/topi/testing/deformable_conv2d_python.py
Co-authored-by: Cody Yu <co...@gmail.com>
* style
* fix
* style
Co-authored-by: Cody Yu <co...@gmail.com>
---
include/tvm/topi/detail/tensor_utils.h | 37 +++++++
python/tvm/topi/generic/nn.py | 18 ++++
python/tvm/topi/nn/deformable_conv2d.py | 110 ++++++++++++++++++++-
python/tvm/topi/testing/__init__.py | 2 +-
..._nchw_python.py => deformable_conv2d_python.py} | 49 +++++++++
src/topi/schedule.cc | 4 +
.../topi/python/test_topi_deformable_conv2d.py | 95 +++++++++++++++++-
7 files changed, 311 insertions(+), 4 deletions(-)
diff --git a/include/tvm/topi/detail/tensor_utils.h b/include/tvm/topi/detail/tensor_utils.h
index 7004c35..65a760b 100644
--- a/include/tvm/topi/detail/tensor_utils.h
+++ b/include/tvm/topi/detail/tensor_utils.h
@@ -89,6 +89,43 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>&
D * x_lerp * y_lerp;
}
+/*!
+ * \brief Sample a point in a tensor using bilinear interpolation.
+ *
+ * \param input The input tensor.
+ * \param indices The index of the target point, which can be fractional
+ * \param max_y The maximum of y dimension
+ * \param max_x The maximum of x dimension
+ *
+ * \return The interpolated value in the given index.
+ */
+inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices,
+ const PrimExpr max_y, const PrimExpr max_x) {
+ auto in_y = indices[1];
+ auto yf = tvm::floor(in_y);
+ auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));
+
+ auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
+ auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
+ auto y_lerp = in_y - yf;
+
+ auto in_x = indices[2];
+ auto xf = tvm::floor(in_x);
+ auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));
+
+ auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
+ auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
+ auto x_lerp = in_x - xf;
+
+ auto A = input(indices[0], y0, x0, indices[3]);
+ auto B = input(indices[0], y0, x1, indices[3]);
+ auto C = input(indices[0], y1, x0, indices[3]);
+ auto D = input(indices[0], y1, x1, indices[3]);
+
+ return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
+ D * x_lerp * y_lerp;
+}
+
} // namespace detail
} // namespace topi
} // namespace tvm
diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py
index 4bc3f97..60ccd0d 100644
--- a/python/tvm/topi/generic/nn.py
+++ b/python/tvm/topi/generic/nn.py
@@ -462,6 +462,24 @@ def schedule_deformable_conv2d_nchw(outs):
return _default_schedule(outs, False)
+def schedule_deformable_conv2d_nhwc(outs):
+ """Schedule for deformable_conv2d_nhwc.
+ We only use the default schedule here and rely on auto_scheduler.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of deformable_conv2d_nhwc
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ return _default_schedule(outs, False)
+
+
def schedule_bitserial_conv2d_nchw(outs):
"""Schedule for bitserial_conv2d_nchw
diff --git a/python/tvm/topi/nn/deformable_conv2d.py b/python/tvm/topi/nn/deformable_conv2d.py
index a8c2745..780530c 100644
--- a/python/tvm/topi/nn/deformable_conv2d.py
+++ b/python/tvm/topi/nn/deformable_conv2d.py
@@ -21,7 +21,7 @@ from tvm import te
from .utils import get_pad_tuple
from ..utils import get_const_tuple
-from ..cpp.utils import bilinear_sample_nchw
+from ..cpp.utils import bilinear_sample_nchw, bilinear_sample_nhwc
def deformable_conv2d_nchw(
@@ -130,3 +130,111 @@ def deformable_conv2d_nchw(
),
tag="deformable_conv2d_nchw",
)
+
+
+def deformable_conv2d_nhwc(
+ data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype
+):
+ """Deformable conv2D operator in NHWC layout.
+
+ The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ 4-D with shape [batch, in_height, in_width, in_channel]
+
+ offset : tvm.te.Tensor
+ 4-D with shape [batch, out_height, out_width,
+ deformable_groups * filter_height * filter_width * 2].
+
+ kernel : tvm.te.Tensor
+ 4-D with shape [filter_height, filter_width, in_channel, num_filter]
+
+ strides : int or a list/tuple of two ints
+ stride size, or [stride_height, stride_width]
+
+ padding : int or a list/tuple of two ints
+ padding size, or [pad_height, pad_width]
+
+ dilation : int or a list/tuple of two ints
+ dilation size, or [dilation_height, dilation_width]
+
+ deformable_groups : int
+ number of deformable groups
+
+ groups : int
+ number of groups
+
+ Returns
+ -------
+ output : tvm.te.Tensor
+ 4-D with shape [batch, out_height, out_width, out_channel]
+ """
+ if out_dtype is None:
+ out_dtype = data.dtype
+
+ if isinstance(strides, int):
+ stride_h = stride_w = strides
+ else:
+ stride_h, stride_w = strides
+
+ if isinstance(dilation, int):
+ dilation_h = dilation_w = dilation
+ else:
+ dilation_h, dilation_w = dilation
+
+ batch, in_height, in_width, in_channel = get_const_tuple(data.shape)
+ kernel_h, kernel_w, channel, out_channel = get_const_tuple(kernel.shape)
+ _, out_height, out_width, _ = get_const_tuple(offset.shape)
+ assert in_channel % deformable_groups == 0, "Input cahnnels must divide deformable group size"
+ assert groups == 1, "deformable_conv2d_nchw does not support groups > 1"
+
+ ic_per_dgroup = channel // deformable_groups
+
+ dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+ dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+ pad_top, pad_left, _, _ = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+ rc = te.reduce_axis((0, in_channel), name="rc")
+ ry = te.reduce_axis((0, kernel_h), name="ry")
+ rx = te.reduce_axis((0, kernel_w), name="rx")
+
+ zero = tvm.tir.const(0.0, data.dtype)
+
+ def _bilinear(n, h, w, c):
+ outside = tvm.tir.any(h < 0, w < 0, h >= in_height, w >= in_width)
+ val = bilinear_sample_nhwc(data, (n, h, w, c), in_height - 1, in_width - 1)
+ return tvm.tir.if_then_else(outside, zero, val)
+
+ data_deform = te.compute(
+ (batch, kernel_h, kernel_w, in_channel, out_height, out_width),
+ lambda n, kh, kw, c, y, x: _bilinear(
+ n,
+ y * stride_h
+ - pad_top
+ + kh * dilation_h
+ + offset[
+ n, y, x, c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2
+ ],
+ x * stride_w
+ - pad_left
+ + kw * dilation_w
+ + offset[
+ n,
+ y,
+ x,
+ c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2 + 1,
+ ],
+ c,
+ ),
+ tag="data_deform",
+ )
+ return te.compute(
+ (batch, out_height, out_width, out_channel),
+ lambda n, y, x, f: te.sum(
+ data_deform[n, ry, rx, rc, y, x].astype(out_dtype)
+ * kernel[ry, rx, rc, f].astype(out_dtype),
+ axis=[ry, rx, rc],
+ ),
+ tag="deformable_conv2d_nhwc",
+ )
diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py
index 0654344..85f13a7 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -31,7 +31,7 @@ from .conv3d_transpose_ncdhw_python import conv3d_transpose_ncdhw_python
from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python
from .correlation_nchw_python import correlation_nchw_python
-from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
+from .deformable_conv2d_python import deformable_conv2d_nchw_python, deformable_conv2d_nhwc_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
diff --git a/python/tvm/topi/testing/deformable_conv2d_nchw_python.py b/python/tvm/topi/testing/deformable_conv2d_python.py
similarity index 74%
rename from python/tvm/topi/testing/deformable_conv2d_nchw_python.py
rename to python/tvm/topi/testing/deformable_conv2d_python.py
index 6a7afb4..0930843 100644
--- a/python/tvm/topi/testing/deformable_conv2d_nchw_python.py
+++ b/python/tvm/topi/testing/deformable_conv2d_python.py
@@ -119,3 +119,52 @@ def deformable_conv2d_nchw_python(
b_np[n, f, h, w] += np.tensordot(a_deform[n, c, h, w], w_np[f, c])
return b_np
+
+
+def deformable_conv2d_nhwc_python(
+ a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
+):
+ """Deformable convolution operator in NHWC layout.
+
+ Parameters
+ ----------
+ a_np : numpy.ndarray
+ 4-D with shape [batch, in_height, in_width, in_channel]
+
+ offset_np : numpy.ndarray
+ 4-D with shape [batch, out_height, out_width,
+ deformable_groups * filter_height * filter_width * 2]
+
+ w_np : numpy.ndarray
+ 4-D with shape [filter_height, filter_width, in_channel, num_filter]
+
+ stride : int or a list/tuple of two ints
+ Stride size, or [stride_height, stride_width]
+
+ padding : int or str or a list/tuple of 2 or 4 ints
+ Padding size, or ['VALID', 'SAME'], or
+ [pad_height, pad_width] for 2 ints, or
+ [pad_top, pad_left, pad_bottom, pad_right] for 2 ints
+
+ dilation : int or a list/tuple of two ints
+ Dilation size, or [dilate_height, dilate_width]
+
+ deformable_groups : int
+ Number of deformable groups
+
+ groups : int
+ Number of groups
+
+ Returns
+ -------
+ b_np : np.ndarray
+ 4-D with shape [batch, out_channel, out_height, out_width]
+ """
+ a_np = np.transpose(a_np, [0, 3, 1, 2]) # NHWC -> NCHW
+ offset_np = np.transpose(offset_np, [0, 3, 1, 2]) # NHWC -> NCHW
+ w_np = np.transpose(w_np, [3, 2, 0, 1]) # HWIO -> OIHW
+ b_np = deformable_conv2d_nchw_python(
+ a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
+ )
+ b_np = np.transpose(b_np, [0, 2, 3, 1]) # NCHW -> NHWC
+ return b_np
diff --git a/src/topi/schedule.cc b/src/topi/schedule.cc
index c315d40..f9400bf 100644
--- a/src/topi/schedule.cc
+++ b/src/topi/schedule.cc
@@ -190,6 +190,10 @@ TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw").set_body([](TVMArgs args,
*rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
});
+TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = detail::bilinear_sample_nhwc(args[0], args[1], args[2], args[3]);
+});
+
/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<tvm::te::Schedule(
const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
diff --git a/tests/python/topi/python/test_topi_deformable_conv2d.py b/tests/python/topi/python/test_topi_deformable_conv2d.py
index 34bfae7..cd6f33f 100644
--- a/tests/python/topi/python/test_topi_deformable_conv2d.py
+++ b/tests/python/topi/python/test_topi_deformable_conv2d.py
@@ -26,11 +26,15 @@ from tvm.topi.utils import get_const_tuple
import tvm.testing
-_deformable_conv2d_implement = {
+_deformable_conv2d_nchw_implement = {
"generic": (topi.nn.deformable_conv2d_nchw, topi.generic.schedule_deformable_conv2d_nchw),
"cuda": (topi.cuda.deformable_conv2d_nchw, topi.cuda.schedule_deformable_conv2d_nchw),
}
+_deformable_conv2d_nhwc_implement = {
+ "generic": (topi.nn.deformable_conv2d_nhwc, topi.generic.schedule_deformable_conv2d_nhwc),
+}
+
def verify_deformable_conv2d_nchw(
batch,
@@ -94,7 +98,7 @@ def verify_deformable_conv2d_nchw(
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
- fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_implement)
+ fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nchw_implement)
with tvm.target.Target(device):
C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype)
s = fschedule([C])
@@ -112,6 +116,86 @@ def verify_deformable_conv2d_nchw(
check_device(device)
+def verify_deformable_conv2d_nhwc(
+ batch,
+ in_channel,
+ in_size,
+ num_filter,
+ kernel,
+ stride,
+ padding,
+ dilation=1,
+ deformable_groups=1,
+ groups=1,
+):
+ print(
+ "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)"
+ % (
+ batch,
+ in_channel,
+ in_size,
+ num_filter,
+ kernel,
+ stride,
+ padding,
+ dilation,
+ deformable_groups,
+ groups,
+ )
+ )
+
+ A = te.placeholder((batch, in_size, in_size, in_channel), name="A")
+ out_size = (in_size - (kernel - 1) * dilation - 1 + 2 * padding) // stride + 1
+ Offset = te.placeholder(
+ (batch, out_size, out_size, deformable_groups * kernel * kernel * 2), name="offset"
+ )
+ W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
+ bias = te.placeholder((num_filter,), name="bias")
+
+ a_shape = get_const_tuple(A.shape)
+ offset_shape = get_const_tuple(Offset.shape)
+ w_shape = get_const_tuple(W.shape)
+ bias_shape = get_const_tuple(bias.shape)
+ dtype = A.dtype
+
+ @memoize("topi.tests.test_topi_deformable_conv2d_nchw.verify_deformable_conv2d_nhwc")
+ def get_ref_data():
+ a_np = np.random.uniform(size=a_shape).astype(dtype)
+ offset_np = np.random.randn(*offset_shape).astype(dtype)
+ w_np = np.random.uniform(size=w_shape).astype(dtype)
+ b_np = np.random.uniform(size=bias_shape).astype(dtype)
+ c_np = tvm.topi.testing.deformable_conv2d_nhwc_python(
+ a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
+ )
+
+ return a_np, offset_np, w_np, c_np
+
+ a_np, offset_np, w_np, c_np = get_ref_data()
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not tvm.testing.device_enabled(device):
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nhwc_implement)
+ with tvm.target.Target(device):
+ C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype)
+ s = fschedule([C])
+
+ a = tvm.nd.array(a_np, ctx)
+ offset = tvm.nd.array(offset_np, ctx)
+ w = tvm.nd.array(w_np, ctx)
+ c = tvm.nd.empty(c_np.shape, dtype=c_np.dtype, ctx=ctx)
+
+ func = tvm.build(s, [A, Offset, W, C], device)
+ func(a, offset, w, c)
+ tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+ for device in ["llvm"]:
+ check_device(device)
+
+
@tvm.testing.uses_gpu
def test_deformable_conv2d_nchw():
verify_deformable_conv2d_nchw(1, 16, 7, 16, 1, 1, 0, deformable_groups=4)
@@ -119,5 +203,12 @@ def test_deformable_conv2d_nchw():
verify_deformable_conv2d_nchw(1, 16, 7, 16, 3, 1, 2, dilation=2)
+def test_deformable_conv2d_nhwc():
+ verify_deformable_conv2d_nhwc(1, 16, 7, 16, 1, 1, 0, deformable_groups=4)
+ verify_deformable_conv2d_nhwc(1, 16, 7, 16, 3, 1, 1, dilation=2, deformable_groups=4)
+ verify_deformable_conv2d_nhwc(1, 16, 7, 16, 3, 1, 2, dilation=2)
+
+
if __name__ == "__main__":
test_deformable_conv2d_nchw()
+ test_deformable_conv2d_nhwc()