You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/07/04 00:09:33 UTC

[tvm] branch unity updated: [Unity] Fix dlight reduction rule (#15194)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 11db81effd [Unity] Fix dlight reduction rule (#15194)
11db81effd is described below

commit 11db81effdc000c319661356b06ece115a91bad4
Author: Bohan Hou <sp...@gmail.com>
AuthorDate: Mon Jul 3 17:09:27 2023 -0700

    [Unity] Fix dlight reduction rule (#15194)
    
    Fix corner cases for dlight reduction
---
 python/tvm/dlight/gpu/reduction.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py
index bfca76546f..8dcfa36697 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -32,6 +32,9 @@ class Reduction(ScheduleRule):
         target: Target,
         _: bool,
     ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+        if not isinstance(func, tir.PrimFunc):
+            return None
+
         if target.kind.name == "cuda":
             len_tx = 256
             unroll_depth = 256
@@ -42,18 +45,20 @@ class Reduction(ScheduleRule):
         sch = tir.Schedule(func)
         block_infos = normalize_prim_func(sch)
         block_infos = try_inline_contiguous_spatial(sch, block_infos)
-        assert len(block_infos) > 0
+        if block_infos is None or len(block_infos) == 0:
+            return None
 
         dom_kind = block_infos[0].dom_kind()
         num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S"))
         num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R"))
         try:
+            # TODO: fix num_leading_s = 0 case
+            assert num_trailing_r > 0
             for block in block_infos[1:-1]:
                 assert block.dom_kind() == dom_kind
             assert block_infos[-1].is_injective()
             assert len(block_infos[-1].dom_kind()) == len(dom_kind)
         except AssertionError:
-            print("Mismatch")
             return None
 
         loops = sch.get_loops(block_infos[-1].block_rv)