You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/07/04 00:09:33 UTC
[tvm] branch unity updated: [Unity] Fix dlight reduction rule (#15194)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 11db81effd [Unity] Fix dlight reduction rule (#15194)
11db81effd is described below
commit 11db81effdc000c319661356b06ece115a91bad4
Author: Bohan Hou <sp...@gmail.com>
AuthorDate: Mon Jul 3 17:09:27 2023 -0700
[Unity] Fix dlight reduction rule (#15194)
Fix corner cases for dlight reduction
---
python/tvm/dlight/gpu/reduction.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py
index bfca76546f..8dcfa36697 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -32,6 +32,9 @@ class Reduction(ScheduleRule):
target: Target,
_: bool,
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+ if not isinstance(func, tir.PrimFunc):
+ return None
+
if target.kind.name == "cuda":
len_tx = 256
unroll_depth = 256
@@ -42,18 +45,20 @@ class Reduction(ScheduleRule):
sch = tir.Schedule(func)
block_infos = normalize_prim_func(sch)
block_infos = try_inline_contiguous_spatial(sch, block_infos)
- assert len(block_infos) > 0
+ if block_infos is None or len(block_infos) == 0:
+ return None
dom_kind = block_infos[0].dom_kind()
num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S"))
num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R"))
try:
+ # TODO: fix num_leading_s = 0 case
+ assert num_trailing_r > 0
for block in block_infos[1:-1]:
assert block.dom_kind() == dom_kind
assert block_infos[-1].is_injective()
assert len(block_infos[-1].dom_kind()) == len(dom_kind)
except AssertionError:
- print("Mismatch")
return None
loops = sch.get_loops(block_infos[-1].block_rv)