You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2020/12/01 19:59:26 UTC

[tvm] branch main updated: [TOPI] deformable_conv2d in NHWC (#6999)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 73a1a9a  [TOPI] deformable_conv2d in NHWC (#6999)
73a1a9a is described below

commit 73a1a9a69f62281d61148280a023e58e6dcd08f0
Author: Wuwei Lin <vi...@gmail.com>
AuthorDate: Tue Dec 1 14:59:09 2020 -0500

    [TOPI] deformable_conv2d in NHWC (#6999)
    
    * [TOPI] deformable_conv2d in NHWC
    
    * Update python/tvm/topi/generic/nn.py
    
    Co-authored-by: Cody Yu <co...@gmail.com>
    
    * Update python/tvm/topi/testing/deformable_conv2d_python.py
    
    Co-authored-by: Cody Yu <co...@gmail.com>
    
    * style
    
    * fix
    
    * style
    
    Co-authored-by: Cody Yu <co...@gmail.com>
---
 include/tvm/topi/detail/tensor_utils.h             |  37 +++++++
 python/tvm/topi/generic/nn.py                      |  18 ++++
 python/tvm/topi/nn/deformable_conv2d.py            | 110 ++++++++++++++++++++-
 python/tvm/topi/testing/__init__.py                |   2 +-
 ..._nchw_python.py => deformable_conv2d_python.py} |  49 +++++++++
 src/topi/schedule.cc                               |   4 +
 .../topi/python/test_topi_deformable_conv2d.py     |  95 +++++++++++++++++-
 7 files changed, 311 insertions(+), 4 deletions(-)

diff --git a/include/tvm/topi/detail/tensor_utils.h b/include/tvm/topi/detail/tensor_utils.h
index 7004c35..65a760b 100644
--- a/include/tvm/topi/detail/tensor_utils.h
+++ b/include/tvm/topi/detail/tensor_utils.h
@@ -89,6 +89,43 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>&
          D * x_lerp * y_lerp;
 }
 
+/*!
+ * \brief Sample a point in a tensor using bilinear interpolation.
+ *
+ * \param input The input tensor.
+ * \param indices The index of the target point, which can be fractional
+ * \param max_y The maximum of y dimension
+ * \param max_x The maximum of x dimension
+ *
+ * \return The interpolated value in the given index.
+ */
+inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices,
+                                     const PrimExpr max_y, const PrimExpr max_x) {
+  auto in_y = indices[1];
+  auto yf = tvm::floor(in_y);
+  auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));
+
+  auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y));
+  auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
+  auto y_lerp = in_y - yf;
+
+  auto in_x = indices[2];
+  auto xf = tvm::floor(in_x);
+  auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x));
+
+  auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x));
+  auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
+  auto x_lerp = in_x - xf;
+
+  auto A = input(indices[0], y0, x0, indices[3]);
+  auto B = input(indices[0], y0, x1, indices[3]);
+  auto C = input(indices[0], y1, x0, indices[3]);
+  auto D = input(indices[0], y1, x1, indices[3]);
+
+  return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp +
+         D * x_lerp * y_lerp;
+}
+
 }  // namespace detail
 }  // namespace topi
 }  // namespace tvm
diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py
index 4bc3f97..60ccd0d 100644
--- a/python/tvm/topi/generic/nn.py
+++ b/python/tvm/topi/generic/nn.py
@@ -462,6 +462,24 @@ def schedule_deformable_conv2d_nchw(outs):
     return _default_schedule(outs, False)
 
 
+def schedule_deformable_conv2d_nhwc(outs):
+    """Schedule for deformable_conv2d_nhwc.
+    We only use the default schedule here and rely on auto_scheduler.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of deformable_conv2d_nhwc
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
 def schedule_bitserial_conv2d_nchw(outs):
     """Schedule for bitserial_conv2d_nchw
 
diff --git a/python/tvm/topi/nn/deformable_conv2d.py b/python/tvm/topi/nn/deformable_conv2d.py
index a8c2745..780530c 100644
--- a/python/tvm/topi/nn/deformable_conv2d.py
+++ b/python/tvm/topi/nn/deformable_conv2d.py
@@ -21,7 +21,7 @@ from tvm import te
 
 from .utils import get_pad_tuple
 from ..utils import get_const_tuple
-from ..cpp.utils import bilinear_sample_nchw
+from ..cpp.utils import bilinear_sample_nchw, bilinear_sample_nhwc
 
 
 def deformable_conv2d_nchw(
@@ -130,3 +130,111 @@ def deformable_conv2d_nchw(
         ),
         tag="deformable_conv2d_nchw",
     )
+
+
+def deformable_conv2d_nhwc(
+    data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype
+):
+    """Deformable conv2D operator in NHWC layout.
+
+    The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    offset : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width,
+                        deformable_groups * filter_height * filter_width * 2].
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, num_filter]
+
+    strides : int or a list/tuple of two ints
+        stride size, or [stride_height, stride_width]
+
+    padding : int or a list/tuple of two ints
+        padding size, or [pad_height, pad_width]
+
+    dilation : int or a list/tuple of two ints
+        dilation size, or [dilation_height, dilation_width]
+
+    deformable_groups : int
+        number of deformable groups
+
+    groups : int
+        number of groups
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+    if out_dtype is None:
+        out_dtype = data.dtype
+
+    if isinstance(strides, int):
+        stride_h = stride_w = strides
+    else:
+        stride_h, stride_w = strides
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    batch, in_height, in_width, in_channel = get_const_tuple(data.shape)
+    kernel_h, kernel_w, channel, out_channel = get_const_tuple(kernel.shape)
+    _, out_height, out_width, _ = get_const_tuple(offset.shape)
+    assert in_channel % deformable_groups == 0, "Input cahnnels must divide deformable group size"
+    assert groups == 1, "deformable_conv2d_nchw does not support groups > 1"
+
+    ic_per_dgroup = channel // deformable_groups
+
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    pad_top, pad_left, _, _ = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    rc = te.reduce_axis((0, in_channel), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+
+    zero = tvm.tir.const(0.0, data.dtype)
+
+    def _bilinear(n, h, w, c):
+        outside = tvm.tir.any(h < 0, w < 0, h >= in_height, w >= in_width)
+        val = bilinear_sample_nhwc(data, (n, h, w, c), in_height - 1, in_width - 1)
+        return tvm.tir.if_then_else(outside, zero, val)
+
+    data_deform = te.compute(
+        (batch, kernel_h, kernel_w, in_channel, out_height, out_width),
+        lambda n, kh, kw, c, y, x: _bilinear(
+            n,
+            y * stride_h
+            - pad_top
+            + kh * dilation_h
+            + offset[
+                n, y, x, c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2
+            ],
+            x * stride_w
+            - pad_left
+            + kw * dilation_w
+            + offset[
+                n,
+                y,
+                x,
+                c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2 + 1,
+            ],
+            c,
+        ),
+        tag="data_deform",
+    )
+    return te.compute(
+        (batch, out_height, out_width, out_channel),
+        lambda n, y, x, f: te.sum(
+            data_deform[n, ry, rx, rc, y, x].astype(out_dtype)
+            * kernel[ry, rx, rc, f].astype(out_dtype),
+            axis=[ry, rx, rc],
+        ),
+        tag="deformable_conv2d_nhwc",
+    )
diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py
index 0654344..85f13a7 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -31,7 +31,7 @@ from .conv3d_transpose_ncdhw_python import conv3d_transpose_ncdhw_python
 from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
 from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python
 from .correlation_nchw_python import correlation_nchw_python
-from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
+from .deformable_conv2d_python import deformable_conv2d_nchw_python, deformable_conv2d_nhwc_python
 from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
 from .dilate_python import dilate_python
 from .softmax_python import softmax_python, log_softmax_python
diff --git a/python/tvm/topi/testing/deformable_conv2d_nchw_python.py b/python/tvm/topi/testing/deformable_conv2d_python.py
similarity index 74%
rename from python/tvm/topi/testing/deformable_conv2d_nchw_python.py
rename to python/tvm/topi/testing/deformable_conv2d_python.py
index 6a7afb4..0930843 100644
--- a/python/tvm/topi/testing/deformable_conv2d_nchw_python.py
+++ b/python/tvm/topi/testing/deformable_conv2d_python.py
@@ -119,3 +119,52 @@ def deformable_conv2d_nchw_python(
         b_np[n, f, h, w] += np.tensordot(a_deform[n, c, h, w], w_np[f, c])
 
     return b_np
+
+
+def deformable_conv2d_nhwc_python(
+    a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
+):
+    """Deformable convolution operator in NHWC layout.
+
+    Parameters
+    ----------
+    a_np : numpy.ndarray
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    offset_np : numpy.ndarray
+        4-D with shape [batch, out_height, out_width,
+                        deformable_groups * filter_height * filter_width * 2]
+
+    w_np : numpy.ndarray
+        4-D with shape [filter_height, filter_width, in_channel, num_filter]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str or a list/tuple of 2 or 4 ints
+        Padding size, or ['VALID', 'SAME'], or
+        [pad_height, pad_width] for 2 ints, or
+        [pad_top, pad_left, pad_bottom, pad_right] for 2 ints
+
+    dilation : int or a list/tuple of two ints
+        Dilation size, or [dilate_height, dilate_width]
+
+    deformable_groups : int
+        Number of deformable groups
+
+    groups : int
+        Number of groups
+
+    Returns
+    -------
+    b_np : np.ndarray
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    a_np = np.transpose(a_np, [0, 3, 1, 2])  # NHWC -> NCHW
+    offset_np = np.transpose(offset_np, [0, 3, 1, 2])  # NHWC -> NCHW
+    w_np = np.transpose(w_np, [3, 2, 0, 1])  # HWIO -> OIHW
+    b_np = deformable_conv2d_nchw_python(
+        a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
+    )
+    b_np = np.transpose(b_np, [0, 2, 3, 1])  # NCHW -> NHWC
+    return b_np
diff --git a/src/topi/schedule.cc b/src/topi/schedule.cc
index c315d40..f9400bf 100644
--- a/src/topi/schedule.cc
+++ b/src/topi/schedule.cc
@@ -190,6 +190,10 @@ TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw").set_body([](TVMArgs args,
   *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
 });
 
+TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) {
+  *rv = detail::bilinear_sample_nhwc(args[0], args[1], args[2], args[3]);
+});
+
 /*! \brief Builder function for instantiating schedules. */
 using FTVMScheduleBuilder = std::function<tvm::te::Schedule(
     const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
diff --git a/tests/python/topi/python/test_topi_deformable_conv2d.py b/tests/python/topi/python/test_topi_deformable_conv2d.py
index 34bfae7..cd6f33f 100644
--- a/tests/python/topi/python/test_topi_deformable_conv2d.py
+++ b/tests/python/topi/python/test_topi_deformable_conv2d.py
@@ -26,11 +26,15 @@ from tvm.topi.utils import get_const_tuple
 import tvm.testing
 
 
-_deformable_conv2d_implement = {
+_deformable_conv2d_nchw_implement = {
     "generic": (topi.nn.deformable_conv2d_nchw, topi.generic.schedule_deformable_conv2d_nchw),
     "cuda": (topi.cuda.deformable_conv2d_nchw, topi.cuda.schedule_deformable_conv2d_nchw),
 }
 
+_deformable_conv2d_nhwc_implement = {
+    "generic": (topi.nn.deformable_conv2d_nhwc, topi.generic.schedule_deformable_conv2d_nhwc),
+}
+
 
 def verify_deformable_conv2d_nchw(
     batch,
@@ -94,7 +98,7 @@ def verify_deformable_conv2d_nchw(
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
-        fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_implement)
+        fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nchw_implement)
         with tvm.target.Target(device):
             C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype)
             s = fschedule([C])
@@ -112,6 +116,86 @@ def verify_deformable_conv2d_nchw(
         check_device(device)
 
 
+def verify_deformable_conv2d_nhwc(
+    batch,
+    in_channel,
+    in_size,
+    num_filter,
+    kernel,
+    stride,
+    padding,
+    dilation=1,
+    deformable_groups=1,
+    groups=1,
+):
+    print(
+        "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)"
+        % (
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+            deformable_groups,
+            groups,
+        )
+    )
+
+    A = te.placeholder((batch, in_size, in_size, in_channel), name="A")
+    out_size = (in_size - (kernel - 1) * dilation - 1 + 2 * padding) // stride + 1
+    Offset = te.placeholder(
+        (batch, out_size, out_size, deformable_groups * kernel * kernel * 2), name="offset"
+    )
+    W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
+    bias = te.placeholder((num_filter,), name="bias")
+
+    a_shape = get_const_tuple(A.shape)
+    offset_shape = get_const_tuple(Offset.shape)
+    w_shape = get_const_tuple(W.shape)
+    bias_shape = get_const_tuple(bias.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_deformable_conv2d_nchw.verify_deformable_conv2d_nhwc")
+    def get_ref_data():
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        offset_np = np.random.randn(*offset_shape).astype(dtype)
+        w_np = np.random.uniform(size=w_shape).astype(dtype)
+        b_np = np.random.uniform(size=bias_shape).astype(dtype)
+        c_np = tvm.topi.testing.deformable_conv2d_nhwc_python(
+            a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups
+        )
+
+        return a_np, offset_np, w_np, c_np
+
+    a_np, offset_np, w_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not tvm.testing.device_enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nhwc_implement)
+        with tvm.target.Target(device):
+            C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype)
+            s = fschedule([C])
+
+            a = tvm.nd.array(a_np, ctx)
+            offset = tvm.nd.array(offset_np, ctx)
+            w = tvm.nd.array(w_np, ctx)
+            c = tvm.nd.empty(c_np.shape, dtype=c_np.dtype, ctx=ctx)
+
+            func = tvm.build(s, [A, Offset, W, C], device)
+            func(a, offset, w, c)
+            tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+    for device in ["llvm"]:
+        check_device(device)
+
+
 @tvm.testing.uses_gpu
 def test_deformable_conv2d_nchw():
     verify_deformable_conv2d_nchw(1, 16, 7, 16, 1, 1, 0, deformable_groups=4)
@@ -119,5 +203,12 @@ def test_deformable_conv2d_nchw():
     verify_deformable_conv2d_nchw(1, 16, 7, 16, 3, 1, 2, dilation=2)
 
 
+def test_deformable_conv2d_nhwc():
+    verify_deformable_conv2d_nhwc(1, 16, 7, 16, 1, 1, 0, deformable_groups=4)
+    verify_deformable_conv2d_nhwc(1, 16, 7, 16, 3, 1, 1, dilation=2, deformable_groups=4)
+    verify_deformable_conv2d_nhwc(1, 16, 7, 16, 3, 1, 2, dilation=2)
+
+
 if __name__ == "__main__":
     test_deformable_conv2d_nchw()
+    test_deformable_conv2d_nhwc()