You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2020/11/08 04:55:01 UTC

[incubator-tvm] branch main updated: More flexible conv2d_NCHWc_int8 generic operator. (#6714)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5e9ddeb  More flexible conv2d_NCHWc_int8 generic operator. (#6714)
5e9ddeb is described below

commit 5e9ddebc24472258ba1e290d8c621b93c01be47b
Author: Balint Cristian <cr...@gmail.com>
AuthorDate: Sun Nov 8 06:54:45 2020 +0200

    More flexible conv2d_NCHWc_int8 generic operator. (#6714)
---
 python/tvm/topi/generic/conv2d.py | 12 ++++++------
 python/tvm/topi/nn/conv2d.py      |  7 ++++---
 2 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py
index f23cff3..7dd9aed 100644
--- a/python/tvm/topi/generic/conv2d.py
+++ b/python/tvm/topi/generic/conv2d.py
@@ -51,7 +51,7 @@ def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
         num_int8_elements,
     )
 
-    oc_bn = int32_lanes
+    oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements
     ic_bn = 1
     for bn in range(oc_bn, 0, -4):
         if wkl.in_filter % bn == 0:
@@ -99,7 +99,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
         num_int8_elements,
     )
 
-    oc_bn = int32_lanes
+    oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements
     ic_bn = 1
     for bn in range(oc_bn, 0, -4):
         if wkl.in_filter % bn == 0:
@@ -119,7 +119,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
 
 
 def schedule_conv_NCHWc_cpu_common_int8(
-    s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None
+    s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, int8_elems=4, intrin=None
 ):
     """
     Defines the schedule for INT8 for Intel and ARM machines
@@ -180,7 +180,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
     ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
 
     assert oc_bn % int32_lanes == 0
-    assert ic_bn % 4 == 0  # 4 (u)int8 elements in (u)int32
+    assert ic_bn % int8_elems == 0  # (u)int8 elements in (u)int32
 
     oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
 
@@ -245,7 +245,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
 
 
 def schedule_conv_NCHWc_cpu_1x1_int8(
-    s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None
+    s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, int8_elems=4, intrin=None
 ):
     """
     Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
@@ -305,7 +305,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(
     kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
 
     assert oc_bn % int32_lanes == 0
-    assert ic_bn % 4 == 0  # 4 (u)int8 elements in (u)int32
+    assert ic_bn % int8_elems == 0  # (u)int8 elements in (u)int32
 
     oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
 
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 2e147fc..cd10c75 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -505,7 +505,7 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
 
 
 def conv2d_NCHWc_int8(
-    data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32"
+    data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32", n_elems=4
 ):
     """Conv2D operator for nChw[x]c layout.
 
@@ -539,6 +539,9 @@ def conv2d_NCHWc_int8(
     out_dtype : str
         output data type
 
+    n_elems : int
+        numer of int8 elements accumulated
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -588,7 +591,6 @@ def conv2d_NCHWc_int8(
     kw = te.reduce_axis((0, kernel_width), name="kw")
 
     if groups == 1:
-        n_elems = 4
         ic_outer = te.reduce_axis((0, in_channel // ic_bn), name="ic_outer")
         ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")
         ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner")
@@ -611,7 +613,6 @@ def conv2d_NCHWc_int8(
             tag="conv2d_NCHWc_int8",
         )
     # for int8 group conv support
-    n_elems = 4
     ic_chunk = in_channel // ic_bn
     ic_outer = te.reduce_axis((0, ic_chunk // groups), name="ic_outer")
     ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")