You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/01/03 06:58:04 UTC

[tvm] branch main updated: [TOPI] Expose mem_scope from generic conv2d variants to be more reusable (#13680)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e5a7f5fb5f [TOPI] Expose mem_scope from generic conv2d variants to be more reusable (#13680)
e5a7f5fb5f is described below

commit e5a7f5fb5f3d503955daad59b1903cafa8f647ad
Author: Balint Cristian <cr...@gmail.com>
AuthorDate: Tue Jan 3 08:57:56 2023 +0200

    [TOPI] Expose mem_scope from generic conv2d variants to be more reusable (#13680)
    
    Expose mem_scope from generic conv2d variants to be more reusable
---
 python/tvm/topi/generic/conv2d.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py
index 76cd9a7d69..a4a37247c8 100644
--- a/python/tvm/topi/generic/conv2d.py
+++ b/python/tvm/topi/generic/conv2d.py
@@ -132,6 +132,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
     int8_elems=4,
     intrin=None,
     inline_fused=True,
+    mem_scope="global",
 ):
     """
     Defines the schedule for INT8 for Intel and ARM machines
@@ -186,7 +187,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
 
     # schedule 5-D NCHW[x]c conv
     C, O = conv_out, last
-    CC = s.cache_write(C, "global")
+    CC = s.cache_write(C, mem_scope)
 
     batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
     ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
@@ -279,6 +280,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(
     int8_elems=4,
     intrin=None,
     inline_fused=False,
+    mem_scope="global",
 ):
     """
     Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
@@ -323,7 +325,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(
         s[kernel_vec].parallel(parallel_axis)
 
     C, O = conv_out, last
-    CC = s.cache_write(C, "global")
+    CC = s.cache_write(C, mem_scope)
 
     batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
     oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)