You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@tvm.apache.org by Olivier Valery via Apache TVM Discuss <no...@discuss.tvm.ai> on 2020/11/05 13:52:30 UTC

[Apache TVM Discuss] [Development] Quantization and 3D convolution


I implemented the conv3d with int8 as following:

I create the file ```python/tvm/topi/cuda/conv3d_int8.py``` which implement the operation itself.

```
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=no-value-for-parameter
"""Int8 conv3d in NCDHWc layout"""
import tvm
from tvm import te
from tvm import autotvm

from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.conv3d import unpack_NCDHWc_to_ncdhw
from ..nn.util import get_pad_tuple3d
from ..util import simplify, get_const_tuple, traverse_inline


def conv3d_ncdhw_int8(data, kernel, strides, padding, dilation, out_dtype="int32"):
    """Compute conv3d internally using conv3d_ncdhwc layout for int8 dtype"""
    assert data.dtype in ("int8", "uint8")
    assert kernel.dtype in ("int8", "uint8")
    assert data.dtype == kernel.dtype
    packed_out = conv3d_NCDHWc_int8(data, kernel, strides, padding, dilation, "NCDHW", out_dtype)
    return unpack_NCDHWc_to_ncdhw(packed_out, out_dtype)


def schedule_conv3d_ncdhw_int8(outs):
    """Create schedule for tensors"""
    return schedule_conv3d_NCDHWc_int8(outs)


@autotvm.register_topi_compute("conv3d_NCDHWc_int8.cuda")
def conv3d_NCDHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
    """Convolution operator in NCDHW[x]c layout for int8."""

    # print("conv3d_NCDHWc_int8")

    assert layout in ["NCDHW", "NCDHW4c"]

    ic_block_factor = 4
    oc_block_factor = 4

    pre_computed = len(kernel.shape) == 7
    if not pre_computed:
        batch, channels, depth, height, width = get_const_tuple(data.shape)
        assert (
            channels % ic_block_factor == 0
        ), "Number of input channels should be multiple of {}".format(ic_block_factor)
        packed_data = te.compute(
            (batch, channels // ic_block_factor, depth, height, width, ic_block_factor),
            lambda n, c, d, h, w, vc: data[n, c * ic_block_factor + vc, d, h, w],
            name="packed_data",
        )

        out_channels, in_channels, kernel_d, kernel_h, kernel_w = get_const_tuple(kernel.shape)
        assert out_channels % 4 == 0, "Number of output channels should be multiple of {}".format(
            oc_block_factor
        )
        packed_kernel = te.compute(
            (
                out_channels // oc_block_factor,
                in_channels // ic_block_factor,
                kernel_d,
                kernel_h,
                kernel_w,
                oc_block_factor,
                ic_block_factor,
            ),
            lambda oc_chunk, ic_chunk, kd, kh, kw, oc_block, ic_block: kernel[
                oc_chunk * oc_block_factor + oc_block,
                ic_chunk * ic_block_factor + ic_block,
                kd,
                kh,
                kw,
            ],
            name="packed_kernel",
        )

    else:
        packed_data = data
        packed_kernel = kernel

    batch, ic_chunk, in_depth, in_height, in_width, ic_block = get_const_tuple(packed_data.shape)
    oc_chunk, ic_chunk, kernel_d, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
        packed_kernel.shape
    )
    assert isinstance(stride, int) or len(stride) == 3
    assert isinstance(dilation, int) or len(dilation) == 3

    if isinstance(stride, int):
        stride_d = stride_h = stride_w = stride
    else:
        stride_d, stride_h, stride_w = stride

    if isinstance(dilation, int):
        dilation_d = dilation_h = dilation_w = dilation
    else:
        dilation_d, dilation_h, dilation_w = dilation

    # # compute the output shape

    pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
        padding, (kernel_d, kernel_h, kernel_w)
    )
    # out_channel = num_filter
    out_depth = (in_depth - kernel_d + pad_front + pad_back) // stride_d + 1
    out_height = (in_height - kernel_h + pad_top + pad_down) // stride_h + 1
    out_width = (in_width - kernel_w + pad_left + pad_right) // stride_w + 1

    oshape = (batch, oc_chunk, out_depth, out_height, out_width, oc_block)
    # compute graph
    pad_before = [0, 0, pad_front, pad_top, pad_left, 0]
    pad_after = [0, 0, pad_back, pad_down, pad_right, 0]
    pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")

    icc = te.reduce_axis((0, ic_chunk), name="ic_chunk")
    icb = te.reduce_axis((0, ic_block), name="ic_block")
    rz = te.reduce_axis((0, kernel_d), name="rz")
    ry = te.reduce_axis((0, kernel_h), name="ry")
    rx = te.reduce_axis((0, kernel_w), name="rx")

    conv = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: te.sum(
            pad_data[
                nn,
                icc,
                zz * stride_d + rz * dilation_d,
                yy * stride_h + ry * dilation_h,
                xx * stride_w + rx * dilation_w,
                icb,
            ].astype("int32")
            * packed_kernel[oc_chunk, icc, rz, ry, rx, oc_block, icb].astype("int32"),
            axis=[icc, rz, ry, rx, icb],
        ),
    )


    output = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: conv[nn, oc_chunk, zz, yy, xx, oc_block].astype(
            out_dtype
        ),
        tag="conv3d_NCDHWc_int8",
    )

    # num flop
    num_flop = (
        batch
        * oc_chunk
        * oc_block
        * out_height
        * out_width
        * ic_chunk
        * ic_block
        * kernel_d
        * kernel_h
        * kernel_w
        * 2
    )
    cfg.add_flop(num_flop)

    return output


_dp4a = dp4a("shared", "shared", "local")


@autotvm.register_topi_schedule("conv3d_NCDHWc_int8.cuda")
def schedule_conv3d_NCDHWc_int8(cfg, outs):
    """Schedule conv3d int8 NCDHWc template"""
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])

    def _callback(op):
        if op.tag == "conv3d_NCDHWc_int8":
            _schedule_conv3d_NCDHWc_int8(cfg, s, op.output(0), "NCDHW", "conv3d_NCDHWc_int8.cuda")

    traverse_inline(s, outs[0].op, _callback)
    return s


def _schedule_conv3d_NCDHWc_int8(cfg, s, output, layout, workload_name):

    conv = output.op.input_tensors[0]
    packed_data, packed_kernel = conv.op.input_tensors

    if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag:
        pad_data = packed_data
        packed_data = pad_data.op.input_tensors[0]
    else:
        pad_data = packed_data

    if autotvm.GLOBAL_SCOPE.in_tuning:
        # skip this part during tuning to make recrods accurate
        # this part will be pre-computed during NNVM's pre-compute optimization pass
        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
        s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
    else:
        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel":
            # data and kernel are not pre-computed, schedule layout transform here
            schedule_injective_from_existing(s, packed_data)
            schedule_injective_from_existing(s, packed_kernel)
    if pad_data != packed_data:
        s[pad_data].compute_inline()

    AA = s.cache_read(pad_data, "shared", [conv])
    WW = s.cache_read(packed_kernel, "shared", [conv])

    s[conv].set_scope("local")

    # handle bias
    if output.op not in s.outputs:
        s[output].compute_inline()
        output = s.outputs[0].output(0)

    # tile and bind spatial axes
    if len(s[output].op.axis) == 6:
        n, f, d, y, x, c = s[output].op.axis
    else:
        # For task extraction of auto-tuning, the expected output is 4D.  Since auto-tuning tasks
        # are created from scratch, therefore the real auto-tuning will still happen on 5D output.
        n, f, d, y, x = s[output].op.axis

    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
    cfg.define_split("tile_d", cfg.axis(d), num_outputs=4)
    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

    kernel_scope, n = s[output].split(n, nparts=1)

    # bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    bd, vd, td, di = cfg["tile_d"].apply(s, output, d)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bf, bd, by, bx, vf, vd, vy, vx, tf, td, ty, tx, fi, di, yi, xi)

    bf = s[output].fuse(n, bf)

    s[output].bind(bf, te.thread_axis("blockIdx.z"))
    s[output].bind(bd, te.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vd, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tf, te.thread_axis("threadIdx.z"))
        s[output].bind(td, te.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_f"].size[2]
        n_ty = cfg["tile_d"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
    else:
        s[output].bind(s[output].fuse(tf, td), te.thread_axis("threadIdx.z"))
        s[output].bind(ty, te.thread_axis("threadIdx.y"))
        s[output].bind(tx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_d"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile reduction axes
    n, f, d, y, x, c = s[conv].op.axis
    rc, rd, ry, rx, rc_block = s[conv].op.reduce_axis

    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
    cfg.define_split("tile_rd", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
    rdo, rdi = cfg["tile_rd"].apply(s, conv, rd)
    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
    s[conv].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x, c, rc_block)

    cfg.define_reorder("reorder_inner", [rco, rdo, ryo, rxo], policy="all")
    cfg["reorder_inner"].apply(s, conv, [rco, rdo, ryo, rxo])
    cfg["reorder_inner"].apply(s, conv, [rci, rdi, ryi, rxi])

    _, rc_block = s[conv].split(rc_block, factor=4)
    s[conv].tensorize(rc_block, _dp4a)

    cache_loc = [rco, rdo, ryo, rxo][cfg["reorder_inner"].perm[-1]]
    s[AA].compute_at(s[conv], cache_loc)
    s[WW].compute_at(s[conv], cache_loc)

    # # cooperative fetching
    for load in [AA, WW]:

        c = s[load].op.axis[-1]
        c_outer, c = s[load].split(c, factor=4)
        s[load].vectorize(c)
        fused = s[load].op.axis[:-1] + [c_outer]
        fused = s[load].fuse(*fused)
        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)

    return s

```
In the file ```python/tvm/relay/op/strategy/cuda.py```,  I linked the new implementation to conv3d_strategy_cuda as following:

```

@conv3d_strategy.register(["cuda", "gpu"])
def conv3d_strategy_cuda(attrs, inputs, out_type, target):
    """conv3d cuda strategy"""
    strategy = _op.OpStrategy()
    data, kernel = inputs
    layout = attrs.data_layout
    kernel_layout = attrs.kernel_layout
    _, stride_h, stride_w = attrs.get_int_tuple("strides")
    _, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
    assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
    if layout == "NCDHW":

        if attrs.groups == 1:
            assert kernel_layout == "OIDHW"
            if data.dtype in ("int8", "uint8") and kernel.dtype in ("int8", "uint8"):
                assert data.dtype == kernel.dtype

                strategy.add_implementation(
                    wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_int8),
                    wrap_topi_schedule(topi.cuda.schedule_conv3d_NCDHWc_int8),
                    name="conv3d_ncdhw_int8.cuda",
                )
            else:
                strategy.add_implementation(
                    wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
                    wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
                    name="conv3d_ncdhw.cuda",
                    plevel=10,
                )
            _, _, _, kh, kw = get_const_tuple(kernel.shape)
            if (
                2 < kh < 8
                and 2 < kw < 8
                and kh == kw
                and stride_h == 1
                and stride_w == 1
                and dilation_h == 1
                and dilation_w == 1
            ):
                strategy.add_implementation(
                    wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd),
                    wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
                    name="conv3d_ncdhw_winograd.cuda",
                    plevel=5,
                )
       

    else:  # layout == "NDHWC":
        strategy.add_implementation(
            wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
            wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
            name="conv3d_ndhwc.cuda",
            plevel=10,
        )
        N, _, _, _, _ = get_const_tuple(data.shape)
        _, _, _, CI, CO = get_const_tuple(kernel.shape)
        if target.kind.name == "cuda":
            if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                if (
                    (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
                    or (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0)
                    or (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
                ):
                    strategy.add_implementation(
                        wrap_compute_conv3d(topi.cuda.conv3d_ndhwc_tensorcore),
                        wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc_tensorcore),
                        name="conv3d_ndhwc_tensorcore.cuda",
                        plevel=20,
                    )

    if target.kind.name == "cuda" and "cudnn" in target.libs:
        strategy.add_implementation(
            wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
            wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
            name="conv3d_cudnn.cuda",
            plevel=25,
        )
    return strategy

```

In the file ```python/tvm/relay/quantize/_annotate.py```, I defined new annotation such as:

```
@register_annotate_function("nn.contrib_conv3d_NCDHWc")
def conv3d_ncdhwc_rewrite(ref_call, new_args, ctx):
    warnings.warn(
        "NCDHWc layout Conv3D detected, please use a lower "
        "optimization level before applying the quantization "
        "pass as quantization will have no effect here..."
    )


@register_annotate_function("nn.conv3d")
def conv3d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for conv2d. Lhs of conv will be quantized to
    input field, and rhs of conv will be quantized to weight field.
    Output would be in activation field"""

    if quantize_context().check_to_skip(ref_call):

        return None

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:

        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

    assert rhs_kind is None
    rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

    expr = _forward_op(ref_call, [lhs_expr, rhs_expr])

    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
```

I also registed a new partition function in ```python/tvm/relay/quantize/_partition.py```

```
@register_partition_function("nn.conv3d")
def conv3d_partition_function(ref_call, new_args, ctx):
    """Rewrite function for conv3d for partition"""
    data_cond, data = partition_expr_check(new_args[0])
    kernel_cond, kernel = partition_expr_check(new_args[1])

    assert not kernel_cond
    if data_cond:
        data = new_args[0].realize()
    ret = _forward_op(ref_call, [data, kernel])
    return QPartitionExpr(ret)

```

I also implemented Conv3dRealize:

```
Expr Conv3dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  CHECK_EQ(new_args.size(), 2);
  if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
    return Expr(nullptr);
  }
  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
  CHECK(lhs);
  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
  CHECK(rhs);

  Expr ldata = lhs->data;
  if (lhs->dtype != cfg->dtype_input) {
    ldata = Cast(ldata, cfg->dtype_input);
  }
  Expr rdata = Cast(rhs->data, cfg->dtype_weight);

  const auto ref_attrs = ref_call->attrs.as<Conv3DAttrs>();
  auto attrs = make_object<Conv3DAttrs>();
  *attrs = *ref_attrs;
  DataType out_dtype = cfg->dtype_activation;
  attrs->out_dtype = out_dtype;

  Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
  Expr dom_scale = FoldConstantOpt(mul);
  return QRealizeIntExpr(ret, dom_scale, out_dtype);
}

RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv3dRealize);
```

As mentionned previously, the int8 based 3d convolution alone gives the right result and can be optimized by the auto-tuning module of tvm. However, during the compilation phase, I "often" encounter the error mentioned above. I figured out that depending on the optimization found by the automatic tuner, the above mentioned error may or may not occur. I don't know how to solve this issue.





---
[Visit Topic](https://discuss.tvm.apache.org/t/quantization-and-3d-convolution/8338/2) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/8fb2784a0e4e9ce13180935a8bd43fe433446ba4779450faec7aa2a0dc1b0e1f).