You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/04/30 08:10:19 UTC

[GitHub] [incubator-tvm] wsl-inspur opened a new pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

wsl-inspur opened a new pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485


   
   - Optimization of Conv2d Winograd algorithm on Tensor Core for NHWC layout.
   - Winograd with tensor core outperforms original winograd algorithm for all the batchsizes. 
   - However, performance of winograd is worse than conv2d for large batchsizes when Tensor Core were enabled for both. 
   - Performance improvements of resnet50 are fairly good for small batchsizes.  
   
   Please see RFC link below for details:
   [https://discuss.tvm.ai/t/rfc-tensor-core-optimization-of-winograd-conv2d-on-tensor-core/6543](url)
   
   @Hzfengsy @Laurawly @vinx13 @jwfromm Please help to review.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-624109424


   > @FrozenGene 
   > 
   > We tested the layout as you suggested, and the results are listed below. 
   > 
   > 
   > 
   > 
   > 
   > kernel:3x3x64x64 feature maps:56x56x64:
   > 
   > batch | kernel (alpha, alpha, ci, co) | kernel (alpha, alpha, co, ci)
   > 
   > 1 | 0.0762 | 0.0766
   > 
   > 2 | 0.0911 | 0.0931
   > 
   > 4 | 0.1197 | 0.124
   > 
   > 8 | 0.1979 | 0.1942
   > 
   > 16 | 0.3453 | 0.3577
   > 
   > 32 | 0.6613 | 0.7161
   > 
   > 256 | 5.5837 | 5.3269
   > 
   > 
   > 
   > 
   > 
   > kernel:3x3x256x256 feature maps:14x14x256:
   > 
   > batch | kernel (alpha, alpha, ci, co) | kernel (alpha, alpha, co, ci)
   > 
   > 1 | 0.0633 | 0.0694
   > 
   > 2 | 0.0825 | 0.0835
   > 
   > 4 | 0.1417 | 0.1562
   > 
   > 8 | 0.1829 | 0.1853
   > 
   > 16 | 0.264 | 0.277
   > 
   > 32 | 0.4506 | 0.4799
   > 
   > 256 | 3.9432 | 4.0867
   > 
   > 
   > 
   > Note: weight transform was pre-computed. The benchmarks were running on T4 GPU (16GB, 70W). Latency is reported with unit of ms.
   > 
   > 
   > 
   > We can see that the performance of both layouts are in the same level, and the kernel with alpha, alpha, ci, co layout is a little better than alpha, alpha, co, ci layout for most of the cases. 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   > 
   
   Thanks for testing on GPU. Just double check, the input tile / inverse and so on layout is been changed too? (Though I think the performance should be in the same level as your report too...)


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-625176840


   > @FrozenGene
   > We made the following tests to go through all the layouts,
   > 
   > Layout_1(the same in this PR)
   > input_tile = (P, CI, alpha, alpha)
   > data_pack = (alpha, alpha, P, CI)
   > bgemm = (alpha, alpha, P, CO)
   > inverse = (P, CO, m, m)
   > output = (N, H, W, CO)
   > kernel = (alpha, alpha, CI, CO)
   > 
   > Layout_2
   > input_tile = (P, CI, alpha, alpha)
   > data_pack = (alpha, alpha, P, CI)
   > bgemm = (alpha, alpha, P, CO)
   > inverse = (P, CO, m, m)
   > output = (N, H, W, CO)
   > kernel = (alpha, alpha, CO, CI)
   > 
   > Layout_3
   > input_tile = (alpha, alpha, P, CI)
   > data_pack = (alpha, alpha, P, CI)
   > bgemm = (alpha, alpha, P, CO)
   > inverse = (P, CO, m, m)
   > output = (N, H, W, CO)
   > kernel = (alpha, alpha, CI, CO)
   > 
   > Layout_4
   > input_tile = (alpha, alpha, P, CI)
   > data_pack = (alpha, alpha, P, CI)
   > bgemm = (alpha, alpha, P, CO)
   > inverse = (m, m, P, CO)
   > output = (N, H, W, CO)
   > kernel = (alpha, alpha, CO, CI)
   > 
   > The results are listed below.
   > 
   > kernel:3x3x64x64 feature maps:56x56x64:
   > 
   > batch	layout_1	layout_2	layout_3	layout_4
   > 1	0.0762	0.0766	0.0758	0.0907
   > 2	0.0911	0.0931	0.0957	0.0939
   > 4	0.1197	0.124	0.1188	0.1257
   > 8	0.1979	0.1942	0.2026	0.2208
   > 16	0.3453	0.3577	0.3427	0.3833
   > 32	0.6613	0.7161	0.6615	0.7574
   > 256	5.5837	5.3269	5.772	5.7058
   > kernel:3x3x256x256 feature maps:14x14x256:
   > 
   > batch	layout_1	layout_2	layout_3	layout_4
   > 1	0.0633	0.0694	0.0703	0.656
   > 2	0.0825	0.0835	0.0822	0.851
   > 4	0.1417	0.1562	0.1382	0.1502
   > 8	0.1829	0.1853	0.1833	0.182
   > 16	0.264	0.277	0.2832	0.2621
   > 32	0.4506	0.4799	0.4507	0.4773
   > 256	3.9432	4.0867	3.5088	4.5032
   > Note: weight transform was pre-computed. The benchmarks were running on T4 GPU (16GB, 70W). Latency is reported with unit of ms.
   > 
   > We can see that the performance of all layouts are in the same level.
   
   Make sense to me.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-621791710


   For performance, have you tried some other layouts? I have some exp on CPU. The more suitable layout on CPU of NHWC input is:
   
   ```
     input_tile: alpha, alpha, P, CI
     data_pack: alpha, alpha, P, CI
     bgemm: alpha, alpha, P, CO
     inverse: m, m, P, CO
     output: N H W CO
     kernel: alpha alpha CO CI
   ```
   For kernel, I design `alpha alpha CO CI`, because I want to vectorize CI. Maybe on GPU, alpha alpha CI CO is better.
   
   I test your layout compared the layout I mentioned, your layout on skylake-512 is 0.388ms, but my layout I mentioned is 0.375ms. I use 20 threads on workload (1, 56, 56, 64, 64). The performance could be reproduced stabilized.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wsl-inspur commented on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
wsl-inspur commented on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-624103994


   @FrozenGene 
   We tested the layout as you suggested, and the results are listed below. 
   
   
   kernel:3x3x64x64 feature maps:56x56x64:
   batch | kernel (alpha, alpha, ci, co) | kernel (alpha, alpha, co, ci)
   1 | 0.0762 | 0.0766
   2 | 0.0911 | 0.0931
   4 | 0.1197 | 0.124
   8 | 0.1979 | 0.1942
   16 | 0.3453 | 0.3577
   32 | 0.6613 | 0.7161
   256 | 5.5837 | 5.3269
   
   
   kernel:3x3x256x256 feature maps:14x14x256:
   batch | kernel (alpha, alpha, ci, co) | kernel (alpha, alpha, co, ci)
   1 | 0.0633 | 0.0694
   2 | 0.0825 | 0.0835
   4 | 0.1417 | 0.1562
   8 | 0.1829 | 0.1853
   16 | 0.264 | 0.277
   32 | 0.4506 | 0.4799
   256 | 3.9432 | 4.0867
   
   Note: weight transform was pre-computed. The benchmarks were running on T4 GPU (16GB, 70W). Latency is reported with unit of ms.
   
   We can see that the performance of both layouts are in the same level, and the kernel with alpha, alpha, ci, co layout is a little better than alpha, alpha, co, ci layout for most of the cases. 
   
   
   
   
   
   
   
   
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-625178019


   Thanks @wsl-inspur @Hzfengsy !


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#discussion_r417839497



##########
File path: topi/python/topi/cuda/conv2d_nhwc_winograd.py
##########
@@ -0,0 +1,639 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+# pylint: disable=too-many-arguments,too-many-locals
+# pylint: disable=too-many-statements
+"""Winograd template for cuda backend"""
+
+import tvm
+from tvm import te
+from tvm import autotvm
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+from .tensor_intrin import intrin_wmma_load_matrix_A
+from .tensor_intrin import intrin_wmma_load_matrix_W
+from .tensor_intrin import intrin_wmma_store_matrix
+from .tensor_intrin import intrin_wmma_gemm
+
+def _infer_tile_size(data, kernel):
+    """Compute the tile size"""
+    N, H, W, CI = get_const_tuple(data.shape)
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm tensorcore"""
+    A = data_pack
+    B = kernel_pack
+    C = bgemm
+    _, _, P, out_dim = get_const_tuple(C.shape)
+    out_dtype = C.dtype
+
+    # Explicit memory access
+    AS = s.cache_read(A, 'shared', [C])
+    BS = s.cache_read(B, 'shared', [C])
+    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+    CF = s.cache_write(C, 'wmma.accumulator')
+    CS = s.cache_read(CF, 'shared', [C])
+
+    # Create tuning space
+    cfg.define_knob("block_row_warps", [1, 2, 4])
+    cfg.define_knob("block_col_warps", [1, 2, 4])
+    cfg.define_knob("warp_row_tiles", [1, 2, 4, 8])
+    cfg.define_knob("warp_col_tiles", [1, 2, 4, 8])
+    cfg.define_knob("chunk", [1, 2, 4, 8])
+    cfg.define_knob("offset", [0, 1, 2, 4, 8])
+    cfg.define_knob("offsetCS", [0, 1, 2, 4, 8])
+    cfg.define_knob("vec", [1, 2, 4, 8])
+
+    # Ensure that the default parameters are applicable when autotvm is not in use
+    if (P % 16 == 0 and out_dim % 16 == 0):
+        cfg.define_knob("wmma_m", [16, 8, 32])
+    elif (P % 32 == 0 and out_dim % 8 == 0):
+        cfg.define_knob("wmma_m", [32, 16, 8])
+    elif (P % 8 == 0 and out_dim % 32 == 0):
+        cfg.define_knob("wmma_m", [8, 16, 32])
+
+    warp_size = 32
+    wmma_k = 16
+    block_row_warps = cfg["block_row_warps"].val
+    block_col_warps = cfg["block_col_warps"].val
+    warp_row_tiles = cfg["warp_row_tiles"].val
+    warp_col_tiles = cfg["warp_col_tiles"].val
+    chunk = cfg["chunk"].val
+    offsetAB = cfg["offset"].val
+    offsetCS = cfg["offsetCS"].val
+    wmma_m = cfg["wmma_m"].val
+    vec = cfg["vec"].val
+
+    if wmma_m == 16:
+        wmma_n = 16
+    elif wmma_m == 8:
+        wmma_n = 32
+    elif wmma_m == 32:
+        wmma_n = 8
+
+    # Define the stride of intrin functions
+    AS_align = chunk * wmma_k + offsetAB
+    BS_align = warp_col_tiles * block_col_warps * wmma_n + offsetAB
+    CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+    AS_stride = [AS_align, 1]
+    BS_stride = [BS_align, 1]
+    AF_stride = [wmma_k, 1]
+    BF_stride = [wmma_n * warp_col_tiles, 1]
+    CF_stride = [warp_col_tiles * wmma_n, 1]
+    CS_stride = [CS_align, 1]
+    block_x = te.thread_axis('blockIdx.x')
+    block_y = te.thread_axis('blockIdx.y')
+    block_z = te.thread_axis('blockIdx.z')
+    thread_x = te.thread_axis('threadIdx.x')
+    thread_y = te.thread_axis('threadIdx.y')
+    thread_z = te.thread_axis('threadIdx.z')
+
+    # Schedule for computation
+    block_factor_b = wmma_m * warp_row_tiles * block_row_warps
+    block_factor_o = wmma_n * warp_col_tiles * block_col_warps
+    alpha_1, alpha_2, b, o = C.op.axis
+    block_k = s[C].fuse(alpha_1, alpha_2)
+    block_i, bc = s[C].split(b, factor=block_factor_b)
+    block_j, oc = s[C].split(o, factor=block_factor_o)
+    s[C].reorder(block_k, block_i, block_j, bc, oc)
+    t = s[C].fuse(bc, oc)
+    t, vi = s[C].split(t, factor=vec)
+    t, tx = s[C].split(t, factor=warp_size)
+    t, ty = s[C].split(t, factor=block_row_warps)
+    t, tz = s[C].split(t, factor=block_col_warps)
+    s[C].bind(block_k, block_z)
+    s[C].bind(block_i, block_x)
+    s[C].bind(block_j, block_y)
+    s[C].bind(tz, thread_z)
+    s[C].bind(ty, thread_y)
+    s[C].bind(tx, thread_x)
+    s[C].vectorize(vi)
+
+    # Schedule for wmma store
+    s[CS].compute_at(s[C], block_j)
+    _, _, bb, oo = CS.op.axis
+    s[CS].storage_align(bb, CS_align - 1, CS_align)
+    bb, bbi = s[CS].split(bb, factor=wmma_m)
+    oo, ooi = s[CS].split(oo, factor=wmma_n)
+    bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+    oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+    s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi)
+
+    # Schedule for wmma computation
+    s[CF].compute_at(s[CS], oo)
+    _, _, warp_i, warp_j = CF.op.axis
+    warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+    warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+    k, = CF.op.reduce_axis
+    k, _k = s[CF].split(k, factor=wmma_k)
+    ko, ki = s[CF].split(k, factor=chunk)
+    s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+    # Schedule for  wmma_matrix_a load
+    s[AF].compute_at(s[CF], ki)
+    _, _, b, i = AF.op.axis
+    b, b_ii = s[AF].split(b, factor=wmma_m)
+    i, i_jj = s[AF].split(i, factor=wmma_k)
+    s[AF].reorder(b, i, b_ii, i_jj)
+
+    # Schedule for  wmma_matrix_b load
+    s[BF].compute_at(s[CF], ki)
+    _, _, i, o = BF.op.axis
+    o, o_ii = s[BF].split(o, factor=wmma_n)
+    i, i_ii = s[BF].split(i, factor=wmma_k)
+    s[BF].reorder(i, o, i_ii, o_ii)
+
+    # Schedule for A's(B's) shared memory load
+    def shared_shedule(stage, strides):
+        s[stage].compute_at(s[CF], ko)
+        _, _, xo, yo = stage.op.axis
+        s[stage].storage_align(xo, strides - 1, strides)
+        t = s[stage].fuse(xo, yo)
+        t, vi = s[stage].split(t, factor=vec)
+        t, tx = s[stage].split(t, factor=warp_size)
+        t, ty = s[stage].split(t, factor=block_row_warps)
+        _, tz = s[stage].split(t, factor=block_col_warps)
+        s[stage].bind(ty, thread_y)
+        s[stage].bind(tz, thread_z)
+        s[stage].bind(tx, thread_x)
+        s[stage].vectorize(vi)
+
+    shared_shedule(AS, AS_align)
+    shared_shedule(BS, BS_align)
+
+    shape = (wmma_m, wmma_n, wmma_k)
+    in_dtype = 'float16'
+    AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
+    BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype)
+    k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
+    CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
+                            te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *
+                                   BL_gemm[k_gemm, jj].astype(out_dtype),
+                                   axis=k_gemm), name='CL_compute')
+
+    # Lower the computation loops down to TensorCore hardware intrinsics
+    # by mapping the tensorcore to tensor intrinsics
+    s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, shape, "row_major",
+                                                    (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16'))
+    s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, shape, "row_major",
+                                                    (wmma_k, wmma_n), (wmma_k, wmma_n), 'float16'))
+    s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride,
+                                          BF_stride, CF_stride, shape))
+    s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, out_dtype,
+                                                  (wmma_m, wmma_n), (wmma_m, wmma_n)))
+
+
+def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm direct"""
+    b1, b2, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    # Create tuning space
+    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
+                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_knob("offset_bgemm", [0, 1, 2, 4, 8])
+    cfg.define_knob("vector_bgemm", [1, 2, 4, 8])
+    offset_bgemm = cfg["offset_bgemm"].val
+    vector_bgemm = cfg["vector_bgemm"].val
+
+    C = bgemm
+    A0, B0 = kernel_pack, data_pack
+
+    # Designate the memory hierarchy
+    OL = s.cache_write(C, 'local')
+    AA = s.cache_read(A0, 'shared', [OL])
+    BB = s.cache_read(B0, 'shared', [OL])
+
+    # Tile and bind spatial axes
+    b = s[bgemm].fuse(b1, b2)
+    bgemm_scope, b = s[bgemm].split(b, nparts=1)
+    bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
+    s[C].bind(bz, te.thread_axis("blockIdx.z"))
+    s[C].bind(by, te.thread_axis("blockIdx.y"))
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(vz, te.thread_axis("vthread"))
+    s[C].bind(vy, te.thread_axis("vthread"))
+    s[C].bind(vx, te.thread_axis("vthread"))
+    s[C].bind(tz, te.thread_axis("threadIdx.z"))
+    s[C].bind(ty, te.thread_axis("threadIdx.y"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
+    s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
+
+    # Tile reduction axes
+    s[OL].compute_at(s[C], tx)
+    b1, b2, y, x = s[OL].op.axis
+    b = s[OL].fuse(b1, b2)
+    rc, = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    s[OL].reorder(rco, b, y, x, rci)
+
+    s[AA].compute_at(s[OL], rco)
+    _, _, k, n = s[AA].op.axis
+    AA_align = offset_bgemm + cfg["tile_x"].size[1] * cfg["tile_x"].size[2] * cfg["tile_x"].size[3]
+    s[AA].storage_align(k, AA_align - 1, AA_align)
+
+    s[BB].compute_at(s[OL], rco)
+    _, _, m, k = s[BB].op.axis
+    BB_align = offset_bgemm + cfg["tile_rc"].size[1]
+    s[BB].storage_align(m, BB_align - 1, BB_align)
+
+    # Schedule for A and B shared memory load
+    for load in [AA, BB]:
+        fused = s[load].fuse(*list(s[load].op.axis))
+        fused, ti = s[load].split(fused, factor=vector_bgemm)
+        fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
+        fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
+        fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+        s[load].vectorize(ti)
+
+
+def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                       use_tensorcore, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+    N, H, W, CI = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # Kernel tensor is raw tensor, do strict check
+        if dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (dilation_h, dilation_w, 1, 1))
+        KH, KW, CI, CO = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # Kernel tensor is pre-transfomred. This op is created by conv2d_alter_op.
+        # Dilation is not supported
+        alpha, _, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
+    data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad")
+
+    r = KW
+    m = tile_size
+    H = (H + pt + pb - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
+    P = N * nH * nW
+
+    # Determine whether the shape is available with tensorcore
+    shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
+                      (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
+                      (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+
+    if shape_judge and use_tensorcore:
+        trans_type = "float16"
+    else:
+        trans_type = data.dtype
+
+    # Compute transform matrix
+    A, _, _ = winograd_transform_matrices(m, r, out_dtype)
+    _, B, G = winograd_transform_matrices(m, r, data.dtype)
+
+    # Transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
+                                     te.sum((kernel[r_kh][r_kw][ci][co]) *
+                                            G[eps][r_kh] * G[nu][r_kw],
+                                            axis=[r_kh, r_kw]), name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    # Pack input tile
+    input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW)),
+                                     idxmod(idxdiv(p, nW), nH) * m + eps,
+                                     idxmod(p, nW) * m + nu,
+                                     c], name='d')
+
+    # Transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci:
+                           te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # Convert data type of input feature maps and weights for tensorcore
+    Transdata = te.compute(
+        data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type))
+    TransFilter = te.compute(
+        kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type))
+
+    # Do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co:
+                       te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) *
+                              (TransFilter[eps][nu][ci][co]).astype(out_dtype),

Review comment:
       Shouldn't be 
   ```python
   TransFilter *  Transdata
   ```
   ? I am very curious it could pass the correctness testing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wsl-inspur edited a comment on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
wsl-inspur edited a comment on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-625091214


   @FrozenGene 
   We made the following tests to go through all the layouts, 
   
   Layout_1(the same in this PR)
   input_tile = (P, CI, alpha, alpha)
   data_pack = (alpha, alpha, P, CI)
   bgemm = (alpha, alpha, P, CO)
   inverse = (P, CO, m, m)
   output = (N, H, W, CO)
   kernel = (alpha, alpha, CI, CO)
   
   Layout_2
   input_tile = (P, CI, alpha, alpha)
   data_pack = (alpha, alpha, P, CI)
   bgemm = (alpha, alpha, P, CO)
   inverse = (P, CO, m, m)
   output = (N, H, W, CO)
   kernel = (alpha, alpha, CO, CI)
   
   Layout_3
   input_tile = (alpha, alpha, P, CI)
   data_pack = (alpha, alpha, P, CI)
   bgemm = (alpha, alpha, P, CO)
   inverse = (P, CO, m, m)
   output = (N, H, W, CO)
   kernel = (alpha, alpha, CI, CO)
   
   Layout_4
   input_tile = (alpha, alpha, P, CI)
   data_pack = (alpha, alpha, P, CI)
   bgemm = (alpha, alpha, P, CO)
   inverse = (m, m, P, CO)
   output = (N, H, W, CO)
   kernel = (alpha, alpha, CO, CI)
   
   The results are listed below.
   
   kernel:3x3x64x64 feature maps:56x56x64:
   batch | layout_1 | layout_2 | layout_3 | layout_4
   -- | -- | -- | -- | --
   1 | 0.0762 | 0.0766 | 0.0758 | 0.0907
   2 | 0.0911 | 0.0931 | 0.0957 | 0.0939
   4 | 0.1197 | 0.124 | 0.1188 | 0.1257
   8 | 0.1979 | 0.1942 | 0.2026 | 0.2208
   16 | 0.3453 | 0.3577 | 0.3427 | 0.3833
   32 | 0.6613 | 0.7161 | 0.6615 | 0.7574
   256 | 5.5837 | 5.3269 | 5.772 | 5.7058
   
   kernel:3x3x256x256 feature maps:14x14x256:
   
   batch | layout_1 | layout_2 | layout_3 | layout_4
   -- | -- | -- | -- | --
   1 | 0.0633 | 0.0694 | 0.0703 | 0.656
   2 | 0.0825 | 0.0835 | 0.0822 | 0.851
   4 | 0.1417 | 0.1562 | 0.1382 | 0.1502
   8 | 0.1829 | 0.1853 | 0.1833 | 0.182
   16 | 0.264 | 0.277 | 0.2832 | 0.2621
   32 | 0.4506 | 0.4799 | 0.4507 | 0.4773
   256 | 3.9432 | 4.0867 | 3.5088 | 4.5032
   
   
   
   
   
   Note: weight transform was pre-computed. The benchmarks were running on T4 GPU (16GB, 70W). Latency is reported with unit of ms.
   
   We can see that the performance of all layouts are in the same level. 
   
   
   
   
   
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene edited a comment on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
FrozenGene edited a comment on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-621791710


   For performance, have you tried some other layouts on GPU? I have some exp on CPU. The more suitable layout on CPU of NHWC input is:
   
   ```
     input_tile: alpha, alpha, P, CI
     data_pack: alpha, alpha, P, CI
     bgemm: alpha, alpha, P, CO
     inverse: m, m, P, CO
     output: N H W CO
     kernel: alpha alpha CO CI
   ```
   For kernel, I design `alpha alpha CO CI`, because I want to vectorize CI. Maybe on GPU, alpha alpha CI CO is better.
   
   I test your layout compared the layout I mentioned, your layout on skylake-512 is 0.388ms, but my layout I mentioned is 0.375ms. I use 20 threads on workload (1, 56, 56, 64, 64). The performance could be reproduced stabilized.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wsl-inspur commented on a change in pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
wsl-inspur commented on a change in pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#discussion_r417902315



##########
File path: topi/python/topi/cuda/conv2d_nhwc_winograd.py
##########
@@ -0,0 +1,639 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+# pylint: disable=too-many-arguments,too-many-locals
+# pylint: disable=too-many-statements
+"""Winograd template for cuda backend"""
+
+import tvm
+from tvm import te
+from tvm import autotvm
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+from .tensor_intrin import intrin_wmma_load_matrix_A
+from .tensor_intrin import intrin_wmma_load_matrix_W
+from .tensor_intrin import intrin_wmma_store_matrix
+from .tensor_intrin import intrin_wmma_gemm
+
+def _infer_tile_size(data, kernel):
+    """Compute the tile size"""
+    N, H, W, CI = get_const_tuple(data.shape)
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm tensorcore"""
+    A = data_pack
+    B = kernel_pack
+    C = bgemm
+    _, _, P, out_dim = get_const_tuple(C.shape)
+    out_dtype = C.dtype
+
+    # Explicit memory access
+    AS = s.cache_read(A, 'shared', [C])
+    BS = s.cache_read(B, 'shared', [C])
+    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+    CF = s.cache_write(C, 'wmma.accumulator')
+    CS = s.cache_read(CF, 'shared', [C])
+
+    # Create tuning space
+    cfg.define_knob("block_row_warps", [1, 2, 4])
+    cfg.define_knob("block_col_warps", [1, 2, 4])
+    cfg.define_knob("warp_row_tiles", [1, 2, 4, 8])
+    cfg.define_knob("warp_col_tiles", [1, 2, 4, 8])
+    cfg.define_knob("chunk", [1, 2, 4, 8])
+    cfg.define_knob("offset", [0, 1, 2, 4, 8])
+    cfg.define_knob("offsetCS", [0, 1, 2, 4, 8])
+    cfg.define_knob("vec", [1, 2, 4, 8])
+
+    # Ensure that the default parameters are applicable when autotvm is not in use
+    if (P % 16 == 0 and out_dim % 16 == 0):
+        cfg.define_knob("wmma_m", [16, 8, 32])
+    elif (P % 32 == 0 and out_dim % 8 == 0):
+        cfg.define_knob("wmma_m", [32, 16, 8])
+    elif (P % 8 == 0 and out_dim % 32 == 0):
+        cfg.define_knob("wmma_m", [8, 16, 32])
+
+    warp_size = 32
+    wmma_k = 16
+    block_row_warps = cfg["block_row_warps"].val
+    block_col_warps = cfg["block_col_warps"].val
+    warp_row_tiles = cfg["warp_row_tiles"].val
+    warp_col_tiles = cfg["warp_col_tiles"].val
+    chunk = cfg["chunk"].val
+    offsetAB = cfg["offset"].val
+    offsetCS = cfg["offsetCS"].val
+    wmma_m = cfg["wmma_m"].val
+    vec = cfg["vec"].val
+
+    if wmma_m == 16:
+        wmma_n = 16
+    elif wmma_m == 8:
+        wmma_n = 32
+    elif wmma_m == 32:
+        wmma_n = 8
+
+    # Define the stride of intrin functions
+    AS_align = chunk * wmma_k + offsetAB
+    BS_align = warp_col_tiles * block_col_warps * wmma_n + offsetAB
+    CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+    AS_stride = [AS_align, 1]
+    BS_stride = [BS_align, 1]
+    AF_stride = [wmma_k, 1]
+    BF_stride = [wmma_n * warp_col_tiles, 1]
+    CF_stride = [warp_col_tiles * wmma_n, 1]
+    CS_stride = [CS_align, 1]
+    block_x = te.thread_axis('blockIdx.x')
+    block_y = te.thread_axis('blockIdx.y')
+    block_z = te.thread_axis('blockIdx.z')
+    thread_x = te.thread_axis('threadIdx.x')
+    thread_y = te.thread_axis('threadIdx.y')
+    thread_z = te.thread_axis('threadIdx.z')
+
+    # Schedule for computation
+    block_factor_b = wmma_m * warp_row_tiles * block_row_warps
+    block_factor_o = wmma_n * warp_col_tiles * block_col_warps
+    alpha_1, alpha_2, b, o = C.op.axis
+    block_k = s[C].fuse(alpha_1, alpha_2)
+    block_i, bc = s[C].split(b, factor=block_factor_b)
+    block_j, oc = s[C].split(o, factor=block_factor_o)
+    s[C].reorder(block_k, block_i, block_j, bc, oc)
+    t = s[C].fuse(bc, oc)
+    t, vi = s[C].split(t, factor=vec)
+    t, tx = s[C].split(t, factor=warp_size)
+    t, ty = s[C].split(t, factor=block_row_warps)
+    t, tz = s[C].split(t, factor=block_col_warps)
+    s[C].bind(block_k, block_z)
+    s[C].bind(block_i, block_x)
+    s[C].bind(block_j, block_y)
+    s[C].bind(tz, thread_z)
+    s[C].bind(ty, thread_y)
+    s[C].bind(tx, thread_x)
+    s[C].vectorize(vi)
+
+    # Schedule for wmma store
+    s[CS].compute_at(s[C], block_j)
+    _, _, bb, oo = CS.op.axis
+    s[CS].storage_align(bb, CS_align - 1, CS_align)
+    bb, bbi = s[CS].split(bb, factor=wmma_m)
+    oo, ooi = s[CS].split(oo, factor=wmma_n)
+    bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+    oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+    s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi)
+
+    # Schedule for wmma computation
+    s[CF].compute_at(s[CS], oo)
+    _, _, warp_i, warp_j = CF.op.axis
+    warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+    warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+    k, = CF.op.reduce_axis
+    k, _k = s[CF].split(k, factor=wmma_k)
+    ko, ki = s[CF].split(k, factor=chunk)
+    s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+    # Schedule for  wmma_matrix_a load
+    s[AF].compute_at(s[CF], ki)
+    _, _, b, i = AF.op.axis
+    b, b_ii = s[AF].split(b, factor=wmma_m)
+    i, i_jj = s[AF].split(i, factor=wmma_k)
+    s[AF].reorder(b, i, b_ii, i_jj)
+
+    # Schedule for  wmma_matrix_b load
+    s[BF].compute_at(s[CF], ki)
+    _, _, i, o = BF.op.axis
+    o, o_ii = s[BF].split(o, factor=wmma_n)
+    i, i_ii = s[BF].split(i, factor=wmma_k)
+    s[BF].reorder(i, o, i_ii, o_ii)
+
+    # Schedule for A's(B's) shared memory load
+    def shared_shedule(stage, strides):
+        s[stage].compute_at(s[CF], ko)
+        _, _, xo, yo = stage.op.axis
+        s[stage].storage_align(xo, strides - 1, strides)
+        t = s[stage].fuse(xo, yo)
+        t, vi = s[stage].split(t, factor=vec)
+        t, tx = s[stage].split(t, factor=warp_size)
+        t, ty = s[stage].split(t, factor=block_row_warps)
+        _, tz = s[stage].split(t, factor=block_col_warps)
+        s[stage].bind(ty, thread_y)
+        s[stage].bind(tz, thread_z)
+        s[stage].bind(tx, thread_x)
+        s[stage].vectorize(vi)
+
+    shared_shedule(AS, AS_align)
+    shared_shedule(BS, BS_align)
+
+    shape = (wmma_m, wmma_n, wmma_k)
+    in_dtype = 'float16'
+    AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
+    BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype)
+    k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
+    CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
+                            te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *
+                                   BL_gemm[k_gemm, jj].astype(out_dtype),
+                                   axis=k_gemm), name='CL_compute')
+
+    # Lower the computation loops down to TensorCore hardware intrinsics
+    # by mapping the tensorcore to tensor intrinsics
+    s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, shape, "row_major",
+                                                    (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16'))
+    s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, shape, "row_major",
+                                                    (wmma_k, wmma_n), (wmma_k, wmma_n), 'float16'))
+    s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride,
+                                          BF_stride, CF_stride, shape))
+    s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, out_dtype,
+                                                  (wmma_m, wmma_n), (wmma_m, wmma_n)))
+
+
+def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm direct"""
+    b1, b2, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    # Create tuning space
+    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
+                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_knob("offset_bgemm", [0, 1, 2, 4, 8])
+    cfg.define_knob("vector_bgemm", [1, 2, 4, 8])
+    offset_bgemm = cfg["offset_bgemm"].val
+    vector_bgemm = cfg["vector_bgemm"].val
+
+    C = bgemm
+    A0, B0 = kernel_pack, data_pack
+
+    # Designate the memory hierarchy
+    OL = s.cache_write(C, 'local')
+    AA = s.cache_read(A0, 'shared', [OL])
+    BB = s.cache_read(B0, 'shared', [OL])
+
+    # Tile and bind spatial axes
+    b = s[bgemm].fuse(b1, b2)
+    bgemm_scope, b = s[bgemm].split(b, nparts=1)
+    bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
+    s[C].bind(bz, te.thread_axis("blockIdx.z"))
+    s[C].bind(by, te.thread_axis("blockIdx.y"))
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(vz, te.thread_axis("vthread"))
+    s[C].bind(vy, te.thread_axis("vthread"))
+    s[C].bind(vx, te.thread_axis("vthread"))
+    s[C].bind(tz, te.thread_axis("threadIdx.z"))
+    s[C].bind(ty, te.thread_axis("threadIdx.y"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
+    s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
+
+    # Tile reduction axes
+    s[OL].compute_at(s[C], tx)
+    b1, b2, y, x = s[OL].op.axis
+    b = s[OL].fuse(b1, b2)
+    rc, = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    s[OL].reorder(rco, b, y, x, rci)
+
+    s[AA].compute_at(s[OL], rco)
+    _, _, k, n = s[AA].op.axis
+    AA_align = offset_bgemm + cfg["tile_x"].size[1] * cfg["tile_x"].size[2] * cfg["tile_x"].size[3]
+    s[AA].storage_align(k, AA_align - 1, AA_align)
+
+    s[BB].compute_at(s[OL], rco)
+    _, _, m, k = s[BB].op.axis
+    BB_align = offset_bgemm + cfg["tile_rc"].size[1]
+    s[BB].storage_align(m, BB_align - 1, BB_align)
+
+    # Schedule for A and B shared memory load
+    for load in [AA, BB]:
+        fused = s[load].fuse(*list(s[load].op.axis))
+        fused, ti = s[load].split(fused, factor=vector_bgemm)
+        fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
+        fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
+        fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+        s[load].vectorize(ti)
+
+
+def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                       use_tensorcore, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+    N, H, W, CI = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # Kernel tensor is raw tensor, do strict check
+        if dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (dilation_h, dilation_w, 1, 1))
+        KH, KW, CI, CO = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # Kernel tensor is pre-transfomred. This op is created by conv2d_alter_op.
+        # Dilation is not supported
+        alpha, _, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
+    data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad")
+
+    r = KW
+    m = tile_size
+    H = (H + pt + pb - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
+    P = N * nH * nW
+
+    # Determine whether the shape is available with tensorcore
+    shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
+                      (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
+                      (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+
+    if shape_judge and use_tensorcore:
+        trans_type = "float16"
+    else:
+        trans_type = data.dtype
+
+    # Compute transform matrix
+    A, _, _ = winograd_transform_matrices(m, r, out_dtype)
+    _, B, G = winograd_transform_matrices(m, r, data.dtype)
+
+    # Transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
+                                     te.sum((kernel[r_kh][r_kw][ci][co]) *
+                                            G[eps][r_kh] * G[nu][r_kw],
+                                            axis=[r_kh, r_kw]), name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    # Pack input tile
+    input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW)),
+                                     idxmod(idxdiv(p, nW), nH) * m + eps,
+                                     idxmod(p, nW) * m + nu,
+                                     c], name='d')
+
+    # Transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci:
+                           te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # Convert data type of input feature maps and weights for tensorcore
+    Transdata = te.compute(
+        data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type))
+    TransFilter = te.compute(
+        kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type))
+
+    # Do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co:
+                       te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) *
+                              (TransFilter[eps][nu][ci][co]).astype(out_dtype),

Review comment:
       Tensorcore supports fp16 input, and fp16/fp32 output. If the input data type is fp32, input should be converted to fp16 to before running on Tensor Core. This data conversion only causes a limited loss of accuracy.
    
   We added a test file (test_topi_conv2d_nhwc_winograd.py) in this PR. It tests the correctness and accuracy of different shape Conv2d with direct and tensorcore winograd, and it can pass when the rtol is 2e-3.
   
   
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#discussion_r417960925



##########
File path: topi/python/topi/cuda/conv2d_nhwc_winograd.py
##########
@@ -0,0 +1,639 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+# pylint: disable=too-many-arguments,too-many-locals
+# pylint: disable=too-many-statements
+"""Winograd template for cuda backend"""
+
+import tvm
+from tvm import te
+from tvm import autotvm
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+from .tensor_intrin import intrin_wmma_load_matrix_A
+from .tensor_intrin import intrin_wmma_load_matrix_W
+from .tensor_intrin import intrin_wmma_store_matrix
+from .tensor_intrin import intrin_wmma_gemm
+
+def _infer_tile_size(data, kernel):
+    """Compute the tile size"""
+    N, H, W, CI = get_const_tuple(data.shape)
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm tensorcore"""
+    A = data_pack
+    B = kernel_pack
+    C = bgemm
+    _, _, P, out_dim = get_const_tuple(C.shape)
+    out_dtype = C.dtype
+
+    # Explicit memory access
+    AS = s.cache_read(A, 'shared', [C])
+    BS = s.cache_read(B, 'shared', [C])
+    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+    CF = s.cache_write(C, 'wmma.accumulator')
+    CS = s.cache_read(CF, 'shared', [C])
+
+    # Create tuning space
+    cfg.define_knob("block_row_warps", [1, 2, 4])
+    cfg.define_knob("block_col_warps", [1, 2, 4])
+    cfg.define_knob("warp_row_tiles", [1, 2, 4, 8])
+    cfg.define_knob("warp_col_tiles", [1, 2, 4, 8])
+    cfg.define_knob("chunk", [1, 2, 4, 8])
+    cfg.define_knob("offset", [0, 1, 2, 4, 8])
+    cfg.define_knob("offsetCS", [0, 1, 2, 4, 8])
+    cfg.define_knob("vec", [1, 2, 4, 8])
+
+    # Ensure that the default parameters are applicable when autotvm is not in use
+    if (P % 16 == 0 and out_dim % 16 == 0):
+        cfg.define_knob("wmma_m", [16, 8, 32])
+    elif (P % 32 == 0 and out_dim % 8 == 0):
+        cfg.define_knob("wmma_m", [32, 16, 8])
+    elif (P % 8 == 0 and out_dim % 32 == 0):
+        cfg.define_knob("wmma_m", [8, 16, 32])
+
+    warp_size = 32
+    wmma_k = 16
+    block_row_warps = cfg["block_row_warps"].val
+    block_col_warps = cfg["block_col_warps"].val
+    warp_row_tiles = cfg["warp_row_tiles"].val
+    warp_col_tiles = cfg["warp_col_tiles"].val
+    chunk = cfg["chunk"].val
+    offsetAB = cfg["offset"].val
+    offsetCS = cfg["offsetCS"].val
+    wmma_m = cfg["wmma_m"].val
+    vec = cfg["vec"].val
+
+    if wmma_m == 16:
+        wmma_n = 16
+    elif wmma_m == 8:
+        wmma_n = 32
+    elif wmma_m == 32:
+        wmma_n = 8
+
+    # Define the stride of intrin functions
+    AS_align = chunk * wmma_k + offsetAB
+    BS_align = warp_col_tiles * block_col_warps * wmma_n + offsetAB
+    CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+    AS_stride = [AS_align, 1]
+    BS_stride = [BS_align, 1]
+    AF_stride = [wmma_k, 1]
+    BF_stride = [wmma_n * warp_col_tiles, 1]
+    CF_stride = [warp_col_tiles * wmma_n, 1]
+    CS_stride = [CS_align, 1]
+    block_x = te.thread_axis('blockIdx.x')
+    block_y = te.thread_axis('blockIdx.y')
+    block_z = te.thread_axis('blockIdx.z')
+    thread_x = te.thread_axis('threadIdx.x')
+    thread_y = te.thread_axis('threadIdx.y')
+    thread_z = te.thread_axis('threadIdx.z')
+
+    # Schedule for computation
+    block_factor_b = wmma_m * warp_row_tiles * block_row_warps
+    block_factor_o = wmma_n * warp_col_tiles * block_col_warps
+    alpha_1, alpha_2, b, o = C.op.axis
+    block_k = s[C].fuse(alpha_1, alpha_2)
+    block_i, bc = s[C].split(b, factor=block_factor_b)
+    block_j, oc = s[C].split(o, factor=block_factor_o)
+    s[C].reorder(block_k, block_i, block_j, bc, oc)
+    t = s[C].fuse(bc, oc)
+    t, vi = s[C].split(t, factor=vec)
+    t, tx = s[C].split(t, factor=warp_size)
+    t, ty = s[C].split(t, factor=block_row_warps)
+    t, tz = s[C].split(t, factor=block_col_warps)
+    s[C].bind(block_k, block_z)
+    s[C].bind(block_i, block_x)
+    s[C].bind(block_j, block_y)
+    s[C].bind(tz, thread_z)
+    s[C].bind(ty, thread_y)
+    s[C].bind(tx, thread_x)
+    s[C].vectorize(vi)
+
+    # Schedule for wmma store
+    s[CS].compute_at(s[C], block_j)
+    _, _, bb, oo = CS.op.axis
+    s[CS].storage_align(bb, CS_align - 1, CS_align)
+    bb, bbi = s[CS].split(bb, factor=wmma_m)
+    oo, ooi = s[CS].split(oo, factor=wmma_n)
+    bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+    oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+    s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi)
+
+    # Schedule for wmma computation
+    s[CF].compute_at(s[CS], oo)
+    _, _, warp_i, warp_j = CF.op.axis
+    warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+    warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+    k, = CF.op.reduce_axis
+    k, _k = s[CF].split(k, factor=wmma_k)
+    ko, ki = s[CF].split(k, factor=chunk)
+    s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+    # Schedule for  wmma_matrix_a load
+    s[AF].compute_at(s[CF], ki)
+    _, _, b, i = AF.op.axis
+    b, b_ii = s[AF].split(b, factor=wmma_m)
+    i, i_jj = s[AF].split(i, factor=wmma_k)
+    s[AF].reorder(b, i, b_ii, i_jj)
+
+    # Schedule for  wmma_matrix_b load
+    s[BF].compute_at(s[CF], ki)
+    _, _, i, o = BF.op.axis
+    o, o_ii = s[BF].split(o, factor=wmma_n)
+    i, i_ii = s[BF].split(i, factor=wmma_k)
+    s[BF].reorder(i, o, i_ii, o_ii)
+
+    # Schedule for A's(B's) shared memory load
+    def shared_shedule(stage, strides):
+        s[stage].compute_at(s[CF], ko)
+        _, _, xo, yo = stage.op.axis
+        s[stage].storage_align(xo, strides - 1, strides)
+        t = s[stage].fuse(xo, yo)
+        t, vi = s[stage].split(t, factor=vec)
+        t, tx = s[stage].split(t, factor=warp_size)
+        t, ty = s[stage].split(t, factor=block_row_warps)
+        _, tz = s[stage].split(t, factor=block_col_warps)
+        s[stage].bind(ty, thread_y)
+        s[stage].bind(tz, thread_z)
+        s[stage].bind(tx, thread_x)
+        s[stage].vectorize(vi)
+
+    shared_shedule(AS, AS_align)
+    shared_shedule(BS, BS_align)
+
+    shape = (wmma_m, wmma_n, wmma_k)
+    in_dtype = 'float16'
+    AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
+    BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype)
+    k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
+    CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
+                            te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *
+                                   BL_gemm[k_gemm, jj].astype(out_dtype),
+                                   axis=k_gemm), name='CL_compute')
+
+    # Lower the computation loops down to TensorCore hardware intrinsics
+    # by mapping the tensorcore to tensor intrinsics
+    s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, shape, "row_major",
+                                                    (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16'))
+    s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, shape, "row_major",
+                                                    (wmma_k, wmma_n), (wmma_k, wmma_n), 'float16'))
+    s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride,
+                                          BF_stride, CF_stride, shape))
+    s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, out_dtype,
+                                                  (wmma_m, wmma_n), (wmma_m, wmma_n)))
+
+
+def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm direct"""
+    b1, b2, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    # Create tuning space
+    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
+                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_knob("offset_bgemm", [0, 1, 2, 4, 8])
+    cfg.define_knob("vector_bgemm", [1, 2, 4, 8])
+    offset_bgemm = cfg["offset_bgemm"].val
+    vector_bgemm = cfg["vector_bgemm"].val
+
+    C = bgemm
+    A0, B0 = kernel_pack, data_pack
+
+    # Designate the memory hierarchy
+    OL = s.cache_write(C, 'local')
+    AA = s.cache_read(A0, 'shared', [OL])
+    BB = s.cache_read(B0, 'shared', [OL])
+
+    # Tile and bind spatial axes
+    b = s[bgemm].fuse(b1, b2)
+    bgemm_scope, b = s[bgemm].split(b, nparts=1)
+    bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
+    s[C].bind(bz, te.thread_axis("blockIdx.z"))
+    s[C].bind(by, te.thread_axis("blockIdx.y"))
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(vz, te.thread_axis("vthread"))
+    s[C].bind(vy, te.thread_axis("vthread"))
+    s[C].bind(vx, te.thread_axis("vthread"))
+    s[C].bind(tz, te.thread_axis("threadIdx.z"))
+    s[C].bind(ty, te.thread_axis("threadIdx.y"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
+    s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
+
+    # Tile reduction axes
+    s[OL].compute_at(s[C], tx)
+    b1, b2, y, x = s[OL].op.axis
+    b = s[OL].fuse(b1, b2)
+    rc, = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    s[OL].reorder(rco, b, y, x, rci)
+
+    s[AA].compute_at(s[OL], rco)
+    _, _, k, n = s[AA].op.axis
+    AA_align = offset_bgemm + cfg["tile_x"].size[1] * cfg["tile_x"].size[2] * cfg["tile_x"].size[3]
+    s[AA].storage_align(k, AA_align - 1, AA_align)
+
+    s[BB].compute_at(s[OL], rco)
+    _, _, m, k = s[BB].op.axis
+    BB_align = offset_bgemm + cfg["tile_rc"].size[1]
+    s[BB].storage_align(m, BB_align - 1, BB_align)
+
+    # Schedule for A and B shared memory load
+    for load in [AA, BB]:
+        fused = s[load].fuse(*list(s[load].op.axis))
+        fused, ti = s[load].split(fused, factor=vector_bgemm)
+        fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
+        fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
+        fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+        s[load].vectorize(ti)
+
+
+def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                       use_tensorcore, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+    N, H, W, CI = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # Kernel tensor is raw tensor, do strict check
+        if dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (dilation_h, dilation_w, 1, 1))
+        KH, KW, CI, CO = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # Kernel tensor is pre-transfomred. This op is created by conv2d_alter_op.
+        # Dilation is not supported
+        alpha, _, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
+    data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad")
+
+    r = KW
+    m = tile_size
+    H = (H + pt + pb - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
+    P = N * nH * nW
+
+    # Determine whether the shape is available with tensorcore
+    shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
+                      (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
+                      (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+
+    if shape_judge and use_tensorcore:
+        trans_type = "float16"
+    else:
+        trans_type = data.dtype
+
+    # Compute transform matrix
+    A, _, _ = winograd_transform_matrices(m, r, out_dtype)
+    _, B, G = winograd_transform_matrices(m, r, data.dtype)
+
+    # Transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
+                                     te.sum((kernel[r_kh][r_kw][ci][co]) *
+                                            G[eps][r_kh] * G[nu][r_kw],
+                                            axis=[r_kh, r_kw]), name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    # Pack input tile
+    input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW)),
+                                     idxmod(idxdiv(p, nW), nH) * m + eps,
+                                     idxmod(p, nW) * m + nu,
+                                     c], name='d')
+
+    # Transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci:
+                           te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # Convert data type of input feature maps and weights for tensorcore
+    Transdata = te.compute(
+        data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type))
+    TransFilter = te.compute(
+        kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type))
+
+    # Do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co:
+                       te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) *
+                              (TransFilter[eps][nu][ci][co]).astype(out_dtype),

Review comment:
       ok. I think I misunderstand.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-624695979


   cc @FrozenGene 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wsl-inspur commented on pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

Posted by GitBox <gi...@apache.org>.
wsl-inspur commented on pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#issuecomment-625091214


   @FrozenGene 
   We made the following tests to go through all the layouts, 
   
   Layout_1(the same in this PR)
   input_tile = (P, CI, alpha, alpha)
   data_pack = (alpha, alpha, P, CI)
   bgemm = (alpha, alpha, P, CO)
   inverse = (P, CO, m, m)
   output = (N, H, W, CO)
   kernel = (alpha, alpha, CI, CO)
   
   Layout_2
   input_tile = (P, CI, alpha, alpha)
   data_pack = (alpha, alpha, P, CI)
   bgemm = (alpha, alpha, P, CO)
   inverse = (P, CO, m, m)
   output = (N, H, W, CO)
   kernel = (alpha, alpha, CO, CI)
   
   Layout_3
   input_tile = (alpha, alpha, P, CI)
   data_pack = (alpha, alpha, P, CI)
   bgemm = (alpha, alpha, P, CO)
   inverse = (P, CO, m, m)
   output = (N, H, W, CO)
   kernel = (alpha, alpha, CI, CO)
   
   Layout_4
   input_tile = (alpha, alpha, P, CI)
   data_pack = (alpha, alpha, P, CI)
   bgemm = (alpha, alpha, P, CO)
   inverse = (m, m, P, CO)
   output = (N, H, W, CO)
   kernel = (alpha, alpha, CO, CI)
   
   The results are listed below.
   
   kernel:3x3x64x64 feature maps:56x56x64:
   batch | layout_1 | layout_2 | layout_3 | layout_4
   -- | -- | -- | -- | --
   1 | 0.0762 | 0.0766 | 0.0758 | 0.0907
   2 | 0.0911 | 0.0931 | 0.0957 | 0.0939
   4 | 0.1197 | 0.124 | 0.1188 | 0.1257
   8 | 0.1979 | 0.1942 | 0.2026 | 0.2208
   16 | 0.3453 | 0.3577 | 0.3427 | 0.3833
   32 | 0.6613 | 0.7161 | 0.6615 | 0.7574
   
   kernel:3x3x256x256 feature maps:14x14x256:
   
   batch | layout_1 | layout_2 | layout_3 | layout_4
   -- | -- | -- | -- | --
   1 | 0.0633 | 0.0694 | 0.0703 | 0.656
   2 | 0.0825 | 0.0835 | 0.0822 | 0.851
   4 | 0.1417 | 0.1562 | 0.1382 | 0.1502
   8 | 0.1829 | 0.1853 | 0.1833 | 0.182
   16 | 0.264 | 0.277 | 0.2832 | 0.2621
   32 | 0.4506 | 0.4799 | 0.4507 | 0.4773
   256 | 3.9432 | 4.0867 | 3.5088 | 4.5032
   
   
   
   Note: weight transform was pre-computed. The benchmarks were running on T4 GPU (16GB, 70W). Latency is reported with unit of ms.
   
   We can see that the performance of all layouts are in the same level. 
   
   
   
   
   
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org