You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/04/30 08:26:26 UTC

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #5485: [TOPI][Winograd] Optimization of Conv2d Winograd algorithm on Tensor …

FrozenGene commented on a change in pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#discussion_r417839497



##########
File path: topi/python/topi/cuda/conv2d_nhwc_winograd.py
##########
@@ -0,0 +1,639 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+# pylint: disable=too-many-arguments,too-many-locals
+# pylint: disable=too-many-statements
+"""Winograd template for cuda backend"""
+
+import tvm
+from tvm import te
+from tvm import autotvm
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+from .tensor_intrin import intrin_wmma_load_matrix_A
+from .tensor_intrin import intrin_wmma_load_matrix_W
+from .tensor_intrin import intrin_wmma_store_matrix
+from .tensor_intrin import intrin_wmma_gemm
+
+def _infer_tile_size(data, kernel):
+    """Compute the tile size"""
+    N, H, W, CI = get_const_tuple(data.shape)
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm tensorcore"""
+    A = data_pack
+    B = kernel_pack
+    C = bgemm
+    _, _, P, out_dim = get_const_tuple(C.shape)
+    out_dtype = C.dtype
+
+    # Explicit memory access
+    AS = s.cache_read(A, 'shared', [C])
+    BS = s.cache_read(B, 'shared', [C])
+    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+    CF = s.cache_write(C, 'wmma.accumulator')
+    CS = s.cache_read(CF, 'shared', [C])
+
+    # Create tuning space
+    cfg.define_knob("block_row_warps", [1, 2, 4])
+    cfg.define_knob("block_col_warps", [1, 2, 4])
+    cfg.define_knob("warp_row_tiles", [1, 2, 4, 8])
+    cfg.define_knob("warp_col_tiles", [1, 2, 4, 8])
+    cfg.define_knob("chunk", [1, 2, 4, 8])
+    cfg.define_knob("offset", [0, 1, 2, 4, 8])
+    cfg.define_knob("offsetCS", [0, 1, 2, 4, 8])
+    cfg.define_knob("vec", [1, 2, 4, 8])
+
+    # Ensure that the default parameters are applicable when autotvm is not in use
+    if (P % 16 == 0 and out_dim % 16 == 0):
+        cfg.define_knob("wmma_m", [16, 8, 32])
+    elif (P % 32 == 0 and out_dim % 8 == 0):
+        cfg.define_knob("wmma_m", [32, 16, 8])
+    elif (P % 8 == 0 and out_dim % 32 == 0):
+        cfg.define_knob("wmma_m", [8, 16, 32])
+
+    warp_size = 32
+    wmma_k = 16
+    block_row_warps = cfg["block_row_warps"].val
+    block_col_warps = cfg["block_col_warps"].val
+    warp_row_tiles = cfg["warp_row_tiles"].val
+    warp_col_tiles = cfg["warp_col_tiles"].val
+    chunk = cfg["chunk"].val
+    offsetAB = cfg["offset"].val
+    offsetCS = cfg["offsetCS"].val
+    wmma_m = cfg["wmma_m"].val
+    vec = cfg["vec"].val
+
+    if wmma_m == 16:
+        wmma_n = 16
+    elif wmma_m == 8:
+        wmma_n = 32
+    elif wmma_m == 32:
+        wmma_n = 8
+
+    # Define the stride of intrin functions
+    AS_align = chunk * wmma_k + offsetAB
+    BS_align = warp_col_tiles * block_col_warps * wmma_n + offsetAB
+    CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+    AS_stride = [AS_align, 1]
+    BS_stride = [BS_align, 1]
+    AF_stride = [wmma_k, 1]
+    BF_stride = [wmma_n * warp_col_tiles, 1]
+    CF_stride = [warp_col_tiles * wmma_n, 1]
+    CS_stride = [CS_align, 1]
+    block_x = te.thread_axis('blockIdx.x')
+    block_y = te.thread_axis('blockIdx.y')
+    block_z = te.thread_axis('blockIdx.z')
+    thread_x = te.thread_axis('threadIdx.x')
+    thread_y = te.thread_axis('threadIdx.y')
+    thread_z = te.thread_axis('threadIdx.z')
+
+    # Schedule for computation
+    block_factor_b = wmma_m * warp_row_tiles * block_row_warps
+    block_factor_o = wmma_n * warp_col_tiles * block_col_warps
+    alpha_1, alpha_2, b, o = C.op.axis
+    block_k = s[C].fuse(alpha_1, alpha_2)
+    block_i, bc = s[C].split(b, factor=block_factor_b)
+    block_j, oc = s[C].split(o, factor=block_factor_o)
+    s[C].reorder(block_k, block_i, block_j, bc, oc)
+    t = s[C].fuse(bc, oc)
+    t, vi = s[C].split(t, factor=vec)
+    t, tx = s[C].split(t, factor=warp_size)
+    t, ty = s[C].split(t, factor=block_row_warps)
+    t, tz = s[C].split(t, factor=block_col_warps)
+    s[C].bind(block_k, block_z)
+    s[C].bind(block_i, block_x)
+    s[C].bind(block_j, block_y)
+    s[C].bind(tz, thread_z)
+    s[C].bind(ty, thread_y)
+    s[C].bind(tx, thread_x)
+    s[C].vectorize(vi)
+
+    # Schedule for wmma store
+    s[CS].compute_at(s[C], block_j)
+    _, _, bb, oo = CS.op.axis
+    s[CS].storage_align(bb, CS_align - 1, CS_align)
+    bb, bbi = s[CS].split(bb, factor=wmma_m)
+    oo, ooi = s[CS].split(oo, factor=wmma_n)
+    bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+    oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+    s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi)
+
+    # Schedule for wmma computation
+    s[CF].compute_at(s[CS], oo)
+    _, _, warp_i, warp_j = CF.op.axis
+    warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+    warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+    k, = CF.op.reduce_axis
+    k, _k = s[CF].split(k, factor=wmma_k)
+    ko, ki = s[CF].split(k, factor=chunk)
+    s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+    # Schedule for  wmma_matrix_a load
+    s[AF].compute_at(s[CF], ki)
+    _, _, b, i = AF.op.axis
+    b, b_ii = s[AF].split(b, factor=wmma_m)
+    i, i_jj = s[AF].split(i, factor=wmma_k)
+    s[AF].reorder(b, i, b_ii, i_jj)
+
+    # Schedule for  wmma_matrix_b load
+    s[BF].compute_at(s[CF], ki)
+    _, _, i, o = BF.op.axis
+    o, o_ii = s[BF].split(o, factor=wmma_n)
+    i, i_ii = s[BF].split(i, factor=wmma_k)
+    s[BF].reorder(i, o, i_ii, o_ii)
+
+    # Schedule for A's(B's) shared memory load
+    def shared_shedule(stage, strides):
+        s[stage].compute_at(s[CF], ko)
+        _, _, xo, yo = stage.op.axis
+        s[stage].storage_align(xo, strides - 1, strides)
+        t = s[stage].fuse(xo, yo)
+        t, vi = s[stage].split(t, factor=vec)
+        t, tx = s[stage].split(t, factor=warp_size)
+        t, ty = s[stage].split(t, factor=block_row_warps)
+        _, tz = s[stage].split(t, factor=block_col_warps)
+        s[stage].bind(ty, thread_y)
+        s[stage].bind(tz, thread_z)
+        s[stage].bind(tx, thread_x)
+        s[stage].vectorize(vi)
+
+    shared_shedule(AS, AS_align)
+    shared_shedule(BS, BS_align)
+
+    shape = (wmma_m, wmma_n, wmma_k)
+    in_dtype = 'float16'
+    AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
+    BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype)
+    k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
+    CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
+                            te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *
+                                   BL_gemm[k_gemm, jj].astype(out_dtype),
+                                   axis=k_gemm), name='CL_compute')
+
+    # Lower the computation loops down to TensorCore hardware intrinsics
+    # by mapping the tensorcore to tensor intrinsics
+    s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, shape, "row_major",
+                                                    (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16'))
+    s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, shape, "row_major",
+                                                    (wmma_k, wmma_n), (wmma_k, wmma_n), 'float16'))
+    s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride,
+                                          BF_stride, CF_stride, shape))
+    s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, out_dtype,
+                                                  (wmma_m, wmma_n), (wmma_m, wmma_n)))
+
+
+def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm direct"""
+    b1, b2, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    # Create tuning space
+    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
+                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_knob("offset_bgemm", [0, 1, 2, 4, 8])
+    cfg.define_knob("vector_bgemm", [1, 2, 4, 8])
+    offset_bgemm = cfg["offset_bgemm"].val
+    vector_bgemm = cfg["vector_bgemm"].val
+
+    C = bgemm
+    A0, B0 = kernel_pack, data_pack
+
+    # Designate the memory hierarchy
+    OL = s.cache_write(C, 'local')
+    AA = s.cache_read(A0, 'shared', [OL])
+    BB = s.cache_read(B0, 'shared', [OL])
+
+    # Tile and bind spatial axes
+    b = s[bgemm].fuse(b1, b2)
+    bgemm_scope, b = s[bgemm].split(b, nparts=1)
+    bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
+    s[C].bind(bz, te.thread_axis("blockIdx.z"))
+    s[C].bind(by, te.thread_axis("blockIdx.y"))
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(vz, te.thread_axis("vthread"))
+    s[C].bind(vy, te.thread_axis("vthread"))
+    s[C].bind(vx, te.thread_axis("vthread"))
+    s[C].bind(tz, te.thread_axis("threadIdx.z"))
+    s[C].bind(ty, te.thread_axis("threadIdx.y"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
+    s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
+
+    # Tile reduction axes
+    s[OL].compute_at(s[C], tx)
+    b1, b2, y, x = s[OL].op.axis
+    b = s[OL].fuse(b1, b2)
+    rc, = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    s[OL].reorder(rco, b, y, x, rci)
+
+    s[AA].compute_at(s[OL], rco)
+    _, _, k, n = s[AA].op.axis
+    AA_align = offset_bgemm + cfg["tile_x"].size[1] * cfg["tile_x"].size[2] * cfg["tile_x"].size[3]
+    s[AA].storage_align(k, AA_align - 1, AA_align)
+
+    s[BB].compute_at(s[OL], rco)
+    _, _, m, k = s[BB].op.axis
+    BB_align = offset_bgemm + cfg["tile_rc"].size[1]
+    s[BB].storage_align(m, BB_align - 1, BB_align)
+
+    # Schedule for A and B shared memory load
+    for load in [AA, BB]:
+        fused = s[load].fuse(*list(s[load].op.axis))
+        fused, ti = s[load].split(fused, factor=vector_bgemm)
+        fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
+        fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
+        fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+        s[load].vectorize(ti)
+
+
+def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                       use_tensorcore, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+    N, H, W, CI = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # Kernel tensor is raw tensor, do strict check
+        if dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (dilation_h, dilation_w, 1, 1))
+        KH, KW, CI, CO = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # Kernel tensor is pre-transfomred. This op is created by conv2d_alter_op.
+        # Dilation is not supported
+        alpha, _, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
+    data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad")
+
+    r = KW
+    m = tile_size
+    H = (H + pt + pb - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
+    P = N * nH * nW
+
+    # Determine whether the shape is available with tensorcore
+    shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
+                      (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
+                      (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+
+    if shape_judge and use_tensorcore:
+        trans_type = "float16"
+    else:
+        trans_type = data.dtype
+
+    # Compute transform matrix
+    A, _, _ = winograd_transform_matrices(m, r, out_dtype)
+    _, B, G = winograd_transform_matrices(m, r, data.dtype)
+
+    # Transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
+                                     te.sum((kernel[r_kh][r_kw][ci][co]) *
+                                            G[eps][r_kh] * G[nu][r_kw],
+                                            axis=[r_kh, r_kw]), name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    # Pack input tile
+    input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW)),
+                                     idxmod(idxdiv(p, nW), nH) * m + eps,
+                                     idxmod(p, nW) * m + nu,
+                                     c], name='d')
+
+    # Transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci:
+                           te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # Convert data type of input feature maps and weights for tensorcore
+    Transdata = te.compute(
+        data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type))
+    TransFilter = te.compute(
+        kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type))
+
+    # Do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co:
+                       te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) *
+                              (TransFilter[eps][nu][ci][co]).astype(out_dtype),

Review comment:
       Shouldn't be 
   ```python
   TransFilter *  Transdata
   ```
   ? I am very curious it could pass the correctness testing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org