You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2022/09/15 22:28:37 UTC

[tvm] 02/02: gemm use any combination

This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch micro/dense_dsp_smaller_config
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 1915c412a831a3a1df4b3819df41e86502a533a3
Author: Mehrdad Hessar <mh...@octoml.ai>
AuthorDate: Thu Sep 15 15:28:03 2022 -0700

    gemm use any combination
---
 python/tvm/topi/arm_cpu/mprofile/dsp/dense.py             | 2 +-
 python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/dense.py b/python/tvm/topi/arm_cpu/mprofile/dsp/dense.py
index 5630636c92..cecb461874 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/dense.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/dense.py
@@ -32,7 +32,7 @@ def dense_dsp_compute(cfg, data, weight, bias=None, out_dtype=None):
     M, K = get_const_tuple(data.shape)
     N, _ = get_const_tuple(weight.shape)
 
-    factor = 2
+    factor = 16
     # import pdb; pdb.set_trace()
     if M % factor == 0:
         cfg.define_split("tile_x", M, policy="factors", num_outputs=2, filter=lambda x: x.size[-1] % factor == 0)
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
index ffc48eaabd..41f22516ba 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
@@ -207,7 +207,7 @@ __STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
   int16_t bb_pad[{bb_pad_size}];
   int32_t retcode = 0;
 
-  if ( {M} < 16 || {N} < 16 ) {{
+  if ( {M} < 16 && {N} < 16 ) {{
     retcode = gemm_{M}x{K}x{N}_body_loop_{uniq_id}(aa, bb, cc, A_stride, B_stride, C_stride);
     goto out;
   }}
@@ -313,7 +313,7 @@ __STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
   int16_t bb_pad[{bb_pad_size}];
   int32_t retcode = 0;
 
-  if ( {M} < 16 || {N} < 16 ) {{
+  if ( {M} < 16 && {N} < 16 ) {{
     retcode = gemm_{M}x{K}x{N}_update_loop_{uniq_id}(aa, bb, cc, A_stride, B_stride, C_stride);
     goto out;
   }}
@@ -393,7 +393,7 @@ __STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
     int A_stride, int B_stride, int C_stride) {{
   int32_t retcode = 0;
 
-  if ( {M} < 2 || {N} < 2 ) {{
+  if ( {M} < 2 && {N} < 2 ) {{
     retcode = gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(aa, bb, cc, A_stride, B_stride, C_stride);
     goto out;
   }}
@@ -471,7 +471,7 @@ __STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
     int A_stride, int B_stride, int C_stride) {{
   int32_t retcode = 0;
 
-  if ( {M} < 2 || {N} < 2 ) {{
+  if ( {M} < 2 && {N} < 2 ) {{
     retcode = gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(aa, bb, cc, A_stride, B_stride, C_stride);
     goto out;
   }}