You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/03/31 00:29:06 UTC

[GitHub] [incubator-tvm] jwfromm opened a new pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

jwfromm opened a new pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186
 
 
   This PR adds winograd functions and schedules to Conv3D in Topi. I also add an alter op layout pass to pretransform weights along with the associated relay functions. While working writing the relay operators, I noticed quite a bit of inconsistency across various convolution operators and decided to organize them and reduce code duplication. Now, make functions for all convolution variants are shared where possible so that each ND variant doesn't need its own. I also moved all type relationship templates to `convolution.h` rather than have them split across two files.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403740056
 
 

 ##########
 File path: python/tvm/relay/op/nn/nn.py
 ##########
 @@ -1928,6 +1999,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
     return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
 
 
+def contrib_conv3d_winograd_weight_transform(weight,
+                                             tile_size):
+    r"""Weight Transformation part for 3D convolution with winograd algorithm.
+
+    We separate this as a single op to enable pre-compute for inference.
+    Use this together with nn.contrib_conv3d_winograd_without_weight_transform
+
+    Parameters
+    ----------
+    weight : tvm.relay.Expr
+        The weight expressions.
+
+    tile_size : int
+        The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
 
 Review comment:
   ```suggestion
           The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403539442
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
 
 Review comment:
   My latest commit includes the `ci` `co` axes swap and it does seem to slightly improve performance.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403414259
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                                pre_computed):
+    """Compute declaration for winograd without transforming depth"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, KD, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+    out_depth = simplify((D - KD + pf + pb) // DSTR + 1)
+    D += pf + pb
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m-1) // m, (W + m-1) // m
+    P = N * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # During autotuning dont count kernel packing as a time cost
+        # as it will later be removed via alter_op_layout.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, KD, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, KD, CI, CO),
+                lambda eps, nu, d, ci, co: te.sum(
+                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW))][c][d]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu], name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
+                           te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    rz = te.reduce_axis((0, KD), name='rz')
+    bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
+                       te.sum(kernel_pack[eps][nu][rz][ci][co] *
+                              data_pack[eps][nu][ci][d * DSTR + rz][p],
+                              axis=[ci, rz]), name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
+                         te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
+                                axis=[r_a, r_b]), name='inverse')
+
+    # output
+    output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
+                        inverse[co,
+                                d,
+                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                                idxmod(h, m),
+                                idxmod(w, m)],
+                        name='output', tag='conv3d_ncdhw_winograd_without_depth')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        s[G].compute_inline()
+        r_a, r_b, r_c = s[kernel_pack].op.reduce_axis
+        # Could add additional unrolling by omg, eps, nu in the future.
+        for axis in [r_a, r_b, r_c]:
+            s[kernel_pack].unroll(axis)
+
+        fused = s[kernel_pack].fuse(ci, co)
+        bb, tt = s[kernel_pack].split(fused, 128)
+        s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c)
+        s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
+        s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
+    else:
+        kernel = kernel_pack
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    ##### space definition begin #####
+    b1, b2, b3, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    cfg.define_split(
+        "tile_b",
+        cfg.axis(alpha * alpha * alpha),
+        num_outputs=4,
+        filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
 
 Review comment:
   Since in 3D you have way more tiles, I think you could try more splits on `x` (number of tiles, P). P should be much larger than CO (y), so using the same number of split for x and y seems unbalanced to me.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403298913
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -662,96 +454,101 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
         ConvInferCorrectLayout<Conv2DWinogradAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_weight_transform
-TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
-
-bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
-                                      int num_inputs,
-                                      const Attrs& attrs,
-                                      const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  const Conv2DWinogradWeightTransformAttrs* param = attrs.as<Conv2DWinogradWeightTransformAttrs>();
-  CHECK(param != nullptr);
-
-  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
-
-  // each pad width element should be a pair of positive integers
-  std::vector<IndexExpr> oshape {
-      param->tile_size + data->shape[2] - 1,
-      param->tile_size + data->shape[3] - 1,
-      data->shape[0],
-      data->shape[1],
-  };
-
-  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
-                                                  data->dtype));
-  return true;
-}
-
-Expr MakeConv2DWinogradWeightTransform(Expr weight,
-                                       int tile_size) {
-  auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
-  attrs->tile_size = tile_size;
-  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
-  return Call(op, {weight}, Attrs(attrs), {});
-}
-
+TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
 
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body_typed(MakeConv2DWinogradWeightTransform);
-
+.set_body_typed([](Expr weight,
+                   int tile_size) {
+  return MakeConvWinogradWeightTransform(
+    weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
+});
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm.
+    .describe(R"code(Weight transformation of winograd fast convolution algorithm.
 
 Separate this into another operator in order to enable Precompute Pass to compute the
 weight transformation in advance.
 
 - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DWinogradWeightTransformAttrs>()
-.set_num_inputs(1)
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
+    .set_attrs_type<ConvWinogradWeightTransformAttrs>()
 
 Review comment:
   ditto

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r401330192
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,348 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if D % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        r_kd = te.reduce_axis((0, KD), name='r_kd')
+        r_kh = te.reduce_axis((0, KH), name='r_kh')
+        r_kw = te.reduce_axis((0, KW), name='r_kw')
+        kernel_pack = te.compute(
+            (alpha, alpha, alpha, CI, CO),
+            lambda omg, eps, nu, ci, co: te.sum(
+                kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                axis=[r_kd, r_kh, r_kw]),
+            name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            # skip this part during tuning to make recrods accurate
+            # this part will be pre-computed during pre-compute optimization pass
+            s[G].pragma(s[G].op.axis[0], 'debug_skip_region')
+            s[kernel_pack].pragma(eps, 'debug_skip_region')
 
 Review comment:
   I just find when we use `debug_skip_region`, we will create empty value. However, empty value will make us record performance incorrect. Do you check the performance after tuning? My workaround is to pass one `dirty` value to it. Like this:
   ```python
       kernel_pack = tvm.compute((OMG, EPS, NU, CI, CO), lambda omg, eps, nu, ci, co:
                                 weight[0][0][0][0][0],
                                 name='kernel_pack')
   ```
   This issue I think exist in all our AutoTVM use `debug_skip_region`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403301790
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -766,51 +563,41 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_tra
 .set_body_typed(MakeConv2DWinogradNNPACKWeightTransform);
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
+    .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
 
 Review comment:
   indent

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403415068
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                                pre_computed):
+    """Compute declaration for winograd without transforming depth"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, KD, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+    out_depth = simplify((D - KD + pf + pb) // DSTR + 1)
+    D += pf + pb
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m-1) // m, (W + m-1) // m
+    P = N * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # During autotuning dont count kernel packing as a time cost
+        # as it will later be removed via alter_op_layout.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, KD, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, KD, CI, CO),
+                lambda eps, nu, d, ci, co: te.sum(
+                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW))][c][d]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu], name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
+                           te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    rz = te.reduce_axis((0, KD), name='rz')
+    bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
+                       te.sum(kernel_pack[eps][nu][rz][ci][co] *
+                              data_pack[eps][nu][ci][d * DSTR + rz][p],
+                              axis=[ci, rz]), name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
+                         te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
+                                axis=[r_a, r_b]), name='inverse')
+
+    # output
+    output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
+                        inverse[co,
+                                d,
+                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                                idxmod(h, m),
+                                idxmod(w, m)],
+                        name='output', tag='conv3d_ncdhw_winograd_without_depth')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        s[G].compute_inline()
+        r_a, r_b, r_c = s[kernel_pack].op.reduce_axis
+        # Could add additional unrolling by omg, eps, nu in the future.
+        for axis in [r_a, r_b, r_c]:
+            s[kernel_pack].unroll(axis)
+
+        fused = s[kernel_pack].fuse(ci, co)
+        bb, tt = s[kernel_pack].split(fused, 128)
+        s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c)
+        s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
+        s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
+    else:
+        kernel = kernel_pack
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    ##### space definition begin #####
+    b1, b2, b3, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    cfg.define_split(
+        "tile_b",
+        cfg.axis(alpha * alpha * alpha),
+        num_outputs=4,
+        filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
 
 Review comment:
   Although this makes sense, I'm not quite sure how to increase the number of output axes (`num_outputs`) and still map each tile axis to a cuda `thread_axis`. Do you have any suggestions?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403414798
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
 
 Review comment:
   This is a good point, I did `co` as the innermost axis because `conv2d_winograd` does as well and I was assuming there must be some reason for it. I'll try swapping the inner axis and seeing if that affects performance.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403422713
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                                pre_computed):
+    """Compute declaration for winograd without transforming depth"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, KD, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+    out_depth = simplify((D - KD + pf + pb) // DSTR + 1)
+    D += pf + pb
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m-1) // m, (W + m-1) // m
+    P = N * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # During autotuning dont count kernel packing as a time cost
+        # as it will later be removed via alter_op_layout.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, KD, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, KD, CI, CO),
+                lambda eps, nu, d, ci, co: te.sum(
+                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW))][c][d]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu], name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
+                           te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    rz = te.reduce_axis((0, KD), name='rz')
+    bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
+                       te.sum(kernel_pack[eps][nu][rz][ci][co] *
+                              data_pack[eps][nu][ci][d * DSTR + rz][p],
+                              axis=[ci, rz]), name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
+                         te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
+                                axis=[r_a, r_b]), name='inverse')
+
+    # output
+    output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
+                        inverse[co,
+                                d,
+                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                                idxmod(h, m),
+                                idxmod(w, m)],
+                        name='output', tag='conv3d_ncdhw_winograd_without_depth')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        s[G].compute_inline()
+        r_a, r_b, r_c = s[kernel_pack].op.reduce_axis
+        # Could add additional unrolling by omg, eps, nu in the future.
+        for axis in [r_a, r_b, r_c]:
+            s[kernel_pack].unroll(axis)
+
+        fused = s[kernel_pack].fuse(ci, co)
+        bb, tt = s[kernel_pack].split(fused, 128)
+        s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c)
+        s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
+        s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
+    else:
+        kernel = kernel_pack
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    ##### space definition begin #####
+    b1, b2, b3, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    cfg.define_split(
+        "tile_b",
+        cfg.axis(alpha * alpha * alpha),
+        num_outputs=4,
+        filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
 
 Review comment:
   oh `num_outputs` just indicates how many axes to split into. there's a `max_factor` argument that can set how many different split patterns to try. By leaving it unspecified I think we're telling autotvm to try everything.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 merged pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 merged pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186
 
 
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403422713
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                                pre_computed):
+    """Compute declaration for winograd without transforming depth"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, KD, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+    out_depth = simplify((D - KD + pf + pb) // DSTR + 1)
+    D += pf + pb
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m-1) // m, (W + m-1) // m
+    P = N * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # During autotuning dont count kernel packing as a time cost
+        # as it will later be removed via alter_op_layout.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, KD, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, KD, CI, CO),
+                lambda eps, nu, d, ci, co: te.sum(
+                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW))][c][d]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu], name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
+                           te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    rz = te.reduce_axis((0, KD), name='rz')
+    bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
+                       te.sum(kernel_pack[eps][nu][rz][ci][co] *
+                              data_pack[eps][nu][ci][d * DSTR + rz][p],
+                              axis=[ci, rz]), name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
+                         te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
+                                axis=[r_a, r_b]), name='inverse')
+
+    # output
+    output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
+                        inverse[co,
+                                d,
+                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                                idxmod(h, m),
+                                idxmod(w, m)],
+                        name='output', tag='conv3d_ncdhw_winograd_without_depth')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        s[G].compute_inline()
+        r_a, r_b, r_c = s[kernel_pack].op.reduce_axis
+        # Could add additional unrolling by omg, eps, nu in the future.
+        for axis in [r_a, r_b, r_c]:
+            s[kernel_pack].unroll(axis)
+
+        fused = s[kernel_pack].fuse(ci, co)
+        bb, tt = s[kernel_pack].split(fused, 128)
+        s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c)
+        s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
+        s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
+    else:
+        kernel = kernel_pack
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    ##### space definition begin #####
+    b1, b2, b3, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    cfg.define_split(
+        "tile_b",
+        cfg.axis(alpha * alpha * alpha),
+        num_outputs=4,
+        filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
 
 Review comment:
   oh `num_outputs` just indicates how many axes to split into. there's a `max_factor` argument that can set how many different split patterns to try. By leaving it unspecified I think we're telling autotvm to try everying.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on issue #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on issue #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#issuecomment-609491509
 
 
   Thanks @jwfromm @masahi @FrozenGene. This is now merged.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403301513
 
 

 ##########
 File path: src/relay/op/nn/convolution.h
 ##########
 @@ -360,6 +363,540 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   return true;
 }
 
+
+// Winograd convolution shape relations
+template<typename AttrType>
 
 Review comment:
   I haven't seen anywhere else but `Conv2DWinogradWeightTransform` using this template. I wonder why make it a template function?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r401330192
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,348 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if D % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        r_kd = te.reduce_axis((0, KD), name='r_kd')
+        r_kh = te.reduce_axis((0, KH), name='r_kh')
+        r_kw = te.reduce_axis((0, KW), name='r_kw')
+        kernel_pack = te.compute(
+            (alpha, alpha, alpha, CI, CO),
+            lambda omg, eps, nu, ci, co: te.sum(
+                kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                axis=[r_kd, r_kh, r_kw]),
+            name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            # skip this part during tuning to make recrods accurate
+            # this part will be pre-computed during pre-compute optimization pass
+            s[G].pragma(s[G].op.axis[0], 'debug_skip_region')
+            s[kernel_pack].pragma(eps, 'debug_skip_region')
 
 Review comment:
   I just find when we use `debug_skip_region`, we will create empty value. However, empty value will make us record performance incorrect. Do you check the performance after tuning? My workaround is to pass one `dirty` value to it. Like this:
   ```python
       kernel_pack = tvm.compute((OMG, EPS, NU, CI, CO), lambda eps, nu, ci, co:
                                 weight[0][0][0][0][0],
                                 name='kernel_pack')
   ```
   This issue I think exist in all our AutoTVM use `debug_skip_region`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403539789
 
 

 ##########
 File path: src/relay/op/nn/convolution.h
 ##########
 @@ -360,6 +363,540 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   return true;
 }
 
+
+// Winograd convolution shape relations
+template<typename AttrType>
 
 Review comment:
   Thanks @masahi, making them inline functions is a better design. My latest commit changes all relevant templates into inline functions.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403412011
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -198,141 +302,32 @@ with the layer input to produce a tensor of outputs.
 .add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);
 
+
 // relay.nn.conv2d_transpose
 TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
 
-bool Conv2DTransposeRel(const Array<Type>& types,
-                        int num_inputs,
-                        const Attrs& attrs,
-                        const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
-  const auto* data = types[0].as<TensorTypeNode>();
-  const auto* weight = types[1].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  static const Layout kNCHW("NCHW");
-  static const Layout kOIHW("OIHW");
-
-  const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
-  CHECK(param != nullptr);
-  const Layout in_layout(param->data_layout);
-  const Layout kernel_layout(param->kernel_layout);
-
-  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
-  CHECK(trans_in_layout.defined())
-    << "Conv only support input layouts that are convertible from NCHW."
-    << " But got " << in_layout;
-
-  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
-  CHECK(trans_kernel_layout.defined())
-    << "Conv only support kernel layouts that are convertible from OIHW."
-    << " But got "<< kernel_layout;
-
-  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
-  CHECK(trans_out_layout.defined())
-    << "Conv only support output layouts that are convertible from NCHW."
-    << " But got " << out_layout;
-
-  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
-
-  auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
-
-  // infer weight if the kernel_size and channels are defined
-  if (param->kernel_size.defined() && param->channels.defined()) {
-    CHECK_EQ(param->kernel_size.size(), 2);
-    CHECK_EQ(param->dilation.size(), 2);
-
-    Array<IndexExpr> wshape({dshape_nchw[1],
-            indexdiv(param->channels, param->groups),
-            param->kernel_size[0],
-            param->kernel_size[1]});
-
-    wshape = trans_kernel_layout.BackwardShape(wshape);
-    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
-    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
-    channels = param->channels;
-
-    // assign result to reporter
-    reporter->Assign(types[1], TensorType(wshape, data->dtype));
-  } else {
-    // use weight to infer the conv shape.
-    if (weight == nullptr) return false;
-    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
-    if (param->kernel_size.defined()) {
-      CHECK_EQ(param->kernel_size.size(), 2);
-      // check the size
-      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
-            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
-          << "Conv2D: shape of weight is inconsistent with kernel_size, "
-          << " kernel_size=" << param->kernel_size
-          << " wshape=" << Array<IndexExpr>(wshape);
-    }
-    if (param->channels.defined()) {
-      CHECK(reporter->AssertEQ(param->channels, wshape[1]))
-          << "Conv2D: shape of weight is inconsistent with channels, "
-          << " channels=" << param->channels
-          << " wshape=" << Array<IndexExpr>(wshape);
-    }
-    CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
-    channels = wshape[1];
-    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
-    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
-  }
-  // dilation
-  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
-  IndexExpr pad_h, pad_w;
-  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
-  oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
-                 pad_h + param->output_padding[0]));
-  oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
-                 pad_w + param->output_padding[1]));
-
-  DataType out_dtype = param->out_dtype;
-  if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
-  }
-  oshape = trans_out_layout.BackwardShape(oshape);
-  reporter->Assign(types[2], TensorType(oshape, out_dtype));
-  return true;
-}
-
-
-Expr MakeConv2DTranspose(Expr data,
-                         Expr weight,
-                         Array<IndexExpr> strides,
-                         Array<IndexExpr> padding,
-                         Array<IndexExpr> dilation,
-                         int groups,
-                         IndexExpr channels,
-                         Array<IndexExpr> kernel_size,
-                         std::string data_layout,
-                         std::string kernel_layout,
-                         std::string out_layout,
-                         Array<IndexExpr> output_padding,
-                         DataType out_dtype) {
-  auto attrs = make_object<Conv2DTransposeAttrs>();
-  attrs->channels = std::move(channels);
-  attrs->kernel_size = std::move(kernel_size);
-  attrs->strides = std::move(strides);
-  attrs->padding = std::move(padding);
-  attrs->output_padding = std::move(output_padding);
-  attrs->dilation = std::move(dilation);
-  attrs->groups = groups;
-  attrs->data_layout = std::move(data_layout);
-  attrs->kernel_layout = std::move(kernel_layout);
-  attrs->out_layout = std::move(out_layout);
-  attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get("nn.conv2d_transpose");
-  return Call(op, {data, weight}, Attrs(attrs), {});
-}
-
-
 TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose")
-.set_body_typed(MakeConv2DTranspose);
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   Array<IndexExpr> output_padding,
+                   DataType out_dtype) {
+  return MakeConvTranspose<Conv2DTransposeAttrs>(
+    data, weight, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose");
+});
 
 RELAY_REGISTER_OP("nn.conv2d_transpose")
-.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
+    .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
 
 Review comment:
   indent

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403411388
 
 

 ##########
 File path: include/tvm/relay/attrs/nn.h
 ##########
 @@ -306,6 +306,69 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
   }
 };
 
+/*! \brief Attributes used in 3d winograd convolution operators */
+struct Conv3DWinogradAttrs : public tvm::AttrsNode<Conv3DWinogradAttrs> {
+  int tile_size;
+  Array<IndexExpr> strides;
+  Array<IndexExpr> padding;
+  Array<IndexExpr> dilation;
+  int groups;
+  IndexExpr channels;
+  Array<IndexExpr> kernel_size;
+  std::string data_layout;
+  std::string kernel_layout;
+  std::string out_layout;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") {
+    TVM_ATTR_FIELD(tile_size)
+      .describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
 
 Review comment:
   2x2x2?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on issue #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on issue #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#issuecomment-606923274
 
 
   @merrymercy, there's one little bit of this PR that isn't quite working yet. After autotuning, the alter_op_layout pass successfully converts `conv3d_winograd` to `contrib_conv3d_winograd_without_weight_transform` but then the autotvm dispatcher complains about not being able to find the new op. What should happen is the same schedule used for `conv3d_winograd` is applied. As far as I can tell, everything is completely analogous to the conv2d case which works fine. Is there some special casing somewhere to make this work for conv2d?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403299155
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -662,96 +454,101 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
         ConvInferCorrectLayout<Conv2DWinogradAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_weight_transform
-TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
-
-bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
-                                      int num_inputs,
-                                      const Attrs& attrs,
-                                      const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  const Conv2DWinogradWeightTransformAttrs* param = attrs.as<Conv2DWinogradWeightTransformAttrs>();
-  CHECK(param != nullptr);
-
-  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
-
-  // each pad width element should be a pair of positive integers
-  std::vector<IndexExpr> oshape {
-      param->tile_size + data->shape[2] - 1,
-      param->tile_size + data->shape[3] - 1,
-      data->shape[0],
-      data->shape[1],
-  };
-
-  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
-                                                  data->dtype));
-  return true;
-}
-
-Expr MakeConv2DWinogradWeightTransform(Expr weight,
-                                       int tile_size) {
-  auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
-  attrs->tile_size = tile_size;
-  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
-  return Call(op, {weight}, Attrs(attrs), {});
-}
-
+TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
 
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body_typed(MakeConv2DWinogradWeightTransform);
-
+.set_body_typed([](Expr weight,
+                   int tile_size) {
+  return MakeConvWinogradWeightTransform(
+    weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
+});
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm.
+    .describe(R"code(Weight transformation of winograd fast convolution algorithm.
 
 Separate this into another operator in order to enable Precompute Pass to compute the
 weight transformation in advance.
 
 - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DWinogradWeightTransformAttrs>()
-.set_num_inputs(1)
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
+    .set_attrs_type<ConvWinogradWeightTransformAttrs>()
+    .set_num_inputs(1)
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_support_level(10)
+    .add_type_rel("Conv2DWinogradWeightTransform",
+                  Conv2DWinogradWeightTransformRel<ConvWinogradWeightTransformAttrs>);
+
+// relay.nn.contrib_conv3d_winograd_without_weight_transform
+TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform")
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   int tile_size,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   DataType out_dtype) {
+  return MakeConvWinograd<Conv3DWinogradAttrs>(
+    data, weight, tile_size, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform");
+});
+
+RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform")
+    .describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout.
 
 Review comment:
   ditto

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r401734815
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,348 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if D % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        r_kd = te.reduce_axis((0, KD), name='r_kd')
+        r_kh = te.reduce_axis((0, KH), name='r_kh')
+        r_kw = te.reduce_axis((0, KW), name='r_kw')
+        kernel_pack = te.compute(
+            (alpha, alpha, alpha, CI, CO),
+            lambda omg, eps, nu, ci, co: te.sum(
+                kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                axis=[r_kd, r_kh, r_kw]),
+            name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            # skip this part during tuning to make recrods accurate
+            # this part will be pre-computed during pre-compute optimization pass
+            s[G].pragma(s[G].op.axis[0], 'debug_skip_region')
+            s[kernel_pack].pragma(eps, 'debug_skip_region')
 
 Review comment:
   Right way should be like this pr: https://github.com/apache/incubator-tvm/pull/5200

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403413820
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
 
 Review comment:
   could bringing `ci` to the innermost axis be more efficient? Because you are reducing over `ci`, you are making a lot of strided access.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403412557
 
 

 ##########
 File path: src/relay/op/nn/convolution.h
 ##########
 @@ -360,6 +363,540 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   return true;
 }
 
+
+// Winograd convolution shape relations
+template<typename AttrType>
 
 Review comment:
   why not make them inline function instead of template? If this function is not intended to be polymorphic, making it a template can be confusing.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403740788
 
 

 ##########
 File path: python/tvm/relay/op/nn/nn.py
 ##########
 @@ -295,13 +295,84 @@ def conv3d(data,
         strides = (strides, strides, strides)
     if isinstance(dilation, int):
         dilation = (dilation, dilation, dilation)
-    if isinstance(padding, int):
-        padding = (padding, padding, padding)
+    padding = get_pad_tuple3d(padding)
     return _make.conv3d(data, weight, strides, padding, dilation,
                         groups, channels, kernel_size, data_layout,
                         kernel_layout, out_layout, out_dtype)
 
 
+def contrib_conv3d_winograd_without_weight_transform(data,
+                                                     weight,
+                                                     tile_size,
+                                                     strides=(1, 1, 1),
+                                                     padding=(0, 0, 0),
+                                                     dilation=(1, 1, 1),
+                                                     groups=1,
+                                                     channels=None,
+                                                     kernel_size=None,
+                                                     data_layout="NCDHW",
+                                                     kernel_layout="OIDHW",
+                                                     out_layout="",
+                                                     out_dtype=""):
+    r"""3D convolution with winograd algorithm.
+
+    The basic parameters are the same as the ones in vanilla conv3d.
+    It assumes the weight is pre-transformed by nn.contrib_conv3d_winograd_weight_transform
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    weight : tvm.relay.Expr
+        The weight expressions.
+
+    tile_size : int
+        The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
 
 Review comment:
   ```suggestion
           The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403298871
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -662,96 +454,101 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
         ConvInferCorrectLayout<Conv2DWinogradAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_weight_transform
-TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
-
-bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
-                                      int num_inputs,
-                                      const Attrs& attrs,
-                                      const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  const Conv2DWinogradWeightTransformAttrs* param = attrs.as<Conv2DWinogradWeightTransformAttrs>();
-  CHECK(param != nullptr);
-
-  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
-
-  // each pad width element should be a pair of positive integers
-  std::vector<IndexExpr> oshape {
-      param->tile_size + data->shape[2] - 1,
-      param->tile_size + data->shape[3] - 1,
-      data->shape[0],
-      data->shape[1],
-  };
-
-  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
-                                                  data->dtype));
-  return true;
-}
-
-Expr MakeConv2DWinogradWeightTransform(Expr weight,
-                                       int tile_size) {
-  auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
-  attrs->tile_size = tile_size;
-  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
-  return Call(op, {weight}, Attrs(attrs), {});
-}
-
+TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
 
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body_typed(MakeConv2DWinogradWeightTransform);
-
+.set_body_typed([](Expr weight,
+                   int tile_size) {
+  return MakeConvWinogradWeightTransform(
+    weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
+});
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm.
+    .describe(R"code(Weight transformation of winograd fast convolution algorithm.
 
 Review comment:
   ditto

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403411388
 
 

 ##########
 File path: include/tvm/relay/attrs/nn.h
 ##########
 @@ -306,6 +306,69 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
   }
 };
 
+/*! \brief Attributes used in 3d winograd convolution operators */
+struct Conv3DWinogradAttrs : public tvm::AttrsNode<Conv3DWinogradAttrs> {
+  int tile_size;
+  Array<IndexExpr> strides;
+  Array<IndexExpr> padding;
+  Array<IndexExpr> dilation;
+  int groups;
+  IndexExpr channels;
+  Array<IndexExpr> kernel_size;
+  std::string data_layout;
+  std::string kernel_layout;
+  std::string out_layout;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") {
+    TVM_ATTR_FIELD(tile_size)
+      .describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
 
 Review comment:
   2x2x2 etc?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403422713
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                                pre_computed):
+    """Compute declaration for winograd without transforming depth"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, KD, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+    out_depth = simplify((D - KD + pf + pb) // DSTR + 1)
+    D += pf + pb
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m-1) // m, (W + m-1) // m
+    P = N * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # During autotuning dont count kernel packing as a time cost
+        # as it will later be removed via alter_op_layout.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, KD, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, KD, CI, CO),
+                lambda eps, nu, d, ci, co: te.sum(
+                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW))][c][d]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu], name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
+                           te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    rz = te.reduce_axis((0, KD), name='rz')
+    bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
+                       te.sum(kernel_pack[eps][nu][rz][ci][co] *
+                              data_pack[eps][nu][ci][d * DSTR + rz][p],
+                              axis=[ci, rz]), name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
+                         te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
+                                axis=[r_a, r_b]), name='inverse')
+
+    # output
+    output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
+                        inverse[co,
+                                d,
+                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                                idxmod(h, m),
+                                idxmod(w, m)],
+                        name='output', tag='conv3d_ncdhw_winograd_without_depth')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        s[G].compute_inline()
+        r_a, r_b, r_c = s[kernel_pack].op.reduce_axis
+        # Could add additional unrolling by omg, eps, nu in the future.
+        for axis in [r_a, r_b, r_c]:
+            s[kernel_pack].unroll(axis)
+
+        fused = s[kernel_pack].fuse(ci, co)
+        bb, tt = s[kernel_pack].split(fused, 128)
+        s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c)
+        s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
+        s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
+    else:
+        kernel = kernel_pack
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    ##### space definition begin #####
+    b1, b2, b3, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    cfg.define_split(
+        "tile_b",
+        cfg.axis(alpha * alpha * alpha),
+        num_outputs=4,
+        filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
 
 Review comment:
   oh `num_outputs` just indicates how many axes to split into. there's a `num_splits` argument that can set how many different split patterns to try. By leaving it unspecified I think we're telling autotvm to try everying.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403298477
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -511,133 +401,35 @@ said convolution.
                 out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
 
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv1DTransposeAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(2)
-.add_type_rel("Conv1DTranspose", Conv1DTransposeRel);
-
+    .set_attrs_type<Conv1DTransposeAttrs>()
 
 Review comment:
   remove the indent to be consistent

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on issue #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on issue #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#issuecomment-606324854
 
 
   @merrymercy @icemelon9 @comaniac @masahi this PR is likely most relevant to you guys, can you take a look and let me know what you think?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r401965687
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,348 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if D % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        r_kd = te.reduce_axis((0, KD), name='r_kd')
+        r_kh = te.reduce_axis((0, KH), name='r_kh')
+        r_kw = te.reduce_axis((0, KW), name='r_kw')
+        kernel_pack = te.compute(
+            (alpha, alpha, alpha, CI, CO),
+            lambda omg, eps, nu, ci, co: te.sum(
+                kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                axis=[r_kd, r_kh, r_kw]),
+            name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            # skip this part during tuning to make recrods accurate
+            # this part will be pre-computed during pre-compute optimization pass
+            s[G].pragma(s[G].op.axis[0], 'debug_skip_region')
+            s[kernel_pack].pragma(eps, 'debug_skip_region')
 
 Review comment:
   Thanks for pointing this out. I've made changes to use the placeholder technique instead of `debug_skip_region` and found autotuning to work as well as before.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm removed a comment on issue #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm removed a comment on issue #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#issuecomment-606923274
 
 
   @merrymercy, there's one little bit of this PR that isn't quite working yet. After autotuning, the alter_op_layout pass successfully converts `conv3d_winograd` to `contrib_conv3d_winograd_without_weight_transform` but then the autotvm dispatcher complains about not being able to find the new op. What should happen is the same schedule used for `conv3d_winograd` is applied. As far as I can tell, everything is completely analogous to the conv2d case which works fine. Is there some special casing somewhere to make this work for conv2d?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403335554
 
 

 ##########
 File path: src/relay/op/nn/convolution.h
 ##########
 @@ -360,6 +363,540 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   return true;
 }
 
+
+// Winograd convolution shape relations
+template<typename AttrType>
 
 Review comment:
   I wanted to be consistent with all type relations for convolution functions. Previously we had about the functions evenly split between convolution.cc and convolution.h. Since these are now in a header file they play nicer as template functions than raw function definitions.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM] Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403415504
 
 

 ##########
 File path: topi/python/topi/cuda/conv3d_winograd.py
 ##########
 @@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CI, CO = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CI, CO),
+                lambda omg, eps, nu, ci, co: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][ci][co] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                                pre_computed):
+    """Compute declaration for winograd without transforming depth"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, KD, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+    out_depth = simplify((D - KD + pf + pb) // DSTR + 1)
+    D += pf + pb
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m-1) // m, (W + m-1) // m
+    P = N * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # During autotuning dont count kernel packing as a time cost
+        # as it will later be removed via alter_op_layout.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, KD, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, KD, CI, CO),
+                lambda eps, nu, d, ci, co: te.sum(
+                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW))][c][d]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu], name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
+                           te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    rz = te.reduce_axis((0, KD), name='rz')
+    bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
+                       te.sum(kernel_pack[eps][nu][rz][ci][co] *
+                              data_pack[eps][nu][ci][d * DSTR + rz][p],
+                              axis=[ci, rz]), name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
+                         te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
+                                axis=[r_a, r_b]), name='inverse')
+
+    # output
+    output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
+                        inverse[co,
+                                d,
+                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                                idxmod(h, m),
+                                idxmod(w, m)],
+                        name='output', tag='conv3d_ncdhw_winograd_without_depth')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, ci, co = s[kernel_pack].op.axis
+        s[G].compute_inline()
+        r_a, r_b, r_c = s[kernel_pack].op.reduce_axis
+        # Could add additional unrolling by omg, eps, nu in the future.
+        for axis in [r_a, r_b, r_c]:
+            s[kernel_pack].unroll(axis)
+
+        fused = s[kernel_pack].fuse(ci, co)
+        bb, tt = s[kernel_pack].split(fused, 128)
+        s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c)
+        s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
+        s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
+    else:
+        kernel = kernel_pack
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    ##### space definition begin #####
+    b1, b2, b3, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    cfg.define_split(
+        "tile_b",
+        cfg.axis(alpha * alpha * alpha),
+        num_outputs=4,
+        filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
 
 Review comment:
   sorry my autotvm-fu is very low, I thought `num_outputs=4` means "try 4 values for tile_x and pick the best one". Are you talking about something else?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services