You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "yzh119 (via GitHub)" <gi...@apache.org> on 2023/04/10 20:35:38 UTC

[GitHub] [tvm] yzh119 commented on a diff in pull request #14451: [TIR] Fix compute inline for dynamic loop and nontrivial store indices

yzh119 commented on code in PR #14451:
URL: https://github.com/apache/tvm/pull/14451#discussion_r1162065838


##########
tests/python/unittest/test_tir_schedule_compute_inline.py:
##########
@@ -894,6 +894,98 @@ def test_compute_inline_multi_consumer(use_block_name):
     verify_trace_roundtrip(sch=sch, mod=elementwise_multi_producer_consumer)
 
 
+def test_compute_inline_layout_transformed_store():
+    @T.prim_func
+    def before(X: T.Buffer[(9, 96), "float32"]):
+        A = T.alloc_buffer([6, 9, 16], "float32")
+        B = T.alloc_buffer([3, 3, 96], "float32")
+        for i, j in T.grid(9, 96):
+            with T.block("producer"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                A[vj // 16, vi, vj % 16] = X[vi, vj]
+        for i0, i1, j in T.grid(3, 3, 96):
+            with T.block("consumer"):
+                vi0, vi1, vj = T.axis.remap("SSS", [i0, i1, j])
+                B[vi0, vi1, vj] = A[vj % 16, vi0 * 3 + vi1, vj // 16] + 1.0
+
+    @T.prim_func
+    def after(X: T.Buffer[(9, 96), "float32"]):
+        B = T.alloc_buffer([3, 3, 96], "float32")
+        for i0, i1, j in T.grid(3, 3, 96):
+            with T.block("consumer"):
+                vi0, vi1, vj = T.axis.remap("SSS", [i0, i1, j])
+                B[vi0, vi1, vj] = X[vi0 * 3 + vi1, vj // 16 + vj % 16 * 16] + 1.0
+
+    sch = tir.Schedule(before, debug_mask="all")
+    sch.compute_inline(sch.get_block("producer"))
+    tvm.ir.assert_structural_equal(after, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+
+def test_compute_inline_out_of_bound_consumer():
+    """The case is intentionally left for when the producer region
+    do not cover the consumers. Though the out of bound region values
+    are not defined, currently the behavior is still generating inline
+    computation with the rule of producer.
+    """
+
+    @T.prim_func
+    def before():
+        A = T.alloc_buffer([8], "int32")
+        B = T.alloc_buffer([10], "int32")
+        for i in range(8):
+            with T.block("producer"):
+                vi = T.axis.remap("S", [i])
+                A[vi] = vi
+        for i in range(10):
+            with T.block("consumer"):
+                vi = T.axis.remap("S", [i])
+                B[vi] = A[vi] + 1
+
+    @T.prim_func
+    def after():
+        B = T.alloc_buffer([10], "int32")
+        for i in range(10):
+            with T.block("consumer"):
+                vi = T.axis.remap("S", [i])
+                B[vi] = vi + 1
+
+    sch = tir.Schedule(before, debug_mask="all")
+    sch.compute_inline(sch.get_block("producer"))
+    tvm.ir.assert_structural_equal(after, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+
+def test_compute_inline_dynamic_shape_producer():
+    @T.prim_func
+    def before(a: T.handle, c: T.handle, m: T.int32, n: T.int32):
+        A = T.match_buffer(a, (m, n))
+        B = T.alloc_buffer((m, n))
+        C = T.match_buffer(c, (m, n))
+        for i, j in T.grid(m, n):
+            with T.block("B"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                B[vi, vj] = A[vi, vj] * 2.0
+        for i, j in T.grid(m, n):
+            with T.block("C"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                C[vi, vj] = B[vi, vj] + 1.0
+
+    @T.prim_func
+    def after(a: T.handle, c: T.handle, m: T.int32, n: T.int32):
+        A = T.match_buffer(a, (m, n))
+        C = T.match_buffer(c, (m, n))
+        for i, j in T.grid(m, n):
+            with T.block("C"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                C[vi, vj] = A[vi, vj] * 2.0 + 1.0
+
+    sch = tir.Schedule(before, debug_mask="all")
+    sch.compute_inline(sch.get_block("B"))
+    tvm.ir.assert_structural_equal(after, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+

Review Comment:
   Adding another "nested block" case with variable loop extents in outer blocks is better.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org