You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2020/11/08 04:55:01 UTC
[incubator-tvm] branch main updated: More flexible
conv2d_NCHWc_int8 generic operator. (#6714)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5e9ddeb More flexible conv2d_NCHWc_int8 generic operator. (#6714)
5e9ddeb is described below
commit 5e9ddebc24472258ba1e290d8c621b93c01be47b
Author: Balint Cristian <cr...@gmail.com>
AuthorDate: Sun Nov 8 06:54:45 2020 +0200
More flexible conv2d_NCHWc_int8 generic operator. (#6714)
---
python/tvm/topi/generic/conv2d.py | 12 ++++++------
python/tvm/topi/nn/conv2d.py | 7 ++++---
2 files changed, 10 insertions(+), 9 deletions(-)
diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py
index f23cff3..7dd9aed 100644
--- a/python/tvm/topi/generic/conv2d.py
+++ b/python/tvm/topi/generic/conv2d.py
@@ -51,7 +51,7 @@ def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
num_int8_elements,
)
- oc_bn = int32_lanes
+ oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
@@ -99,7 +99,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
num_int8_elements,
)
- oc_bn = int32_lanes
+ oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
@@ -119,7 +119,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
def schedule_conv_NCHWc_cpu_common_int8(
- s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None
+ s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, int8_elems=4, intrin=None
):
"""
Defines the schedule for INT8 for Intel and ARM machines
@@ -180,7 +180,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
assert oc_bn % int32_lanes == 0
- assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32
+ assert ic_bn % int8_elems == 0 # (u)int8 elements in (u)int32
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
@@ -245,7 +245,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
def schedule_conv_NCHWc_cpu_1x1_int8(
- s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None
+ s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, int8_elems=4, intrin=None
):
"""
Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
@@ -305,7 +305,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
assert oc_bn % int32_lanes == 0
- assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32
+ assert ic_bn % int8_elems == 0 # (u)int8 elements in (u)int32
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 2e147fc..cd10c75 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -505,7 +505,7 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
def conv2d_NCHWc_int8(
- data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32"
+ data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32", n_elems=4
):
"""Conv2D operator for nChw[x]c layout.
@@ -539,6 +539,9 @@ def conv2d_NCHWc_int8(
out_dtype : str
output data type
+ n_elems : int
+ numer of int8 elements accumulated
+
Returns
-------
output : tvm.te.Tensor
@@ -588,7 +591,6 @@ def conv2d_NCHWc_int8(
kw = te.reduce_axis((0, kernel_width), name="kw")
if groups == 1:
- n_elems = 4
ic_outer = te.reduce_axis((0, in_channel // ic_bn), name="ic_outer")
ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")
ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner")
@@ -611,7 +613,6 @@ def conv2d_NCHWc_int8(
tag="conv2d_NCHWc_int8",
)
# for int8 group conv support
- n_elems = 4
ic_chunk = in_channel // ic_bn
ic_outer = te.reduce_axis((0, ic_chunk // groups), name="ic_outer")
ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")