You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/08/01 13:29:54 UTC

[tvm] branch unity updated: [Unity][Dlight] Avoid too large vectorization factor in caching (#15443)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 72bd41d3e8 [Unity][Dlight] Avoid too large vectorization factor in caching (#15443)
72bd41d3e8 is described below

commit 72bd41d3e8e0564dc68cb346ab8f4cff487196b3
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Tue Aug 1 06:29:48 2023 -0700

    [Unity][Dlight] Avoid too large vectorization factor in caching (#15443)
    
    * [Unity][Dlight] Avoid too large vectorization factor in caching
    
    * Update gemv.py
---
 python/tvm/dlight/gpu/gemv.py        | 9 ++++++---
 tests/python/dlight/test_gpu_gemv.py | 5 ++---
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index 0d0e4845d4..4c11aa7780 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -241,9 +241,12 @@ class GEMV(ScheduleRule):
                 cache = sch.cache_read(rf, index, "shared")
                 sch.compute_at(cache, unit, preserve_unit_loops=True)
                 fused = sch.fuse(*sch.get_loops(cache)[5:])
-                _, _ty, _tx, _vec = sch.split(
-                    fused, [None, len_ty, len_tx, vec_bytes // type_bytes]
-                )
+                loop: tir.For = sch.get(fused)
+                vec_length = vec_bytes // type_bytes
+                if isinstance(loop.extent, tir.IntImm):
+                    # avoid introducing predicates when vector length is too large
+                    vec_length = min(loop.extent // len_ty // len_tx, vec_length)
+                _, _ty, _tx, _vec = sch.split(fused, [None, len_ty, len_tx, vec_length])
                 sch.bind(_ty, "threadIdx.y")
                 sch.bind(_tx, "threadIdx.x")
                 sch.vectorize(_vec)
diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py
index 82648fb867..fb6315f802 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -101,13 +101,12 @@ class TestGEMV(BaseBeforeAfter):
                             for ax0_ax1_ax2_ax3_fused_0 in range(1):
                                 for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(1, thread="threadIdx.y"):
                                     for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
-                                        for ax0_ax1_ax2_ax3_fused_3 in T.vectorized(8):
+                                        for ax0_ax1_ax2_ax3_fused_3 in T.vectorized(4):
                                             with T.block("lv1637_shared"):
                                                 v0 = T.axis.spatial(1, 0)
                                                 v1 = T.axis.spatial(32, ax0_fused)
                                                 v2 = T.axis.spatial(1, 0)
-                                                v3 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused_0 * 256 + ax0_ax1_ax2_ax3_fused_1 * 256 + ax0_ax1_ax2_ax3_fused_2 * 8 + ax0_ax1_ax2_ax3_fused_3)
-                                                T.where(((ax0_ax1_ax2_ax3_fused_0 + ax0_ax1_ax2_ax3_fused_1) * 32 + ax0_ax1_ax2_ax3_fused_2) * 8 + ax0_ax1_ax2_ax3_fused_3 < 128)
+                                                v3 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 128 + ax0_ax1_ax2_ax3_fused_2 * 4 + ax0_ax1_ax2_ax3_fused_3)
                                                 T.reads(lv1637[v0, v1, v2, v3])
                                                 T.writes(lv1637_shared[v0, v1, v2, v3])
                                                 lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3]