You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/06 21:27:58 UTC

[GitHub] [tvm] Lunderberg opened a new pull request, #12720: [TIR] Implement API for padded layout transformations

Lunderberg opened a new pull request, #12720:
URL: https://github.com/apache/tvm/pull/12720

   Implementation of API in `tvm.tir.schedule` for layout transformations with padding, as part of https://github.com/apache/tvm/issues/12261, item "Insert pad value into generated TIR, using `tir::if_then_else`, `builtin::assume`, and `builtin::undef`".
   
   Following the RFC discussion [here](https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1170294348) and [here](https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1171290053), this commit preferentially rewrites the loops that surround a padded transformation where possible, in order to express padding in terms of `tir::if_then_else`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r970059913


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
 
 
 def shared_32x16_to_ldmatrix_32x16_layout(i, j):
-    thread_id = (i % 4) + 4 * (j % 8)
+    thread_id = (i % 16) // 4 + 4 * (j % 8)

Review Comment:
   Thank you for tagging @masahi, I had forgotten to do so.  I *think* I have it set up correctly, based on Nvidia documentation and similarity to the (16,32) shape, but couldn't verify definitively.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r972234799


##########
python/tvm/tir/function.py:
##########
@@ -389,17 +389,27 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] =
 
         final_indices = []
         axis_separators = []
-        for val in mapping:
-            if isinstance(val, tvm.ir.PrimExpr):
-                final_indices.append(val)
-            elif val is IndexMap.AXIS_SEPARATOR:
-                axis_separators.append(len(final_indices))
-            else:
-                raise TypeError(
-                    "Expected mapping function to return list of "
-                    "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR.  "
-                    f"Instead received {val} of type {type(val)}."
-                )
+
+        try:
+            iter(mapping)

Review Comment:
   What's the use case for this? According to the doc the mapping function should return a List, it might also need update



##########
python/tvm/tir/schedule/schedule.py:
##########
@@ -2479,6 +2480,31 @@ def transform_layout(
             primitive will be called in addition to the
             TransformLayout primitive.
 
+        pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]

Review Comment:
   Document the assumption when pad_value is IndexMap. I remember in the RFC we assume it should contain no BufferLoad from buffers except the current buffer



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
masahi commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r970068321


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
 
 
 def shared_32x16_to_ldmatrix_32x16_layout(i, j):
-    thread_id = (i % 4) + 4 * (j % 8)
+    thread_id = (i % 16) // 4 + 4 * (j % 8)

Review Comment:
   ah sorry I was talking about `shared_16x16_to_ldmatrix_32x8_layout`. I need to remember how I came up with shared_32x16_to_ldmatrix_32x16_layout. I think it is used for int8 MMA.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #12720:
URL: https://github.com/apache/tvm/pull/12720#issuecomment-1249797455

   @Hzfengsy Can you review/verify that the requested changes (use non-opaque blocks in unit tests) are made?  I think that's the only item remaining on the PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r965322232


##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -601,9 +601,11 @@ class ScheduleNode : public runtime::Object {
    * \param buffer_index The index of the buffer in block's read or write region.
    * \param buffer_index_type The type of the buffer index, kRead or kWrite.
    * \param index_map The transformation to apply.
+   * \param pad_value The value to write into padding introduced by the transformation.

Review Comment:
   Good call.  Added both here and in the python docstring.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r971016392


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
 
 
 def shared_32x16_to_ldmatrix_32x16_layout(i, j):
-    thread_id = (i % 4) + 4 * (j % 8)
+    thread_id = (i % 16) // 4 + 4 * (j % 8)

Review Comment:
   Thank you for looking into it!  I wasn't able to find any tests that explicitly validate the transform (e.g. use the transform to generate data in a specific layout, then pass through the mma), as all the tests either started with transformed data, only used the 16x16 shape, or replaced everything with the tensor intrinsic.
   
   I had put together [this standalone test](https://gist.github.com/Lunderberg/0c2a44de34e7e2a1d149c37b2a112f91) to convince myself on it.  The main issue with the current index map is that it doesn't map to unique locations (512 input indices map to 128 output indices).  It only arose as an issue in this PR, because it generates the inverse in order to determine whether/where padding is required.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r970041968


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
 
 
 def shared_32x16_to_ldmatrix_32x16_layout(i, j):
-    thread_id = (i % 4) + 4 * (j % 8)
+    thread_id = (i % 16) // 4 + 4 * (j % 8)

Review Comment:
   cc @masahi 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #12720:
URL: https://github.com/apache/tvm/pull/12720#issuecomment-1245512099

   Looks like the final failing unit test is due to an incorrect mapping in `tir.tensor_intrin.cuda.shared_32x16_to_ldmatrix_32x16_layout`.  It currently returns `[(i % 4) + 4 * (j % 8), 8 * (j // 8) + (i // 16) * 4 + i % 4]`, which doesn't fill all 32x16 indices, and fails when attempted to use `NonSurjectiveInverse` on it.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
masahi commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r970067438


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
 
 
 def shared_32x16_to_ldmatrix_32x16_layout(i, j):
-    thread_id = (i % 4) + 4 * (j % 8)
+    thread_id = (i % 16) // 4 + 4 * (j % 8)

Review Comment:
   hmm I think the original mapping is correct, this is from p34 of the slide https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21745-developing-cuda-kernels-to-push-tensor-cores-to-the-absolute-limit-on-nvidia-a100.pdf
   
   Sorry I don't remember the details



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Hzfengsy commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r964422619


##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -601,9 +601,11 @@ class ScheduleNode : public runtime::Object {
    * \param buffer_index The index of the buffer in block's read or write region.
    * \param buffer_index_type The type of the buffer index, kRead or kWrite.
    * \param index_map The transformation to apply.
+   * \param pad_value The value to write into padding introduced by the transformation.
    */
   virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
-                               BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;
+                               BufferIndexType buffer_index_type, const IndexMap& index_map,
+                               const Optional<PrimExpr>& pad_value) = 0;

Review Comment:
   How about adding a default value? 
   ```suggestion
                                  const Optional<PrimExpr>& pad_value = NullOpt) = 0;
   ```



##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -329,5 +329,302 @@ def test_transform_block_layout_fail_mixed_iter_type(use_block_name):
         )
 
 
+class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
+    pad_value = tvm.testing.parameter(None)
+
+    transformed_buffer = tvm.testing.parameter("A")
+
+    @pytest.fixture
+    def transform(self, pad_value, transformed_buffer):
+        def transform(mod):
+            sch = tir.Schedule(mod)
+            sch.transform_layout(
+                "block", transformed_buffer, lambda i: [i // 4, i % 4], pad_value=pad_value
+            )
+            # sch.transform_block_layout("block", lambda i: [i // 4, i % 4])
+            return sch.mod
+
+        return transform
+
+
+class TestNoPadding(BasePaddingCompare):
+    """Transformations without padding do not depend on pad_value."""
+
+    pad_value = tvm.testing.parameter(None, 42)
+
+    def before():

Review Comment:
   It's great if we can support padding opaque blocks. However, non-opaque blocks are the most common cases. 
   
   Could you please change the test cases into non-opaque blocks?
   e.g.
   ```python
   with T.block():
       vi = T.axis.remap("S", [i])
       A[vi] = 0
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r965015023


##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -329,5 +329,302 @@ def test_transform_block_layout_fail_mixed_iter_type(use_block_name):
         )
 
 
+class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
+    pad_value = tvm.testing.parameter(None)
+
+    transformed_buffer = tvm.testing.parameter("A")
+
+    @pytest.fixture
+    def transform(self, pad_value, transformed_buffer):
+        def transform(mod):
+            sch = tir.Schedule(mod)
+            sch.transform_layout(
+                "block", transformed_buffer, lambda i: [i // 4, i % 4], pad_value=pad_value
+            )
+            # sch.transform_block_layout("block", lambda i: [i // 4, i % 4])
+            return sch.mod
+
+        return transform
+
+
+class TestNoPadding(BasePaddingCompare):
+    """Transformations without padding do not depend on pad_value."""
+
+    pad_value = tvm.testing.parameter(None, 42)
+
+    def before():

Review Comment:
   For this specific instance, the test case for non-opaque blocks is lower.  I've re-ordered these test cases so that the non-opaque block tests occur first, and are opaque blocks are the ones called out as being non-standard.
   
   I am in the process of updating all remaining examples in the tests to use non-opaque blocks, though I expect some of the updated tests to fail until https://github.com/apache/tvm/pull/12724 lands.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r973381556


##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -329,5 +329,302 @@ def test_transform_block_layout_fail_mixed_iter_type(use_block_name):
         )
 
 
+class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
+    pad_value = tvm.testing.parameter(None)
+
+    transformed_buffer = tvm.testing.parameter("A")
+
+    @pytest.fixture
+    def transform(self, pad_value, transformed_buffer):
+        def transform(mod):
+            sch = tir.Schedule(mod)
+            sch.transform_layout(
+                "block", transformed_buffer, lambda i: [i // 4, i % 4], pad_value=pad_value
+            )
+            # sch.transform_block_layout("block", lambda i: [i // 4, i % 4])
+            return sch.mod
+
+        return transform
+
+
+class TestNoPadding(BasePaddingCompare):
+    """Transformations without padding do not depend on pad_value."""
+
+    pad_value = tvm.testing.parameter(None, 42)
+
+    def before():

Review Comment:
   Forgot to mark this conversation as resolved earlier.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r968604370


##########
python/tvm/tir/schedule/schedule.py:
##########
@@ -2479,6 +2480,20 @@ def transform_layout(
             primitive will be called in addition to the
             TransformLayout primitive.
 
+        pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]
+
+            The value to be used for any padding introduced by the
+            transformation.
+
+            If None, the transformation may not introduce padding.
+
+            If an int, float or PrimExpr, the transformation is the
+            specific value to be present in the padding.
+
+            If an IndexMap or Callable, the transformation is the
+            value to be present in the padding in terms of the
+            transformed index.

Review Comment:
   Updates made to pass `Optional<IndexMap> pad_value` throughout C++ API, mimicking how `IndexMap index_map` is passed, along with a unit test to validate the functionality.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
masahi commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r970068321


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
 
 
 def shared_32x16_to_ldmatrix_32x16_layout(i, j):
-    thread_id = (i % 4) + 4 * (j % 8)
+    thread_id = (i % 16) // 4 + 4 * (j % 8)

Review Comment:
   ah sorry I was talking about `shared_16x32_to_ldmatrix_32x16_layout`. I need to remember how I came up with shared_32x16_to_ldmatrix_32x16_layout



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r965998079


##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -16,12 +16,580 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+#include <optional>
+#include <variant>
+
 #include "../../../arith/ir_mutator_with_analyzer.h"
 #include "../utils.h"
 
 namespace tvm {
 namespace tir {
 
+class LayoutTransformPlanner : private StmtExprVisitor {

Review Comment:
   Thank you, and documentation added here for the general algorithm, and when each handling of padding may be used.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
masahi commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r970068321


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
 
 
 def shared_32x16_to_ldmatrix_32x16_layout(i, j):
-    thread_id = (i % 4) + 4 * (j % 8)
+    thread_id = (i % 16) // 4 + 4 * (j % 8)

Review Comment:
   ah sorry I was talking about `shared_16x16_to_ldmatrix_32x8_layout`. I need to remember how I came up with shared_32x16_to_ldmatrix_32x16_layout



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg merged pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg merged PR #12720:
URL: https://github.com/apache/tvm/pull/12720


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r965151002


##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -601,9 +601,11 @@ class ScheduleNode : public runtime::Object {
    * \param buffer_index The index of the buffer in block's read or write region.
    * \param buffer_index_type The type of the buffer index, kRead or kWrite.
    * \param index_map The transformation to apply.
+   * \param pad_value The value to write into padding introduced by the transformation.

Review Comment:
   When `pad_value` is incorrect this can affect the correctness of the program. Would be great to explicitly mention this.



##########
python/tvm/tir/schedule/schedule.py:
##########
@@ -2479,6 +2480,20 @@ def transform_layout(
             primitive will be called in addition to the
             TransformLayout primitive.
 
+        pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]
+
+            The value to be used for any padding introduced by the
+            transformation.
+
+            If None, the transformation may not introduce padding.
+
+            If an int, float or PrimExpr, the transformation is the
+            specific value to be present in the padding.
+
+            If an IndexMap or Callable, the transformation is the
+            value to be present in the padding in terms of the
+            transformed index.

Review Comment:
   cpp side only accepts `Optional[PrimExpr]`, seems this is not supported?



##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -16,12 +16,580 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+#include <optional>
+#include <variant>
+
 #include "../../../arith/ir_mutator_with_analyzer.h"
 #include "../utils.h"
 
 namespace tvm {
 namespace tir {
 
+class LayoutTransformPlanner : private StmtExprVisitor {
+ public:
+  // Statement to be inserted prior to the analyzed block
+  struct ProloguePlan {
+    Stmt prologue;
+  };
+
+  // Loops within the analyzed block that should be replaced
+  struct ReplacementPlan {
+    Map<For, Stmt> replacements;
+    Map<Block, Block> block_sref_reuse;
+  };
+
+  // The block to be inserted, along with the location at which it
+  // should be inserted.  The location will be either a For or a
+  // Block, and will be after all writes the transformed buffer.
+  struct EpiloguePlan {
+    Stmt insert_after;
+    Stmt new_block;
+  };
+
+  struct NoPaddingRequired {};
+
+  using TransformPlan =
+      std::variant<ProloguePlan, ReplacementPlan, EpiloguePlan, NoPaddingRequired>;
+
+  static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map,
+                            IndexMap inverse, PrimExpr padding_predicate,
+                            Optional<PrimExpr> pad_value) {
+    LayoutTransformPlanner visitor(old_buffer);
+    visitor(block);
+    return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value);
+  }
+
+ private:
+  explicit LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {}
+
+  void VisitStmt_(const ForNode* op) override {
+    BindLoopVar context(this, GetRef<For>(op));
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const LetStmtNode* op) override {
+    BindLetVar context(this, op->var, op->value);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const BlockRealizeNode* op) override {
+    BindBlockRealize context(this, GetRef<BlockRealize>(op));
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) override {
+    if (!op->buffer.same_as(old_buffer_)) {
+      return;
+    }
+
+    std::optional<std::pair<size_t, size_t>> loop_dependency_range = std::nullopt;
+    for (const auto& index : op->indices) {
+      if (auto index_depth = LoopDependencyRange(index); index_depth.has_value()) {
+        if (loop_dependency_range) {
+          loop_dependency_range = {
+              std::min(loop_dependency_range.value().first, index_depth.value().first),
+              std::max(loop_dependency_range.value().second, index_depth.value().second)};
+        } else {
+          loop_dependency_range = index_depth;
+        }
+      }
+    }
+
+    WriteInfo write_info;
+    write_info.store = GetRef<BufferStore>(op);
+    if (loop_dependency_range) {
+      size_t i = loop_dependency_range.value().first;
+      size_t j = loop_dependency_range.value().second;
+      ICHECK_LT(i, active_loops_.size());
+      ICHECK_LT(j, active_loops_.size());
+
+      write_info.dependent_loopnest = {active_loops_.begin() + i, active_loops_.begin() + j + 1};
+    }
+    write_info.innermost_block_realize = innermost_block_realize_;
+
+    write_info.contains_row_major_traversal = [&]() -> bool {
+      const auto& loopnest = write_info.dependent_loopnest;
+      if (loopnest.empty()) {
+        return false;
+      }
+
+      if (loopnest.size() != old_buffer_->shape.size() || loopnest.size() != op->indices.size()) {
+        return false;
+      }
+
+      for (size_t i = 0; i < loopnest.size(); i++) {
+        const For& loop = loopnest[i];
+        const PrimExpr& buffer_dim = old_buffer_->shape[i];
+        PrimExpr index = Substitute(op->indices[i], active_let_bindings_);
+        bool is_loop_over_axis = index.same_as(loop->loop_var) && is_const_int(loop->min, 0) &&
+                                 ExprDeepEqual()(loop->extent, buffer_dim) &&
+                                 loop->kind == ForKind::kSerial;
+        if (!is_loop_over_axis) {
+          return false;
+        }
+      }
+
+      return true;
+    }();
+
+    write_info_.push_back(write_info);
+
+    // Don't need to continue recursing, as the entire goal was to
+    // find the BufferStore.
+  }
+
+  std::optional<std::pair<size_t, size_t>> LoopDependencyRange(const PrimExpr& expr) const {
+    std::optional<std::pair<size_t, size_t>> prev = std::nullopt;
+    for (const auto& var : UndefinedVars(expr)) {
+      auto it = loop_depth_lookup_.find(var.get());
+      if (it != loop_depth_lookup_.end()) {
+        if (prev.has_value()) {
+          prev = {std::min(prev.value().first, it->second.first),
+                  std::max(prev.value().second, it->second.second)};
+        } else {
+          prev = it->second;
+        }
+      }
+    }
+
+    return prev;
+  }
+
+  class BufferStoreReplacer : public StmtExprMutator {
+   public:
+    BufferStoreReplacer(std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store,
+                        std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
+                            replace_block_realize)
+        : replace_store_(replace_store), replace_block_realize_(replace_block_realize) {}
+
+    Stmt VisitStmt_(const BufferStoreNode* op) final {
+      if (auto replacement = replace_store_(op)) {
+        auto store = Downcast<BufferStore>(replacement.value());
+        return StmtExprMutator::VisitStmt_(store.get());
+      } else {
+        return StmtExprMutator::VisitStmt_(op);
+      }
+    }
+
+    Stmt VisitStmt_(const BlockRealizeNode* op) final {
+      auto realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
+      if (auto replacement = replace_block_realize_(op, realize)) {
+        return replacement.value();
+      } else {
+        return std::move(realize);
+      }
+    }
+
+   private:
+    std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store_;
+    std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
+        replace_block_realize_;
+  };
+
+  TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse,
+                         PrimExpr padding_predicate, Optional<PrimExpr> pad_value) const {
+    if (auto prologue_plan =
+            FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value);
+        prologue_plan.has_value()) {
+      return prologue_plan.value();
+    } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, index_map, inverse,
+                                                               padding_predicate, pad_value);
+               replacement_plan.has_value()) {
+      return replacement_plan.value();
+    } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse,
+                                                         padding_predicate, pad_value);
+               epilogue_plan.has_value()) {
+      return epilogue_plan.value();
+    } else {
+      return NoPaddingRequired();
+    }
+  }
+
+  std::optional<ProloguePlan> FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map,
+                                                   IndexMap inverse, PrimExpr padding_predicate,
+                                                   Optional<PrimExpr> pad_value) const {
+    if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> iter_values;
+    Array<PrimExpr> indices;
+    Map<Var, PrimExpr> loop_indices_to_block_indices;
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+      const auto& loop_var = inverse->initial_indices[i];
+      const auto& dim = new_buffer->shape[i];
+      Var block_var("v_" + loop_var->name_hint, loop_var->dtype);
+      IterVar iter_var(Range(0, dim), block_var, kDataPar);
+      loop_indices_to_block_indices.Set(loop_var, block_var);
+      indices.push_back(iter_var->var);
+      iter_vars.push_back(iter_var);
+      iter_values.push_back(loop_var);
+    }
+    padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices);
+
+    PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value.value());
+    Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr}));
+
+    std::stringstream block_name;
+    block_name << "buffer_" << new_buffer->name << "_assumptions";
+    auto read_region = BufferRegion::FromPoint(new_buffer, indices);
+    stmt = BlockRealize(iter_values, Bool(true),
+                        Block(iter_vars, {read_region}, {}, block_name.str(), stmt));
+
+    for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+      size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+      Var loop_var = inverse->initial_indices[i];
+      PrimExpr extent = new_buffer->shape[i];
+      stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+    }
+    return ProloguePlan{stmt};
+  }
+
+  std::optional<ReplacementPlan> FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map,
+                                                         IndexMap inverse,
+                                                         PrimExpr padding_predicate,
+                                                         Optional<PrimExpr> pad_value) const {
+    if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional<Stmt> {
+      if (!info.contains_row_major_traversal || !pad_value.defined() ||
+          is_zero(padding_predicate)) {
+        return NullOpt;
+      }
+
+      Array<PrimExpr> old_indices = info.store->indices;
+      PrimExpr if_then_else_condition = padding_predicate;
+      Array<PrimExpr> new_indices;
+      for (const auto& var : inverse->initial_indices) {
+        new_indices.push_back(var);
+      }
+
+      auto replace_block_realize =
+          [&]() -> std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)> {
+        auto no_change = [](const BlockRealizeNode*, const BlockRealize&) -> Optional<Stmt> {
+          return NullOpt;
+        };
+        if (!info.innermost_block_realize) {
+          return no_change;
+        }
+        if (old_indices.empty()) {
+          return no_change;
+        }
+
+        BlockRealize block_realize = info.innermost_block_realize.value();
+        const auto& block = block_realize->block;
+
+        // Find the block iterators that are used to access the buffer.  Must be in the same order
+        // as they appear in the indices.
+        if (block->iter_vars.size() < old_indices.size()) {
+          return no_change;
+        }
+        const auto& iter_vars = block->iter_vars;
+        size_t block_index_start = 0;
+        for (; block_index_start < iter_vars.size() - old_indices.size(); block_index_start++) {
+          if (old_indices[0].same_as(iter_vars[block_index_start]->var)) {
+            break;
+          }
+        }
+        if (block_index_start > iter_vars.size() - old_indices.size()) {
+          return no_change;
+        }
+
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          if (!old_indices[i].same_as(iter_vars[block_index_start + i]->var) ||
+              iter_vars[block_index_start + i]->iter_type != kDataPar) {
+            return no_change;
+          }
+        }
+
+        // If we got to this point, all indices used to access the
+        // buffer are virtual indices defined in the innermost block.
+        // Therefore, generate new virtual indices for iterating over
+        // the post-transform buffer.
+        Array<PrimExpr> new_iter_values;             // For BlockRealize
+        Array<IterVar> new_iter_vars;                // For Block
+        Array<PrimExpr> new_access_indices;          // For BufferStore
+        Map<Var, PrimExpr> loop_var_to_virtual_var;  // For updating if_then_else_condition
+
+        for (size_t i = 0; i < block_index_start; i++) {
+          new_iter_vars.push_back(iter_vars[i]);
+          new_iter_values.push_back(block_realize->iter_values[i]);
+        }
+
+        ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+        for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+          Var var = inverse->initial_indices[i];
+          PrimExpr dim = new_buffer->shape[i];
+          std::stringstream ss;
+          ss << "v_" << var->name_hint;
+          Var virtual_var(ss.str(), var.dtype());
+          new_iter_values.push_back(var);
+          new_iter_vars.push_back(IterVar(Range::FromMinExtent(0, dim), virtual_var, kDataPar));
+          new_access_indices.push_back(virtual_var);
+          loop_var_to_virtual_var.Set(var, virtual_var);
+        }
+
+        for (size_t i = block_index_start + old_indices.size(); i < iter_vars.size(); i++) {
+          new_iter_vars.push_back(iter_vars[i]);
+          new_iter_values.push_back(block_realize->iter_values[i]);
+        }
+
+        Map<Var, PrimExpr> old_virtual_var_to_new_virtual_var;
+        ICHECK_EQ(inverse->final_indices.size(), old_indices.size());
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          Var var = Downcast<Var>(old_indices[i]);
+          PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var);
+          old_virtual_var_to_new_virtual_var.Set(var, expr);
+        }
+
+        if_then_else_condition = Substitute(if_then_else_condition, loop_var_to_virtual_var);
+        new_indices = new_access_indices;
+
+        return [target_realize = info.innermost_block_realize, new_iter_vars, new_iter_values,
+                old_virtual_var_to_new_virtual_var](const BlockRealizeNode* op,
+                                                    const BlockRealize& visited) -> Optional<Stmt> {
+          if (op == target_realize.get()) {
+            Block block = visited->block;
+            block =
+                Downcast<Block>(Substitute(std::move(block), old_virtual_var_to_new_virtual_var));
+            block.CopyOnWrite()->iter_vars = new_iter_vars;
+
+            BlockRealize realize = visited;
+            {
+              auto write_ptr = realize.CopyOnWrite();
+              write_ptr->block = block;
+              write_ptr->iter_values = new_iter_values;
+            }
+            return realize;
+          } else {
+            return NullOpt;
+          }
+        };
+      }();
+
+      bool all_stores_replaced = true;
+      auto replace_store = [&](const BufferStoreNode* op) -> Optional<Stmt> {
+        if (!op->buffer.same_as(info.store->buffer)) {
+          all_stores_replaced = false;
+          return NullOpt;
+        }
+        ICHECK_EQ(old_indices.size(), op->indices.size());
+        ExprDeepEqual expr_equal;
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          if (!expr_equal(old_indices[i], op->indices[i])) {
+            all_stores_replaced = false;
+            return NullOpt;
+          }
+        }
+
+        return BufferStore(new_buffer,
+                           if_then_else(if_then_else_condition, pad_value.value(), op->value),
+                           new_indices);
+      };
+
+      BufferStoreReplacer replacer(replace_store, replace_block_realize);
+      Stmt stmt = replacer(info.dependent_loopnest.back()->body);
+      if (!all_stores_replaced) {
+        return NullOpt;
+      }
+
+      std::unordered_map<const VarNode*, PrimExpr> var_remap;
+      ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
+      for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
+        Var var = info.dependent_loopnest[i]->loop_var;
+        PrimExpr expr = inverse->final_indices[i];
+        var_remap[var.get()] = expr;
+      }
+      stmt = Substitute(std::move(stmt), var_remap);
+
+      ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+      for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+        size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+        Var loop_var = inverse->initial_indices[i];
+        PrimExpr extent = new_buffer->shape[i];
+        stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+      }
+
+      return stmt;
+    };
+
+    Map<For, Stmt> loop_replacements;
+
+    for (const auto& info : write_info_) {
+      if (info.dependent_loopnest.size()) {
+        if (auto opt_stmt = generate_if_then_else_block(info)) {
+          loop_replacements.Set(info.dependent_loopnest[0], opt_stmt.value());
+        }
+      }
+    }
+
+    if (loop_replacements.size()) {
+      return ReplacementPlan{std::move(loop_replacements)};
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  std::optional<EpiloguePlan> FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map,
+                                                   IndexMap inverse, PrimExpr padding_predicate,
+                                                   Optional<PrimExpr> pad_value) const {
+    if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> iter_values;
+    Array<PrimExpr> indices;
+    Map<Var, PrimExpr> loop_indices_to_block_indices;
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+      const auto& loop_var = inverse->initial_indices[i];
+      const auto& dim = new_buffer->shape[i];
+      Var block_var("v_" + loop_var->name_hint, loop_var->dtype);
+      IterVar iter_var(Range(0, dim), block_var, kDataPar);
+      loop_indices_to_block_indices.Set(loop_var, block_var);
+      indices.push_back(iter_var->var);
+      iter_vars.push_back(iter_var);
+      iter_values.push_back(loop_var);
+    }
+    padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices);
+
+    Stmt stmt = BufferStore(new_buffer, pad_value.value(), indices);
+
+    std::stringstream block_name;
+    block_name << "buffer_" << new_buffer->name << "_padding";
+    auto write_region = BufferRegion::FromPoint(new_buffer, indices);
+    stmt = BlockRealize(iter_values, padding_predicate,
+                        Block(iter_vars, {}, {write_region}, block_name.str(), stmt));
+
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+      size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+      Var loop_var = inverse->initial_indices[i];
+      PrimExpr extent = new_buffer->shape[i];
+      stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+    }
+
+    const auto& info = write_info_.back();
+    Stmt insert_after = [&]() -> Stmt {
+      if (info.dependent_loopnest.size()) {
+        return info.dependent_loopnest.front();
+      } else if (info.innermost_block_realize) {
+        return info.innermost_block_realize.value();
+      } else {
+        LOG(FATAL) << "Write occured outside of any block/loop";
+        return Stmt();
+      }
+    }();
+    return EpiloguePlan{insert_after, stmt};
+  }
+
+  struct BindLoopVar {
+    BindLoopVar(LayoutTransformPlanner* self, For for_node)
+        : self_(self), var_(for_node->loop_var) {
+      size_t loop_depth = self_->active_loops_.size();
+      self_->loop_depth_lookup_[var_.get()] = {loop_depth, loop_depth};
+      self_->active_loops_.push_back(std::move(for_node));
+    }
+    ~BindLoopVar() {
+      self_->active_loops_.pop_back();
+      self_->loop_depth_lookup_.erase(var_.get());
+    }
+    BindLoopVar(const BindLoopVar&) = delete;
+    BindLoopVar& operator=(const BindLoopVar&) = delete;
+    BindLoopVar(BindLoopVar&&) = delete;
+    BindLoopVar& operator=(BindLoopVar&&) = delete;
+
+    LayoutTransformPlanner* self_{nullptr};
+    Var var_;
+  };
+
+  struct BindLetVar {
+    BindLetVar() {}
+    BindLetVar(LayoutTransformPlanner* self, Var var, PrimExpr value) : self_(self), var_(var) {
+      if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) {
+        self_->loop_depth_lookup_[var_.get()] = loop_depth.value();
+        self_->active_let_bindings_[var_.get()] = Substitute(value, self_->active_let_bindings_);
+      }
+    }
+    ~BindLetVar() {
+      if (self_) {
+        self_->loop_depth_lookup_.erase(var_.get());
+        self_->active_let_bindings_.erase(var_.get());
+      }
+    }
+    BindLetVar(const BindLetVar&) = delete;
+    BindLetVar& operator=(const BindLetVar&) = delete;
+    BindLetVar(BindLetVar&& other) : BindLetVar() { swap(other); }
+    BindLetVar& operator=(BindLetVar&& other) {
+      swap(other);
+      return *this;
+    }
+    void swap(BindLetVar& other) {
+      std::swap(self_, other.self_);
+      std::swap(var_, other.var_);
+    }
+
+    LayoutTransformPlanner* self_{nullptr};
+    Var var_;
+  };
+
+  struct BindBlockRealize {
+    BindBlockRealize(LayoutTransformPlanner* self, BlockRealize block_realize) : self_(self) {
+      ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size());
+      for (size_t i = 0; i < block_realize->iter_values.size(); i++) {
+        bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var,
+                                 block_realize->iter_values[i]);
+      }
+      cache_ = std::move(block_realize);
+      std::swap(self_->innermost_block_realize_, cache_);
+    }
+    ~BindBlockRealize() { std::swap(self_->innermost_block_realize_, cache_); }
+    BindBlockRealize(const BindBlockRealize&) = delete;
+    BindBlockRealize& operator=(const BindBlockRealize&) = delete;
+    BindBlockRealize(BindBlockRealize&&) = delete;
+    BindBlockRealize& operator=(BindBlockRealize&&) = delete;
+
+    LayoutTransformPlanner* self_{nullptr};
+    Optional<BlockRealize> cache_;
+    std::vector<BindLetVar> bound_vars_;
+  };
+
+  struct WriteInfo {
+    // The BufferStore object
+    BufferStore store;
+
+    // The block realize that contains the store, if any.
+    Optional<BlockRealize> innermost_block_realize;
+
+    // The nested loops whose values contribute to the indices used in
+    // the store.  Not all loop variables in the loopnest need to
+    // contribute, but the first and last must.
+    std::vector<For> dependent_loopnest;
+
+    // Whether the padding could be represented as a tir::if_then_else
+    // node.  This requires that the surrounding loop iterators
+    // iterate over all pre-transformation buffer axes, that there are
+    // no data dependencies between loop iterations, and that
+    bool contains_row_major_traversal{false};
+  };
+
+  struct LoopEntry {};
+
+  std::vector<WriteInfo> write_info_;
+  std::vector<For> active_loops_;
+  std::unordered_map<const VarNode*, std::pair<size_t, size_t>> loop_depth_lookup_;
+  std::unordered_map<const VarNode*, PrimExpr> active_let_bindings_;
+  Optional<BlockRealize> innermost_block_realize_{NullOpt};

Review Comment:
   document these fields



##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -16,12 +16,580 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+#include <optional>
+#include <variant>
+
 #include "../../../arith/ir_mutator_with_analyzer.h"
 #include "../utils.h"
 
 namespace tvm {
 namespace tir {
 
+class LayoutTransformPlanner : private StmtExprVisitor {

Review Comment:
   document the high level algorithm



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r972246269


##########
python/tvm/tir/function.py:
##########
@@ -389,17 +389,27 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] =
 
         final_indices = []
         axis_separators = []
-        for val in mapping:
-            if isinstance(val, tvm.ir.PrimExpr):
-                final_indices.append(val)
-            elif val is IndexMap.AXIS_SEPARATOR:
-                axis_separators.append(len(final_indices))
-            else:
-                raise TypeError(
-                    "Expected mapping function to return list of "
-                    "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR.  "
-                    f"Instead received {val} of type {type(val)}."
-                )
+
+        try:
+            iter(mapping)

Review Comment:
   This was to allow the mapping function to return a single `PrimExpr`, or something that the ffi can convert into a `PrimExpr`.  Since it wouldn't make sense for the pad value to provide multiple outputs, I found myself frequently writing `lambda i,j : i+j` instead of `lambda i,j: [i+j]`.  I figured that since I was frequently making that mistake, later users would also likely make it as well, so it would be best to support that functionality.
   
   Good call on the documentation, and I'll update the documentation for `from_func` and `from_func_with_separators` accordingly.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r973042996


##########
python/tvm/tir/schedule/schedule.py:
##########
@@ -2479,6 +2480,31 @@ def transform_layout(
             primitive will be called in addition to the
             TransformLayout primitive.
 
+        pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]

Review Comment:
   Thank you, and the docstring has been updated.  I've also added two unit tests, one that validates that an error is raised when the pad value loads from a different buffer, and one that specifies the intended behavior for pad value that loads from the transformed buffer.  The latter is currently marked with `pytest.mark.xfail`, as the support isn't implemented yet.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Hzfengsy commented on pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on PR #12720:
URL: https://github.com/apache/tvm/pull/12720#issuecomment-1238779969

   cc @wrongtest-intellif @vinx13 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r964876309


##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -601,9 +601,11 @@ class ScheduleNode : public runtime::Object {
    * \param buffer_index The index of the buffer in block's read or write region.
    * \param buffer_index_type The type of the buffer index, kRead or kWrite.
    * \param index_map The transformation to apply.
+   * \param pad_value The value to write into padding introduced by the transformation.
    */
   virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
-                               BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;
+                               BufferIndexType buffer_index_type, const IndexMap& index_map,
+                               const Optional<PrimExpr>& pad_value) = 0;

Review Comment:
   Thank you for the catch, and updated.  This also makes the C++ API closer to the Python API.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r965325566


##########
python/tvm/tir/schedule/schedule.py:
##########
@@ -2479,6 +2480,20 @@ def transform_layout(
             primitive will be called in addition to the
             TransformLayout primitive.
 
+        pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]
+
+            The value to be used for any padding introduced by the
+            transformation.
+
+            If None, the transformation may not introduce padding.
+
+            If an int, float or PrimExpr, the transformation is the
+            specific value to be present in the padding.
+
+            If an IndexMap or Callable, the transformation is the
+            value to be present in the padding in terms of the
+            transformed index.

Review Comment:
   Good point.  I had been thinking of it as the `(const Array<Var>&, const Array<PrimExpr>&)` call signature on the TE side for the transformation, and was avoiding introducing additional structures.  I had forgotten that the TIR schedule accepts an `IndexMap` for the transformation, and agree that the C++ side would be better expressed as an `Optional<IndexMap>` instead.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
masahi commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r970073101


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):
 
 
 def shared_32x16_to_ldmatrix_32x16_layout(i, j):
-    thread_id = (i % 4) + 4 * (j % 8)
+    thread_id = (i % 16) // 4 + 4 * (j % 8)

Review Comment:
   Even if the index map is incorrect, it doesn't affect the correctness of tensorized MMA since the index map is only used for pattern matching purpose...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12720: [TIR] Implement API for padded layout transformations

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r965997578


##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -16,12 +16,580 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+#include <optional>
+#include <variant>
+
 #include "../../../arith/ir_mutator_with_analyzer.h"
 #include "../utils.h"
 
 namespace tvm {
 namespace tir {
 
+class LayoutTransformPlanner : private StmtExprVisitor {
+ public:
+  // Statement to be inserted prior to the analyzed block
+  struct ProloguePlan {
+    Stmt prologue;
+  };
+
+  // Loops within the analyzed block that should be replaced
+  struct ReplacementPlan {
+    Map<For, Stmt> replacements;
+    Map<Block, Block> block_sref_reuse;
+  };
+
+  // The block to be inserted, along with the location at which it
+  // should be inserted.  The location will be either a For or a
+  // Block, and will be after all writes the transformed buffer.
+  struct EpiloguePlan {
+    Stmt insert_after;
+    Stmt new_block;
+  };
+
+  struct NoPaddingRequired {};
+
+  using TransformPlan =
+      std::variant<ProloguePlan, ReplacementPlan, EpiloguePlan, NoPaddingRequired>;
+
+  static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map,
+                            IndexMap inverse, PrimExpr padding_predicate,
+                            Optional<PrimExpr> pad_value) {
+    LayoutTransformPlanner visitor(old_buffer);
+    visitor(block);
+    return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value);
+  }
+
+ private:
+  explicit LayoutTransformPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {}
+
+  void VisitStmt_(const ForNode* op) override {
+    BindLoopVar context(this, GetRef<For>(op));
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const LetStmtNode* op) override {
+    BindLetVar context(this, op->var, op->value);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const BlockRealizeNode* op) override {
+    BindBlockRealize context(this, GetRef<BlockRealize>(op));
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) override {
+    if (!op->buffer.same_as(old_buffer_)) {
+      return;
+    }
+
+    std::optional<std::pair<size_t, size_t>> loop_dependency_range = std::nullopt;
+    for (const auto& index : op->indices) {
+      if (auto index_depth = LoopDependencyRange(index); index_depth.has_value()) {
+        if (loop_dependency_range) {
+          loop_dependency_range = {
+              std::min(loop_dependency_range.value().first, index_depth.value().first),
+              std::max(loop_dependency_range.value().second, index_depth.value().second)};
+        } else {
+          loop_dependency_range = index_depth;
+        }
+      }
+    }
+
+    WriteInfo write_info;
+    write_info.store = GetRef<BufferStore>(op);
+    if (loop_dependency_range) {
+      size_t i = loop_dependency_range.value().first;
+      size_t j = loop_dependency_range.value().second;
+      ICHECK_LT(i, active_loops_.size());
+      ICHECK_LT(j, active_loops_.size());
+
+      write_info.dependent_loopnest = {active_loops_.begin() + i, active_loops_.begin() + j + 1};
+    }
+    write_info.innermost_block_realize = innermost_block_realize_;
+
+    write_info.contains_row_major_traversal = [&]() -> bool {
+      const auto& loopnest = write_info.dependent_loopnest;
+      if (loopnest.empty()) {
+        return false;
+      }
+
+      if (loopnest.size() != old_buffer_->shape.size() || loopnest.size() != op->indices.size()) {
+        return false;
+      }
+
+      for (size_t i = 0; i < loopnest.size(); i++) {
+        const For& loop = loopnest[i];
+        const PrimExpr& buffer_dim = old_buffer_->shape[i];
+        PrimExpr index = Substitute(op->indices[i], active_let_bindings_);
+        bool is_loop_over_axis = index.same_as(loop->loop_var) && is_const_int(loop->min, 0) &&
+                                 ExprDeepEqual()(loop->extent, buffer_dim) &&
+                                 loop->kind == ForKind::kSerial;
+        if (!is_loop_over_axis) {
+          return false;
+        }
+      }
+
+      return true;
+    }();
+
+    write_info_.push_back(write_info);
+
+    // Don't need to continue recursing, as the entire goal was to
+    // find the BufferStore.
+  }
+
+  std::optional<std::pair<size_t, size_t>> LoopDependencyRange(const PrimExpr& expr) const {
+    std::optional<std::pair<size_t, size_t>> prev = std::nullopt;
+    for (const auto& var : UndefinedVars(expr)) {
+      auto it = loop_depth_lookup_.find(var.get());
+      if (it != loop_depth_lookup_.end()) {
+        if (prev.has_value()) {
+          prev = {std::min(prev.value().first, it->second.first),
+                  std::max(prev.value().second, it->second.second)};
+        } else {
+          prev = it->second;
+        }
+      }
+    }
+
+    return prev;
+  }
+
+  class BufferStoreReplacer : public StmtExprMutator {
+   public:
+    BufferStoreReplacer(std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store,
+                        std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
+                            replace_block_realize)
+        : replace_store_(replace_store), replace_block_realize_(replace_block_realize) {}
+
+    Stmt VisitStmt_(const BufferStoreNode* op) final {
+      if (auto replacement = replace_store_(op)) {
+        auto store = Downcast<BufferStore>(replacement.value());
+        return StmtExprMutator::VisitStmt_(store.get());
+      } else {
+        return StmtExprMutator::VisitStmt_(op);
+      }
+    }
+
+    Stmt VisitStmt_(const BlockRealizeNode* op) final {
+      auto realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
+      if (auto replacement = replace_block_realize_(op, realize)) {
+        return replacement.value();
+      } else {
+        return std::move(realize);
+      }
+    }
+
+   private:
+    std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store_;
+    std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
+        replace_block_realize_;
+  };
+
+  TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse,
+                         PrimExpr padding_predicate, Optional<PrimExpr> pad_value) const {
+    if (auto prologue_plan =
+            FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value);
+        prologue_plan.has_value()) {
+      return prologue_plan.value();
+    } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, index_map, inverse,
+                                                               padding_predicate, pad_value);
+               replacement_plan.has_value()) {
+      return replacement_plan.value();
+    } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse,
+                                                         padding_predicate, pad_value);
+               epilogue_plan.has_value()) {
+      return epilogue_plan.value();
+    } else {
+      return NoPaddingRequired();
+    }
+  }
+
+  std::optional<ProloguePlan> FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map,
+                                                   IndexMap inverse, PrimExpr padding_predicate,
+                                                   Optional<PrimExpr> pad_value) const {
+    if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> iter_values;
+    Array<PrimExpr> indices;
+    Map<Var, PrimExpr> loop_indices_to_block_indices;
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+      const auto& loop_var = inverse->initial_indices[i];
+      const auto& dim = new_buffer->shape[i];
+      Var block_var("v_" + loop_var->name_hint, loop_var->dtype);
+      IterVar iter_var(Range(0, dim), block_var, kDataPar);
+      loop_indices_to_block_indices.Set(loop_var, block_var);
+      indices.push_back(iter_var->var);
+      iter_vars.push_back(iter_var);
+      iter_values.push_back(loop_var);
+    }
+    padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices);
+
+    PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value.value());
+    Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr}));
+
+    std::stringstream block_name;
+    block_name << "buffer_" << new_buffer->name << "_assumptions";
+    auto read_region = BufferRegion::FromPoint(new_buffer, indices);
+    stmt = BlockRealize(iter_values, Bool(true),
+                        Block(iter_vars, {read_region}, {}, block_name.str(), stmt));
+
+    for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+      size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+      Var loop_var = inverse->initial_indices[i];
+      PrimExpr extent = new_buffer->shape[i];
+      stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+    }
+    return ProloguePlan{stmt};
+  }
+
+  std::optional<ReplacementPlan> FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map,
+                                                         IndexMap inverse,
+                                                         PrimExpr padding_predicate,
+                                                         Optional<PrimExpr> pad_value) const {
+    if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional<Stmt> {
+      if (!info.contains_row_major_traversal || !pad_value.defined() ||
+          is_zero(padding_predicate)) {
+        return NullOpt;
+      }
+
+      Array<PrimExpr> old_indices = info.store->indices;
+      PrimExpr if_then_else_condition = padding_predicate;
+      Array<PrimExpr> new_indices;
+      for (const auto& var : inverse->initial_indices) {
+        new_indices.push_back(var);
+      }
+
+      auto replace_block_realize =
+          [&]() -> std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)> {
+        auto no_change = [](const BlockRealizeNode*, const BlockRealize&) -> Optional<Stmt> {
+          return NullOpt;
+        };
+        if (!info.innermost_block_realize) {
+          return no_change;
+        }
+        if (old_indices.empty()) {
+          return no_change;
+        }
+
+        BlockRealize block_realize = info.innermost_block_realize.value();
+        const auto& block = block_realize->block;
+
+        // Find the block iterators that are used to access the buffer.  Must be in the same order
+        // as they appear in the indices.
+        if (block->iter_vars.size() < old_indices.size()) {
+          return no_change;
+        }
+        const auto& iter_vars = block->iter_vars;
+        size_t block_index_start = 0;
+        for (; block_index_start < iter_vars.size() - old_indices.size(); block_index_start++) {
+          if (old_indices[0].same_as(iter_vars[block_index_start]->var)) {
+            break;
+          }
+        }
+        if (block_index_start > iter_vars.size() - old_indices.size()) {
+          return no_change;
+        }
+
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          if (!old_indices[i].same_as(iter_vars[block_index_start + i]->var) ||
+              iter_vars[block_index_start + i]->iter_type != kDataPar) {
+            return no_change;
+          }
+        }
+
+        // If we got to this point, all indices used to access the
+        // buffer are virtual indices defined in the innermost block.
+        // Therefore, generate new virtual indices for iterating over
+        // the post-transform buffer.
+        Array<PrimExpr> new_iter_values;             // For BlockRealize
+        Array<IterVar> new_iter_vars;                // For Block
+        Array<PrimExpr> new_access_indices;          // For BufferStore
+        Map<Var, PrimExpr> loop_var_to_virtual_var;  // For updating if_then_else_condition
+
+        for (size_t i = 0; i < block_index_start; i++) {
+          new_iter_vars.push_back(iter_vars[i]);
+          new_iter_values.push_back(block_realize->iter_values[i]);
+        }
+
+        ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+        for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+          Var var = inverse->initial_indices[i];
+          PrimExpr dim = new_buffer->shape[i];
+          std::stringstream ss;
+          ss << "v_" << var->name_hint;
+          Var virtual_var(ss.str(), var.dtype());
+          new_iter_values.push_back(var);
+          new_iter_vars.push_back(IterVar(Range::FromMinExtent(0, dim), virtual_var, kDataPar));
+          new_access_indices.push_back(virtual_var);
+          loop_var_to_virtual_var.Set(var, virtual_var);
+        }
+
+        for (size_t i = block_index_start + old_indices.size(); i < iter_vars.size(); i++) {
+          new_iter_vars.push_back(iter_vars[i]);
+          new_iter_values.push_back(block_realize->iter_values[i]);
+        }
+
+        Map<Var, PrimExpr> old_virtual_var_to_new_virtual_var;
+        ICHECK_EQ(inverse->final_indices.size(), old_indices.size());
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          Var var = Downcast<Var>(old_indices[i]);
+          PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var);
+          old_virtual_var_to_new_virtual_var.Set(var, expr);
+        }
+
+        if_then_else_condition = Substitute(if_then_else_condition, loop_var_to_virtual_var);
+        new_indices = new_access_indices;
+
+        return [target_realize = info.innermost_block_realize, new_iter_vars, new_iter_values,
+                old_virtual_var_to_new_virtual_var](const BlockRealizeNode* op,
+                                                    const BlockRealize& visited) -> Optional<Stmt> {
+          if (op == target_realize.get()) {
+            Block block = visited->block;
+            block =
+                Downcast<Block>(Substitute(std::move(block), old_virtual_var_to_new_virtual_var));
+            block.CopyOnWrite()->iter_vars = new_iter_vars;
+
+            BlockRealize realize = visited;
+            {
+              auto write_ptr = realize.CopyOnWrite();
+              write_ptr->block = block;
+              write_ptr->iter_values = new_iter_values;
+            }
+            return realize;
+          } else {
+            return NullOpt;
+          }
+        };
+      }();
+
+      bool all_stores_replaced = true;
+      auto replace_store = [&](const BufferStoreNode* op) -> Optional<Stmt> {
+        if (!op->buffer.same_as(info.store->buffer)) {
+          all_stores_replaced = false;
+          return NullOpt;
+        }
+        ICHECK_EQ(old_indices.size(), op->indices.size());
+        ExprDeepEqual expr_equal;
+        for (size_t i = 0; i < old_indices.size(); i++) {
+          if (!expr_equal(old_indices[i], op->indices[i])) {
+            all_stores_replaced = false;
+            return NullOpt;
+          }
+        }
+
+        return BufferStore(new_buffer,
+                           if_then_else(if_then_else_condition, pad_value.value(), op->value),
+                           new_indices);
+      };
+
+      BufferStoreReplacer replacer(replace_store, replace_block_realize);
+      Stmt stmt = replacer(info.dependent_loopnest.back()->body);
+      if (!all_stores_replaced) {
+        return NullOpt;
+      }
+
+      std::unordered_map<const VarNode*, PrimExpr> var_remap;
+      ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
+      for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
+        Var var = info.dependent_loopnest[i]->loop_var;
+        PrimExpr expr = inverse->final_indices[i];
+        var_remap[var.get()] = expr;
+      }
+      stmt = Substitute(std::move(stmt), var_remap);
+
+      ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+      for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+        size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+        Var loop_var = inverse->initial_indices[i];
+        PrimExpr extent = new_buffer->shape[i];
+        stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+      }
+
+      return stmt;
+    };
+
+    Map<For, Stmt> loop_replacements;
+
+    for (const auto& info : write_info_) {
+      if (info.dependent_loopnest.size()) {
+        if (auto opt_stmt = generate_if_then_else_block(info)) {
+          loop_replacements.Set(info.dependent_loopnest[0], opt_stmt.value());
+        }
+      }
+    }
+
+    if (loop_replacements.size()) {
+      return ReplacementPlan{std::move(loop_replacements)};
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  std::optional<EpiloguePlan> FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map,
+                                                   IndexMap inverse, PrimExpr padding_predicate,
+                                                   Optional<PrimExpr> pad_value) const {
+    if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) {
+      return std::nullopt;
+    }
+
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> iter_values;
+    Array<PrimExpr> indices;
+    Map<Var, PrimExpr> loop_indices_to_block_indices;
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+      const auto& loop_var = inverse->initial_indices[i];
+      const auto& dim = new_buffer->shape[i];
+      Var block_var("v_" + loop_var->name_hint, loop_var->dtype);
+      IterVar iter_var(Range(0, dim), block_var, kDataPar);
+      loop_indices_to_block_indices.Set(loop_var, block_var);
+      indices.push_back(iter_var->var);
+      iter_vars.push_back(iter_var);
+      iter_values.push_back(loop_var);
+    }
+    padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices);
+
+    Stmt stmt = BufferStore(new_buffer, pad_value.value(), indices);
+
+    std::stringstream block_name;
+    block_name << "buffer_" << new_buffer->name << "_padding";
+    auto write_region = BufferRegion::FromPoint(new_buffer, indices);
+    stmt = BlockRealize(iter_values, padding_predicate,
+                        Block(iter_vars, {}, {write_region}, block_name.str(), stmt));
+
+    ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
+    for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
+      size_t i = (inverse->initial_indices.size() - 1) - rev_i;
+      Var loop_var = inverse->initial_indices[i];
+      PrimExpr extent = new_buffer->shape[i];
+      stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
+    }
+
+    const auto& info = write_info_.back();
+    Stmt insert_after = [&]() -> Stmt {
+      if (info.dependent_loopnest.size()) {
+        return info.dependent_loopnest.front();
+      } else if (info.innermost_block_realize) {
+        return info.innermost_block_realize.value();
+      } else {
+        LOG(FATAL) << "Write occured outside of any block/loop";
+        return Stmt();
+      }
+    }();
+    return EpiloguePlan{insert_after, stmt};
+  }
+
+  struct BindLoopVar {
+    BindLoopVar(LayoutTransformPlanner* self, For for_node)
+        : self_(self), var_(for_node->loop_var) {
+      size_t loop_depth = self_->active_loops_.size();
+      self_->loop_depth_lookup_[var_.get()] = {loop_depth, loop_depth};
+      self_->active_loops_.push_back(std::move(for_node));
+    }
+    ~BindLoopVar() {
+      self_->active_loops_.pop_back();
+      self_->loop_depth_lookup_.erase(var_.get());
+    }
+    BindLoopVar(const BindLoopVar&) = delete;
+    BindLoopVar& operator=(const BindLoopVar&) = delete;
+    BindLoopVar(BindLoopVar&&) = delete;
+    BindLoopVar& operator=(BindLoopVar&&) = delete;
+
+    LayoutTransformPlanner* self_{nullptr};
+    Var var_;
+  };
+
+  struct BindLetVar {
+    BindLetVar() {}
+    BindLetVar(LayoutTransformPlanner* self, Var var, PrimExpr value) : self_(self), var_(var) {
+      if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) {
+        self_->loop_depth_lookup_[var_.get()] = loop_depth.value();
+        self_->active_let_bindings_[var_.get()] = Substitute(value, self_->active_let_bindings_);
+      }
+    }
+    ~BindLetVar() {
+      if (self_) {
+        self_->loop_depth_lookup_.erase(var_.get());
+        self_->active_let_bindings_.erase(var_.get());
+      }
+    }
+    BindLetVar(const BindLetVar&) = delete;
+    BindLetVar& operator=(const BindLetVar&) = delete;
+    BindLetVar(BindLetVar&& other) : BindLetVar() { swap(other); }
+    BindLetVar& operator=(BindLetVar&& other) {
+      swap(other);
+      return *this;
+    }
+    void swap(BindLetVar& other) {
+      std::swap(self_, other.self_);
+      std::swap(var_, other.var_);
+    }
+
+    LayoutTransformPlanner* self_{nullptr};
+    Var var_;
+  };
+
+  struct BindBlockRealize {
+    BindBlockRealize(LayoutTransformPlanner* self, BlockRealize block_realize) : self_(self) {
+      ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size());
+      for (size_t i = 0; i < block_realize->iter_values.size(); i++) {
+        bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var,
+                                 block_realize->iter_values[i]);
+      }
+      cache_ = std::move(block_realize);
+      std::swap(self_->innermost_block_realize_, cache_);
+    }
+    ~BindBlockRealize() { std::swap(self_->innermost_block_realize_, cache_); }
+    BindBlockRealize(const BindBlockRealize&) = delete;
+    BindBlockRealize& operator=(const BindBlockRealize&) = delete;
+    BindBlockRealize(BindBlockRealize&&) = delete;
+    BindBlockRealize& operator=(BindBlockRealize&&) = delete;
+
+    LayoutTransformPlanner* self_{nullptr};
+    Optional<BlockRealize> cache_;
+    std::vector<BindLetVar> bound_vars_;
+  };
+
+  struct WriteInfo {
+    // The BufferStore object
+    BufferStore store;
+
+    // The block realize that contains the store, if any.
+    Optional<BlockRealize> innermost_block_realize;
+
+    // The nested loops whose values contribute to the indices used in
+    // the store.  Not all loop variables in the loopnest need to
+    // contribute, but the first and last must.
+    std::vector<For> dependent_loopnest;
+
+    // Whether the padding could be represented as a tir::if_then_else
+    // node.  This requires that the surrounding loop iterators
+    // iterate over all pre-transformation buffer axes, that there are
+    // no data dependencies between loop iterations, and that
+    bool contains_row_major_traversal{false};
+  };
+
+  struct LoopEntry {};
+
+  std::vector<WriteInfo> write_info_;
+  std::vector<For> active_loops_;
+  std::unordered_map<const VarNode*, std::pair<size_t, size_t>> loop_depth_lookup_;
+  std::unordered_map<const VarNode*, PrimExpr> active_let_bindings_;
+  Optional<BlockRealize> innermost_block_realize_{NullOpt};

Review Comment:
   Thank you, and documentation added here for member vars, along with how they are used when collecting `WriteInfo`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org