You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/06 20:38:03 UTC

[GitHub] [tvm-rfcs] Lunderberg opened a new pull request, #77: [RFC] Buffer Layout Padding

Lunderberg opened a new pull request, #77:
URL: https://github.com/apache/tvm-rfcs/pull/77

   This RFC introduces a method to specify padding to be applied as part of a buffer layout transformation, to be used when the desired layout does not evenly tile the buffer being transformed, and simplifications that can be performed based on these padded buffers.
   
   The motivating examples are primarily in the "Implementation options" section, which goes through several desired usages of the buffer padding, and how they can be automatically derived using the TIR primitives/transformations described in earlier sections.
   
   TODO: Rendered Markdown link


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r890650624


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   I agree with the need of introducing primitives that handles padding, along with transformation.
   
   The main question is whether or not we should introduce an IR semantics change to enable such usecase. On one hand, introducing a new IR semantics certainly makes the TIR istelf to be more expressive, but it also brings additional complexities for every primitives/passes that handles the IR. Special primitives may also be needed to handle the new introduced semantics.
   
   It would be great for us to explore such capabilities(of layout padding) without introducing additional complexities on the IR.  
   
   Back to our goal of introducing padding. It should be possible to have an explicit buffer transformation stage that coppies the data into the target padded buffer(with predication) , then run computation on the padded value.
   
   There is certainly some tradeoffs here, but decoupling the padding behavior as a separate stage of IR computation should be able to allow us to reuse more primitives without having to specializing for BufferConstraint
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894941045


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Of course for the sake of discussion the example is limited to two convolutions. In some cases multiple (N) back-to-back contractions with padded transformations, handling at the graph-level can require similar non-local information/hints across the sequence of operators to Nth order.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   > Re: @vinx13: To add a discussion point, I'd like to ask whether the semantic like over computation and writing default value of next operate, can be achieved with graph level rewriting.
   > Re: @tqchen: Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension rewriting and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   * A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   * A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialize a function to legalize between layers only when necessary. 
   
   The main difference I see is that A0 pushes more work into manual hardware specific optimization (compute definitions + patterns/rewriters) that is not as easily repurposable for other hardware targets; whereas A1 provides the infrastructure for more general compiler simplification to proceed from hardware semantics the user provides about the buffer when transforming the layout at schedule time. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   > Re: @vinx13: To add a discussion point, I'd like to ask whether the semantic like over computation and writing default value of next operate, can be achieved with graph level rewriting.
   > Re: @tqchen: Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension rewriting and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   * A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   * A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialization a function to legalize. 
   
   The main differences I see is that A0 pushes more work into manual hardware specific optimization (compute definitions + patterns/rewriters) that is not as easily repurposable for other hardware targets; whereas A1 provides the infrastructure for more general compiler simplification to proceed from hardware semantics the user provides about the buffer when transforming the layout at schedule time. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1169188372

   These make sense, and agreed that the TIR->global feedback is important for enabling the layout reflow.  Going back through the discussion, I think we're converging on agreement on what features are required, and the main question remaining are how to best provide annotation for non-local information, and how best to express layout transformations while scheduling.
   
   I've made some updates to the text of the RFC, based on the discussions here, primarily to remove the proposed changes to TIR data structures.  This follows your comment from a few days ago, which brought up `__builtin_assume` as a comparison.
   
   * Adding an intrinsic `tir::builtin::assume`, which corresponds to the `__builtin_assume` LLVM intrinsic.  The emphasis is that these assumptions are primarily to expose non-local information for use in local simplifications.
   * Removing `BufferConstraint` entirely.  The RFC no longer proposes any changes to TIR data structures, only the `assume` and `undef` intrinsics.
   * Describing what assumptions can/should be placed into a PrimFunc when hoisting stages out into independent PrimFuncs, and what transformations are legal based on the choice of exposed assumptions.
   * Captured some of the discussion here about the dangers of altering a PrimFunc's interface, and the limited cases where it may be altered.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891831175


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+intended for use as `BufferConstraintNode::value`, to indicate that it
+is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  This is primarily used to
+allow simplifications in a producer, as any partial computations
+written to this space (e.g. by vectorized operations) may be left
+as-is.
+
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+### New Transform - Hoist Expression

Review Comment:
   They are different alternative or can be combined on certain workloads. There is a discussion on performance issue of matmul when we just change dimension 128 -> 127. https://discuss.tvm.apache.org/t/te-vectorize-do-we-have-plan-to-support-vectorize-for-non-divisible-split/12469
   I think it might be a good working example. Below is what user get with loop split  `j`: 127 -> (4, 32)
   
   ```python
   for i in range(127):
       for k in range(127):
           for j.outer in range(4):
                for j.inner in T.vectorized(32):
                    if T.likely(j.outer * 32 + j.inner < 127, dtype="bool"):
                        C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
   ```
   
   The issue is that complex condition has to be introduced to maintain the program semantic, and it hurts the performance and generally we can not vectorize program with control flow.
   
   Now I understand we have different alternatives to handle this:
   - Loop partition
   We can already annotate the loop var with hint using non-imperative loop partition.
   
     `for j.outer in range(4, annotations={"pragma_loop_partition_hint": 1}`
   
     After `LoopPartition` pass (and simplify) it becomes:
     ```python
     for i in range(127):
         for k in range(127):
             # j.outer in [0, 3)
             for j.outer in range(3):
                  for j.inner in T.vectorized(32):
                       # condition is const true, optimize out
                       C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
             # j.outer in [3, 4), optimize out
             for j.inner in T.vectorized(31):
                  # condition becomes j.inner < 31, hoisted with loop
                  C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
     ```
     Then the condition branch get eliminated on different loop parts, thus becomes more friendly to performance optimizations like vectorization. For "imperative" partition, it just propose we can just partition on schedule phase when one wants to schedule different parts, such as giving different vectorization width.
   
   - Loop padding
   
     With current RFC, I understand we can padding `C` and `B`'s innermost dimension to 128, and drop the condition directly somehow. Then it directly becomes (IIUC, we may also insert some "arbitrary" value filling code on edges and optimize them out then?) 
     ```python
     for i in range(127):
       for k in range(127):
           for j.outer in range(4):
                for j.inner in T.vectorized(32):
                    C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
     ```
   
     On this particular case, I believe the padding is the better choice since we can get very neat codes with minimal over-computations. And we can also utilize the padding trick for different loop parts in alternative (1).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891630186


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   @tqchen @Hzfengsy Thank you, and I definitely agree on minimizing the number of IR changes being made.  (Also, phew, this ended up being a longer reply than I had expected, which probably means that whatever the result of this thread, the "Rationale and Alternatives" section should be updated.)
   
   @areusch The example transformations are largely present in the "Implementation Options" section.  The goal of that section was to describe different example transformations that we'd like to be able to make, and to ensure that they could be made using the functionality introduced earlier in the RFC.  It wasn't until this morning that I realized that there should also be links in the other direction, pointing from the proposed IR changes to the motivating use case.
   
   Below is the general rationale, with high-level implementations.
   
   Starting by listing out the desired properties of an implementation.
   
   1. No changes to existing TIR data structures
   2. No additional meaning to existing TIR data structures
   3. Simplifications can use constraints from multiple buffers
   4. No ordering requiring between `transform_layout` and fuse/split/reorder.  (Conditional statements that are removed using the buffer constraints are typically introduced by loop rewrites.)
   5. Can be used to describe out-of-bounds access (e.g. texture memory clamping on a GPU) that returns a default value.
   6. Only allocate memory when required or requested
   
   
   Implementations considered
   
   - A. All buffer transformations introduce new stage
   
     - Pro: No coordination required between different operators.
     - Con: Any producer/consumer interactions must be recognized by operator fusion/device planning.
     - Con: Cannot apply to primfunc input/outputs.  (e.g. To de-duplicate operators that differ only by underlying layout, such as `topi.nn.conv2d_hwcn`, `topi.nn.conv2d_nchw`, `topi.nn.conv2d_NCHWc`, etc.)
     - Con: May introduce unnecessary data copies, if the constraint required by the consumer is already met.
   
   - B. Perform one `transform_layout` at a time.  For each one, simplify using provided constraints, do not store constraints afterward.
     
     - Pro: Main downside is that it could only use the constraints of a single buffer at a time.  This wouldn't be able to express simplifications that rely on the padding in multiple buffers. (e.g. [elementwise operator](https://github.com/Lunderberg/tvm-rfcs/blob/buffer_layout_padding/rfcs/0077-layout-transform-padding.md#apply-operator-element-wise-over-the-transformation-padding)
   
     - Con: Requires loop rewriting to be done either inside `layout_transform` or prior to calling `layout_transform`.
     - Con: Can't be applied to use cases outside of layout transformations (e.g. texture memory clamping on a GPU), where simplifications could benefit from assumed constraints.
     
     
   - C. Perform all `transform_layout` in a single function call, passing all layout transforms and padding constraints.
   
     - Pro: Simplifications may use constraints of all buffers being transformed.
     - Con: Requires changing the calling convention for layout transformations.
     - Con: Requires loop rewriting to be done either inside `layout_transform` or prior to calling `layout_transform`.
     - Con: Can't be applied to use cases outside of layout transformations (e.g. texture memory clamping on a GPU), where simplifications could benefit from assumed constraints.
       
   - D. Express buffer constraints using existing `AssertStmt`
   
     In pseudocode, each consumer would have roughly the loopnest below. However, this would still need to have some way of indicating that the constraint should be removed when lowering, and should not produce any runtime assertions.
     
     ```python
     for indices in T.grid(*transform_shape):
         if padding_predicate(indices):
             T.Assert(buf[indices] == pad_value(indices))
     ```
   
     - Pro: No change to TIR data structures
     - Pro: No change required for calling convention for layout transformations.
     - Pro: Simplifications may use constraints of all buffers being transformed.
     - Pro: Can be applied to use cases outside of layout transformations (e.g. texture memory clamping on a GPU), where simplifications could benefit from assumed constraints.
     - Pro: No ordering between loop/layout transforms, because the constraints can be determined from the TIR.
     - Con: Additional meaning attached to existing TIR data structures.
     - Con: Can only describe a fixed number of assertions, wouldn't be able to express a default value for all out-of-bounds reads.
   
   - E. Express buffer constraints as a field in `PrimFuncNode::attrs`
   
     - Con: Passes that replace `Buffer` objects must be aware of this attribute, in order to update the `Buffer` object stored in it.
       
   - F. Express buffer constraints as a new member variable in `Buffer`
   
     - Con: Changes change to TIR data structures
     - Pro: No change required for calling convention for layout transformations.
     - Pro: Simplifications may use constraints of all buffers being transformed.
     - Pro: Can be applied to use cases outside of layout transformations (e.g. texture memory clamping on a GPU), where simplifications could benefit from assumed constraints.
     - Pro: Can rewrite loop structure later, use existing constraints.
     
     
   1. No changes to existing TIR data structures
   2. No additional meaning to existing TIR data structures
   3. Simplifications can use constraints from multiple buffers
   4. No ordering requiring between `transform_layout` and fuse/split/reorder.
   5. Can be used to describe out-of-bounds access (e.g. texture memory clamping on a GPU) that returns a default value.
   6. Only allocate memory when required or requested
     
   - A. All buffer transformations introduce new stage
   - B. Perform one `transform_layout` at a time.  For each one, simplify using provided constraints, do not store constraints afterward.
   - C. Perform all `transform_layout` in a single function call, passing all layout transforms and padding constraints.
   - D. Express buffer constraints using existing `AssertStmt`
   - E. Express buffer constraints as a field in `PrimFuncNode::attrs`
   - F. Express buffer constraints as a new member variable in `Buffer`
     
     
   |        | Goal 1 | Goal 2 | Goal 3 | Goal 4 | Goal 5 | Goal 6 |
   |--------|--------|--------|--------|--------|--------|--------|
   | Impl A | :heavy_check_mark:    | :heavy_check_mark:    | :x:    | :x:    | :x:    | :x:    |
   | Impl B | :heavy_check_mark:    | :heavy_check_mark:    | :x:    | :x:    | :x:    | :heavy_check_mark:    |
   | Impl C | :heavy_check_mark:    | :heavy_check_mark:    | :heavy_check_mark:    | :x:    | :x:    | :heavy_check_mark:    |
   | Impl D | :heavy_check_mark:    | :x:    | :heavy_check_mark:    | :heavy_check_mark:    | :x:    | :heavy_check_mark:    |
   | Impl E | :heavy_check_mark:    | :x:    | :heavy_check_mark:    | :heavy_check_mark:    | :heavy_check_mark:    | :heavy_check_mark:    |
   | Impl F | :x:    | :heavy_check_mark:    | :heavy_check_mark:    | :heavy_check_mark:    | :heavy_check_mark:    | :heavy_check_mark:    |
   
   The implementations that would satisfy the largest number of the desired goals would be adding the member variable `BufferNode::constraints`, or adding a field to `PrimFuncNode::attrs` that holds the constraints.  Between the two, I lean toward having it as an explicit member variable, so that incorrect usage appears as a compilation error when compiling TVM, but would find either implementation acceptable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1165724403

   >  In general, a PrimFunc's interface could only be changed when calls into the PrimFunc are also modified to remain compatible.
   
   Agreed, that is what I originally intended to say 
   
   > Is there a better term than "scheduling primitive" to describe layout transformations that impact input/output buffers? I think the difference is between context-independent transformations that may be performed on a PrimFunc without changing, as opposed to context-dependent transformations that may only be performed as part of a graph-level transformation.
   
   There are a few things, one approach would be to allow schedule primitive to modify multiple functions(including callers), we might need this for more complicated cases.
   
   In our particular example, however, the idea is that the schedule primitive do not modify the input/output buffer, but introduce preproc and postproc stages with clear hint that they should be lifted out (aka we are doing the same thing in two steps)
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893925849


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+primarily used to allow simplifications in a producer.  See [section
+on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage, and [section on
+`tir.transform.RemoveArbitraryStore`](#new-lowering-transform-remove-tarbitrary)
+for its removal.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-primitive-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Primitive - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+A new S-TIR transformation `Schedule.sequential_buffer_access` should
+be introduced, which rewrites iteration loops according to the access
+pattern of a buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):

Review Comment:
   Thank you, and I can update the examples accordingly, though it will be a bit before I have the availability to do so.  (I had been primarily focusing on the algebraic manipulations during the initial drafting.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r909606577


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0

Review Comment:
   Thank you for the catch, and corrected line 851 to `if i//8 == 0`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1180607081

   cc @Hzfengsy @wrongtest-intellif  it would be great if you can also take a followup look


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] kparzysz-quic commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
kparzysz-quic commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r909051867


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0

Review Comment:
   Lines 862 and 865 are swapped.  Either that, or line 851 should read `if i//8 == 0:`.



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+* When encountering a `T.assume` statement, this should be used for
+  later simplifications.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = n//8
+
+  # After simplification.  Because the range of `n` is provided in the
+  # assumption, n//8 can be simplified.
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = 0
+  ```
+
+  These assumptions are statements only known to be true at the
+  location of the `T.assume` call.  For assumptions based on value
+  stored in a buffer, the assumption may be invalidated by later
+  writes to the buffer.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      if A[0] == B[0]:
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+
+  # After simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      # The first access of B[0] may be replaced with 0 using the
+      # assumption.
+      if A[0] == 0:
+          # These later accesses of B[0] may not be replaced, because
+          # for all loop iterations i!=0, the value stored in B[0] has
+          # been overwritten since the T.assume call.
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+  ```
+
+### New Transform - Hoist Expression
+
+A new utility `HoistExpression`, which is a generalization of the
+current `HoistIfThenElse` pass.  The transformation `HoistExpression`
+would apply to the entire body of the `PrimFunc`, and would be used to
+avoid duplication of functionality between `HoistIfThenElse` and
+`HoistExpression`.
+
+`HoistExpression` would also be exposed as a metaschedule primitive,
+acting within a specified block of the `PrimFunc`, with the
+configuration options given below.
+
+```c++
+enum class HoistConditional {
+  kNone = 0,
+  kIfElseStmt = (1<<0),
+  kIfElseExpr = (1<<1),
+  kBooleanExpression = (1<<2),
+};
+
+enum class HoistLetBinding {
+  kNone = 0,
+  kRequiredByCondition = (1<<0),
+  kLetStmt = (1<<1),
+  kLetExpr = (1m<<2),
+};
+```
+
+* The values in `HoistConditional` are bit flags, indicating which
+  conditionals should be hoisted.
+
+  * `HoistConditional::kNone` - Do not hoist conditionals
+
+  * `HoistConditional::kIfElseStmt` - If set, attempt to hoist
+    conditionals that occur within `IfThenElseNode::condition`.
+
+  * `HoistConditional::kIfElseExpr` - If set, attempt to hoist
+    conditionals that occur as the condition of a
+    `builtin::if_then_else` call.
+
+  * `HoistConditional::kBooleanExpression` - If set, attempt to hoist
+    any `PrimExpr` whose data type is `DataType::Bool()`.
+
+* The values in `HoistLetBindings` are bit flags, indicating which
+  bindings should be hoisted.
+
+  * `HoistLetBinding::kNone` - Do not hoist any let bindings.
+
+  * `HoistLetBinding::kRequiredByCondition` - If set, hoist a let
+    binding if it is required in order to hoist a conditional.
+
+  * `HoistLetBinding::kLetStmt = (1<<1)` - If set, attempt to hoist
+    any let bindings performed using `LetStmt`.
+
+  * `HoistLetBinding::kLetExpr` - If set, attempt to hoist any let
+    bindings performed using `Let`.
+
+The existing pass `HoistIfElse` is roughly equivalent to using
+`HoistExpression` with `HoistConditional::kIfElseStmt` and
+`HoistLetBinding::kNone`.  The one exception is that `HoistIfElse`
+occurs after all let bindings have been inlined, and does not check
+let bindings when determining if a condition can be hoisted.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(4,4), "float32"]):
+    for i in T.serial(4):
+        is_in_bounds = i < 3
+        if is_in_bounds:
+            A[i] = 0.0
+
+# Incorrectly hoisted by `HoistIfThenElse`
+@T.prim_func
+def func(A: T.Buffer[(4,), "float32"]) -> None:
+    is_in_bounds = T.var("bool")
+    if is_in_bounds:
+        for i in T.serial(4):
+            is_in_bounds = i < 3
+            A[i] = 0.0
+```
+
+### New Transform - Reduce Loop Extents
+
+Reduce the extent of loops based on conditionals present in the body
+of the loop.
+
+For any non-vectorized `tir::For` loop (`ForKind::kSerial` or
+`ForKind::kParallel`), if the body is a conditional and the
+conditional's `else_case` is empty, determine if the expression is of
+the form `(loop $CMP_OP const) && (...)`.  If so, use the comparison
+operator to reduce the loop extent, such that loop skips values for
+which the comparison is provably false.
+
+TODO: Double-check that this isn't already implemented elsewhere.
+
+TODO: Check if it is implementable using `IntSetAnalyzer`.
+
+Below is an example of how this can work along-side `HoistExpression`
+to simplify the initialization of padding.
+
+```python
+# Original function.
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    for i, j in T.grid(4, 4):
+        if i == 0 and j < 2:
+            A[i, j] = 0.0

Review Comment:
   This would look something like
   ```python
   for i, j in T.grid(0..1, 0..2):  # I'm making the ranges more verbose for clarity
     A[i, j] = 0.0
   for i, j in T.grid(0..1, 2..4)
     pass                           #  j < 2 is false
   for i, j in T.grid(1..4, 0..2)
     pass                           # i == 0 is false
   for i, j in T.grid(1..4, 2..4)
     pass                           # i == 0 is false, j < 2 is false
   ```



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+* When encountering a `T.assume` statement, this should be used for
+  later simplifications.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = n//8
+
+  # After simplification.  Because the range of `n` is provided in the
+  # assumption, n//8 can be simplified.
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = 0
+  ```
+
+  These assumptions are statements only known to be true at the
+  location of the `T.assume` call.  For assumptions based on value
+  stored in a buffer, the assumption may be invalidated by later
+  writes to the buffer.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      if A[0] == B[0]:
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+
+  # After simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      # The first access of B[0] may be replaced with 0 using the
+      # assumption.
+      if A[0] == 0:
+          # These later accesses of B[0] may not be replaced, because
+          # for all loop iterations i!=0, the value stored in B[0] has
+          # been overwritten since the T.assume call.
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+  ```
+
+### New Transform - Hoist Expression
+
+A new utility `HoistExpression`, which is a generalization of the
+current `HoistIfThenElse` pass.  The transformation `HoistExpression`
+would apply to the entire body of the `PrimFunc`, and would be used to
+avoid duplication of functionality between `HoistIfThenElse` and
+`HoistExpression`.
+
+`HoistExpression` would also be exposed as a metaschedule primitive,
+acting within a specified block of the `PrimFunc`, with the
+configuration options given below.
+
+```c++
+enum class HoistConditional {
+  kNone = 0,
+  kIfElseStmt = (1<<0),
+  kIfElseExpr = (1<<1),
+  kBooleanExpression = (1<<2),
+};
+
+enum class HoistLetBinding {
+  kNone = 0,
+  kRequiredByCondition = (1<<0),
+  kLetStmt = (1<<1),
+  kLetExpr = (1m<<2),
+};
+```
+
+* The values in `HoistConditional` are bit flags, indicating which
+  conditionals should be hoisted.
+
+  * `HoistConditional::kNone` - Do not hoist conditionals
+
+  * `HoistConditional::kIfElseStmt` - If set, attempt to hoist
+    conditionals that occur within `IfThenElseNode::condition`.
+
+  * `HoistConditional::kIfElseExpr` - If set, attempt to hoist
+    conditionals that occur as the condition of a
+    `builtin::if_then_else` call.
+
+  * `HoistConditional::kBooleanExpression` - If set, attempt to hoist
+    any `PrimExpr` whose data type is `DataType::Bool()`.
+
+* The values in `HoistLetBindings` are bit flags, indicating which
+  bindings should be hoisted.
+
+  * `HoistLetBinding::kNone` - Do not hoist any let bindings.
+
+  * `HoistLetBinding::kRequiredByCondition` - If set, hoist a let
+    binding if it is required in order to hoist a conditional.
+
+  * `HoistLetBinding::kLetStmt = (1<<1)` - If set, attempt to hoist
+    any let bindings performed using `LetStmt`.
+
+  * `HoistLetBinding::kLetExpr` - If set, attempt to hoist any let
+    bindings performed using `Let`.
+
+The existing pass `HoistIfElse` is roughly equivalent to using
+`HoistExpression` with `HoistConditional::kIfElseStmt` and
+`HoistLetBinding::kNone`.  The one exception is that `HoistIfElse`
+occurs after all let bindings have been inlined, and does not check
+let bindings when determining if a condition can be hoisted.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(4,4), "float32"]):
+    for i in T.serial(4):
+        is_in_bounds = i < 3
+        if is_in_bounds:
+            A[i] = 0.0
+
+# Incorrectly hoisted by `HoistIfThenElse`
+@T.prim_func
+def func(A: T.Buffer[(4,), "float32"]) -> None:
+    is_in_bounds = T.var("bool")
+    if is_in_bounds:
+        for i in T.serial(4):
+            is_in_bounds = i < 3
+            A[i] = 0.0
+```
+
+### New Transform - Reduce Loop Extents
+
+Reduce the extent of loops based on conditionals present in the body
+of the loop.
+
+For any non-vectorized `tir::For` loop (`ForKind::kSerial` or
+`ForKind::kParallel`), if the body is a conditional and the
+conditional's `else_case` is empty, determine if the expression is of
+the form `(loop $CMP_OP const) && (...)`.  If so, use the comparison
+operator to reduce the loop extent, such that loop skips values for
+which the comparison is provably false.
+
+TODO: Double-check that this isn't already implemented elsewhere.
+
+TODO: Check if it is implementable using `IntSetAnalyzer`.
+
+Below is an example of how this can work along-side `HoistExpression`
+to simplify the initialization of padding.
+
+```python
+# Original function.
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    for i, j in T.grid(4, 4):
+        if i == 0 and j < 2:
+            A[i, j] = 0.0
+
+
+# After hoisting with HoistConditional::kBooleanExpression
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    for i in T.serial(4):
+        if i == 0:
+            for j in T.serial(4):
+                if j < 2:
+                    A[i, j] = 0.0
+
+
+# After reducing the extents of serial loops
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    i = 0
+    for j in T.serial(2):
+        A[i, j] = 0.0
+```
+
+
+
+### Utility - Merge Adjacent Loops
+
+If it does not impact the resulting computation, loops may be merged
+together.  This is a valid transformation if both loops are serial
+loops, the loops have the same indices, and if the merging respects
+data dependencies.  This would be exposed as a metaschedule primitive,
+which takes input of the `LoopRV` to be merged.
+
+For adjacent loops, to prove that there is no data dependency, two
+conditions must hold.
+
+1. For all loop indices `i` and `j` where `i > j`, the set of indices
+   written by the first loop in iteration `i` is distinct from the set
+   of indices accessed by the second loop in iteration `j`.  That is,
+   merging the loops wouldn't cause the second loop body to read
+   partial values, nor would it cause the first loop body to overwrite
+   a value produced by the second loop body.
+
+2. For all loop indices `i` and `j` where `i < j`, the set of indices
+   read by the second loop in iteration `i` is distinct from the set
+   of indices written by the second loop in iteration `j`.  That is,
+   merging the loops wouldn't cause the second loop body to overwrite
+   values that are still required by the first loop body.
+
+Element-wise loops do not have any data dependencies, and adjacent
+element-wise loops may be merged.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+
+    for i in T.serial(16):
+        A[i] = 1.0
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j].
+#    c. intersection([i], [j]) = [i] if i==j else [].
+#    d. If i>j, the intersection is empty
+#
+# 2. a. In iteration i, loop 1 reads from index [].
+#    b. In iteration j, loop 2 writes to index [j]
+#    c. intersection([], [j]) = []
+#    c. For all i,j, the intersection is empty
+#
+# Therefore, this merger is valid
+
+# After merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+        A[i] = 1.0
+```
+
+The second loop may read indices that were written in an earlier
+iteration.  Merging would not impact the result.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+
+    for i in T.serial(16):
+        if i > 0:
+            A[i] = A[i - 1] + 1.0
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j,j-1].
+#    c. i>j implies that i!=j and i!=j-1.
+#    c. For all i,j where i<j,
+#
+# 2. a. In iteration i, loop 1 reads from index [].
+#    b. In iteration j, loop 2 writes to index [j]
+#    c. For all i,j, intersection([], [j]) = [].
+#
+# Therefore, this merger is valid
+
+
+# After merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+        if i > 0:
+            A[i] = A[i - 1] + i
+```
+
+The second loop may not read indices that were written in a later
+iteration of the first loop.  In this case, merging would impact the
+output values.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = i
+
+    for i in T.serial(16):
+        if 0 < i < 15:
+            A[i] = A[i - 1] + A[i] + A[i + 1]
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j-1,j,j+1].
+#    c. If i==j+1, then intersection([j+1], [j-1,j,j+1]) = [j+1],
+#       which is non-empty.
+#
+# Therefore, this merger is not valid.
+```
+
+### New Primitive - Remove Branching Through Overcompute
+
+A new transform which attempts to reduce branching by allowing
+overcompute.  It takes an argument to specify which block it should be
+applied within.
+
+For each `IfThenElseStmt`, check if the
+`IfThenElseStmtNode::else_case` is a simplified form of the
+`IfThenElseStmtNode::then_case`.  This check is done by simplifying
+`then_case`, under the assumption that `condition` is false, and
+substituting the known value in a `BufferConstraint` in any
+`BufferLoad` for which the predicate can be proven to be true.  If
+this simplified form is identical to the `else_case`, then the entire
+if/else block can be replaced with `then_case`.  Otherwise, this check
+is repeated to see if the `then_case` can be simplified down to the
+`else_case`.  If neither simplification holds, then no change is made.
+
+For example, consider the following example.  This is a 1-d
+convolution, where both the input and output buffers have a layout
+transformation applied.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(14,), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + A[i + f]
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(14,), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + A[(i + f) // 4, (i + f) % 4]
+
+
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=0.0)
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(4, 4), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + A[(i + f) // 4, (i + f) % 4]
+
+        for io,ii in T.grid(4,4):
+            if io==3 and ii>=2:
+                B[io,ii] = 0.0
+
+
+# sched.sequential_buffer_access(block='compute', buffer='B')
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(4, 4), "float32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 0 <= 4*io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    B[io, ii] = B[io, ii] + A[io + (ii + f) // 4, (ii + f) % 4]
+
+        for io,ii in T.grid(4,4):
+            if io==3 and ii>=2:
+                B[io,ii] = 0.0
+```
+
+
+We'd like to remove the conditional `if 0 <= 4*io + ii < 14` in the
+compute loop.  In order to do so, we need to prove that the body of
+the conditional is a no-op in the case where the conditional is false.
+
+Using the [updated `DomainTouched`
+utility](#enhancement-remove-no-op), this else-block would be a no-op.
+It is a write to `B[io,ii]` predicated on `4*io+ii >= 14`, followed by
+a write to `B[io,ii]` predicated on `io==3 and ii>=2`, without a read
+in between.  Since these predicates are equivalent, the first write is
+a no-op.
+
+```python
+# sched.remove_branching_through_overcompute(block='compute')

Review Comment:
   Does this only apply to outputs?  I think we should per-buffer directive that indicates that out-of-bounds access is allowed. The only thing in question is how to determine/specify that out-of-bounds reads from inputs is ok.  The user can add padding -INF to inputs to maxpool, but how does the maxpool compute know that it can use the out-of-bounds values?
   
   As to whether to actually utilize this should probably be left to the compiler.  Auto-scheduling should not be a replacement for compiler optimizations.



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+* When encountering a `T.assume` statement, this should be used for
+  later simplifications.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = n//8
+
+  # After simplification.  Because the range of `n` is provided in the
+  # assumption, n//8 can be simplified.
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = 0
+  ```
+
+  These assumptions are statements only known to be true at the
+  location of the `T.assume` call.  For assumptions based on value
+  stored in a buffer, the assumption may be invalidated by later
+  writes to the buffer.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      if A[0] == B[0]:
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+
+  # After simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      # The first access of B[0] may be replaced with 0 using the
+      # assumption.
+      if A[0] == 0:
+          # These later accesses of B[0] may not be replaced, because
+          # for all loop iterations i!=0, the value stored in B[0] has
+          # been overwritten since the T.assume call.
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+  ```
+
+### New Transform - Hoist Expression
+
+A new utility `HoistExpression`, which is a generalization of the
+current `HoistIfThenElse` pass.  The transformation `HoistExpression`
+would apply to the entire body of the `PrimFunc`, and would be used to
+avoid duplication of functionality between `HoistIfThenElse` and
+`HoistExpression`.
+
+`HoistExpression` would also be exposed as a metaschedule primitive,
+acting within a specified block of the `PrimFunc`, with the
+configuration options given below.
+
+```c++
+enum class HoistConditional {
+  kNone = 0,
+  kIfElseStmt = (1<<0),
+  kIfElseExpr = (1<<1),
+  kBooleanExpression = (1<<2),
+};
+
+enum class HoistLetBinding {
+  kNone = 0,
+  kRequiredByCondition = (1<<0),
+  kLetStmt = (1<<1),
+  kLetExpr = (1m<<2),
+};
+```
+
+* The values in `HoistConditional` are bit flags, indicating which
+  conditionals should be hoisted.
+
+  * `HoistConditional::kNone` - Do not hoist conditionals
+
+  * `HoistConditional::kIfElseStmt` - If set, attempt to hoist
+    conditionals that occur within `IfThenElseNode::condition`.
+
+  * `HoistConditional::kIfElseExpr` - If set, attempt to hoist
+    conditionals that occur as the condition of a
+    `builtin::if_then_else` call.
+
+  * `HoistConditional::kBooleanExpression` - If set, attempt to hoist
+    any `PrimExpr` whose data type is `DataType::Bool()`.
+
+* The values in `HoistLetBindings` are bit flags, indicating which
+  bindings should be hoisted.
+
+  * `HoistLetBinding::kNone` - Do not hoist any let bindings.
+
+  * `HoistLetBinding::kRequiredByCondition` - If set, hoist a let
+    binding if it is required in order to hoist a conditional.
+
+  * `HoistLetBinding::kLetStmt = (1<<1)` - If set, attempt to hoist
+    any let bindings performed using `LetStmt`.
+
+  * `HoistLetBinding::kLetExpr` - If set, attempt to hoist any let
+    bindings performed using `Let`.
+
+The existing pass `HoistIfElse` is roughly equivalent to using
+`HoistExpression` with `HoistConditional::kIfElseStmt` and
+`HoistLetBinding::kNone`.  The one exception is that `HoistIfElse`
+occurs after all let bindings have been inlined, and does not check
+let bindings when determining if a condition can be hoisted.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(4,4), "float32"]):
+    for i in T.serial(4):
+        is_in_bounds = i < 3
+        if is_in_bounds:
+            A[i] = 0.0
+
+# Incorrectly hoisted by `HoistIfThenElse`
+@T.prim_func
+def func(A: T.Buffer[(4,), "float32"]) -> None:
+    is_in_bounds = T.var("bool")
+    if is_in_bounds:
+        for i in T.serial(4):
+            is_in_bounds = i < 3
+            A[i] = 0.0
+```
+
+### New Transform - Reduce Loop Extents

Review Comment:
   I don't think this is necessary.  We could simply reuse loop partitioning, and break off pieces of the nest that will never execute.



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:

Review Comment:
   Which condition should be kept?  How do we decide?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1164440693

   Added some examples to build on top of @Lunderberg 's example
   
   ## Transformation
   
   The main difference between annotation and special handling are:
   
   - annotation is not necessarily to for correctness of the program, but it may provide hints towards future optimizations
   - Without annotation, the program still runs correctly, but certain optimizations may not trigger
   
   ### Step 0: Produce temp stages with annotation
   
   The transformation produces temporary buffers (AC and BC), where the relation between those data and the A, B are recorded in two blocks(preproc and post proc).
   
   Note that these additional annotations are hint for compilers to perform future optimizations(e.g. to lift them out our cancel. Our eventual goal could be directly reason those properties from the code, but annontations provides a first short cut.
   
   ```python
   @T.prim_func
   def grow(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       AC = T.alloc_buffer([4, 4], "int32")
       BC = T.alloc_buffer([4, 4], "int32")
   
   		for io, ii in T.grid(4, 4):
   				with T.block():
   					T.block_attr("preproc", "pad")
   				AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
   
       for i, j in T.grid(4, 4):
           BC[i, j] = 2 * AC[i, j]
   
       for io, ii in T.grid(14):
   				with T.block():
   						# hint that this is a cropping operation, 
   						# where we know that the remaining part in B is 0
   						# Additionally, the remaining uncovered values 
   						# are assumed to be 0, if not provided then no assumptions are made
   						T.block_attr("postproc", ["crop", 0])
   						B[io, ii] = BC[4 * io + ii]
   
   @T.prim_func
   def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       for i in T.grid(14):
           B[i] = A[i] + 1
   
   @R.func
   def main(A: T.Tensor[14, "int32"]):
   	lv0 = call_tir(grow, [A], (14))
   	# an intermdiate stage to show non-local reflowing
   	lv1 = call_tir(addone, [lv0], (14))
   	lv2 = call_tir(grow, [lv1], (14))
   	...
   
   ```
   
   Not the special crop annotation comes with an `assumed_value`, which is provided as part of transformation (and actually we can prove that it is safe if our layout transformation starts from B and go backwards.
   
   ### Step 1: Reconstruct constraint at TIR-Graph level
   
   By looking at the primfunc, we know that there is a desire to split out the preproc stage and postpost stage to the graph. Although it is totally fine for the compiler to choose not to do so and it is still a valid program. But let us say we choose to lift them out
   
   ```python
   @T.prim_func
   def grow_packed(AC: T.Buffer[[4,4], "int32"], BC: T.Buffer[[4,4], "int32"]):
       for i, j in T.grid(4, 4):
           BC[i, j] = 2 * AC[i, j]
   
   @T.prim_func
   def pad(A: T.Buffer[14, "int32"], AC: T.Buffer[[14, 14], "int32"]):
   		for io, ii in T.grid(4, 4):
   				with T.block():
   						T.block_attr("preproc", "pad")
   						AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
   
   @T.prim_func
   def crop_with_pad_assume(BC: T.Buffer[[4,4], "int32"], B: T.Buffer[14, "int32"]):
   		# Note that this crop carries a pad assertion(of other values of BC)
       for io, ii in T.grid(14):
   				with T.block():
   						T.block_attr("postproc", ["crop", 0])
   						B[io, ii] = BC[4 * io + ii]
   
   @R.func
   def main(A: T.Tensor[14, "int32"]):
   	lv0 = call_tir(pad, (4, 4), A)
   	lv1 = call_tir(grow, [lv0], (4, 4))
   	# These are two things that we want to use for global format reflowing	
   	lv2 = call_tir(crop_with_pad_assume, [lv1], (14))
   	lv3 = call_tir(addone, [lv2], (14)
   	lv4 = call_tir(pad, [lv2], (4, 4))
   	lv4 = call_tir(grow, [lv3], (4, 4))
   	lv5 = call_tir(crop_with_pad_assume, [(14))
   ```
   
   ### Step 2: Global Reflowing of layouts
   
   Now as a last step, let us say we will do global reflowing.
   
   - Start from reverse topo DAG order,
   - Whenever we encounter a pad, we reconstruct a in-memory data structure(something like BufferConstraint, e.g. PadMapping(constraint, pad_value=0))
   - We try to “backrop” the PadMapping through out the graph
   - Each function needs to have its own TIR analysis of how it flows things back, for example, in the case of `addone`, we can safely flow PadMapping back, changing `addone` to `addone_packed` by analyzing the TIR. If the `addone` is elemwise exp however, we need to insert a select operator(because `exp(0)=1` ) the message to input becomes `PadMapping(constraint, pad_value=undef)`.
   - When `PadMapping` meets `crop_with_pad_assert`, we can attempt to simplify and cancel out
   - When there are branches, transpositions in the graph level or other more complicated issues, we might choose to materialize
   
   ### Discussion
   
   There are a few key properties that is really desirable here:
   
   - transformation of PrimFunc do not change the PrimFunc interface: this is really important so we can transform a PrimFunc without worrying about how the graph interacts with it(as the interface remains the same, we can lift out the blocks earlier)
   - There are implicit assumption generated(`crop_with_pad_assume` ) to enable some simplification(otherwise a select is necessary, which is also not as bad). Note that assumption are generated under a global context (when we do transform padding we actually know that the overflowing field are 0). But extra amount of care is needed when we attempt to move `crop_with_pad_assume` , as it really depends on the value property of its input. A high-level gist is we should not do that, and instead the global reflowing of layout will reflow the `PadMapping` to `crop_with_pad_assume` then cancel it out.
   
   Talking about “constraints”, it is also useful to talk about categories of them, roughly we can divide them into three categories.
   
   - static_assert: We want to assert some invariance of the code, it is also necessary to “proof” that it is the case during compile time, otherwise compilation error needs to be raised.
   - (runtime) assert: We want to assert some invariance of the code, it is not necessary to “proof” that this is the case, but we need to do runtime checking if it cannot be proved.
   - assume (from __builtin_assume): We want to assert some invariance of the code, it is not necessary to “prove” that it is the case during compilation.
   
   All three types of constraints can be helpful. In our particular case, `assume` is being generated in `crop_with_pad_assume`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163633675

   > > For example, we may introduce explicit cache stage to add the padding, and mark this block for later processing.
   > 
   > Wouldn't that require a "remove entirely" annotation that was suggested against [here](https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163019805)? I could see how we could mark a transformation to be hoisted out later, but when some simplifications require the constraint to be expressed in the producer, and others in the consumer, exposing it to both `PrimFuncs` for local simplifications would require either duplication of the block, or maintaining non-local information only for a single pass. If the stage is duplicated, all but one of the duplicates would need to be marked as temporary. If the information is only retained for a single pass, then any scheduling/optimization of a single subgraph would require walking through the entire end-to-end model.
   
   @tqchen may clarify. I think it's suggesting marking and lifting the stage to the graph and do global flowing instead of removing it (though from the perspective of the subgraph (PrimFunc) it is removed from the PrimFunc


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163447517

   > So long as the constraints can be statically searched for, this approach makes sense to me. I would be more concerned about adding additional semantics to existing nodes, such as a AttrStmt node
   
   It doesn't add additional semantic, the computation semantic stays the same, it is a hint to the graph compiler. Here are an example using `block_attr` https://github.com/tlc-pack/relax/pull/161/files#diff-0c5223fca97ad1b31a686364a9acc65f59282bb256ba7fd70d9241986828abe5R46-R50


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r890745531


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+primarily used to allow simplifications in a producer.  See [section
+on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage, and [section on
+`tir.transform.RemoveArbitraryStore`](#new-lowering-transform-remove-tarbitrary)
+for its removal.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-primitive-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Primitive - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.

Review Comment:
   For S-TIR,actually we have three sort of "layouts":
   - the loop iter layout, controled by `split`/`fuse`/`reorder`
   - the block iter binding
   - the buffer layout 
   
   I can image one may want to simultaneously transform all of them or combinations of the three. For example,  https://github.com/apache/tvm/pull/11485 transform both the loop and block binding.
   
   What if we can create uniformly designed primitive or uniform guides for all useful combinations. And then create dedicated api like `sequential_buffer_access` on the ground of uniform interfaces.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891760089


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+intended for use as `BufferConstraintNode::value`, to indicate that it
+is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  This is primarily used to
+allow simplifications in a producer, as any partial computations
+written to this space (e.g. by vectorized operations) may be left
+as-is.
+
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+### New Transform - Hoist Expression

Review Comment:
   @wrongtest I remember you also proposed imperative loop partitioning in https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807. Could you comment how does this (and other related utilities / primitives in this RFC) relate to the one you proposed?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891359018


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+primarily used to allow simplifications in a producer.  See [section
+on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage, and [section on
+`tir.transform.RemoveArbitraryStore`](#new-lowering-transform-remove-tarbitrary)
+for its removal.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-primitive-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Primitive - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+A new S-TIR transformation `Schedule.sequential_buffer_access` should
+be introduced, which rewrites iteration loops according to the access
+pattern of a buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed

Review Comment:
   Thank you, and I had been going back and forth on that while drafting.  I had been considering the primitives as defining not just the extent of a search space, but also how many steps it would take for an optimizer to identify a transformation in the search space.
   
   I've updated this section from an independent schedule primitive to a utility function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r895028334


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```

Review Comment:
   The proposed construction seems to have a very interesting relation with s-tir block design!
   - For `Block`
       - With `T.axis.*` and `T.predicate`,  it specifies mapping from loop space to computation instances space (block iter space).
       - With `T.reads`, `T.writes`, it specifies mapping from block iter space to buffer access space.
   
   - For `BufferConstraint`
       - additionally specifies the buffer access space and behavior out of iter space before padding.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r890650624


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Thanks @Lunderberg. 
   
   It would be great to discuss ideas to simplify it further. While on one hand it is useful to introduce additional semantics to the IR itself. Doing so would generally affect the complexities around analysis and transformation.
   
   It would be great for us to explore such capabilities without introducing additional complexities on the IR.  Back to our goal of introducing padding. It should be possible to have an explicit buffer transformation stage that coppies the data into the target padded buffer(with predication) , then run computation on the padded value.
   
   There is certainly some tradeoffs here, but decoupling the padding behavior as a separate stage of IR computation should be able to allow us to reuse more primitives without having to specializing for BufferConstraint
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   > Re: @vinx13: To add a discussion point, I'd like to ask whether the semantic like over computation and writing default value of next operate, can be achieved with graph level rewriting.
   > Re: @tqchen: Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension rewriting and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialization a function to legalize. 
   
   The main differences I see is that A0 pushes much more work into manual hardware specific optimization (compute definitions + patterns/rewriters) that is not as easily re-purposable for other hardware targets; whereas A1 provides the infrastructure for more general compiler simplification to proceed from hardware semantics the user provides about the buffer when transforming the layout at schedule time. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894958067


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Thanks for the discussions so far :) I think we all agree that having additional information variant in the buffer interface would results powerful expressiveness. Just like in high-level language while loop and for resulted in being more powerful.
   
   Just like stated in the last statement. We could indeed consider put some constraints at the interface level. That would need some more thoughts on semantics, how would they interact with graph, and structural complexity (whether padding is something that worth the IR complexity).
   
   On the other hand, I would still encourage us to think whether such complexity is worthwhile for only for padding. As a non-local back to back transformation padding is relatively easy and still achieved the goal. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   @Lunderberg That's true. If the reduction dimension is padded, we will need to insert hint in the graph to assert it was previously padded by 0. From the graph rewriting pov, we can also see this a transformation done in graph level (doesn't rely on arithmetic simplifications)
   
   Example
   ```
   X: R.Tensor[16]
   F: R.Const[3]
   Y: R.Tensor[18] = conv1d(X, F, pad=2)
   Z: R.Tensor[20] = conv1d(Y, F, pad=2)
   ```
   Inserting padding and crop:
   ```
   X: R.Tensor[16]
   F: R.Const[3]
   X_pad = pad(X, before=2, after=6)
   Y = conv1d(X_pad, F, pad=0)
   assert(Y[18:] == 0)
   Y_crop = crop(Y[0:18])
   Y_crop_pad = pad(Y_crop, before=2, after=4)
   Z = conv1d(Y_crop_pad, F, pad=0)
   Z_crop = crop(Z[0:20])
   ```
   Then we can propagate the padding information and combine
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r895028334


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```

Review Comment:
   The proposed construction seems to have a very interesting relation with s-tir block design!
   - For `Block`
       - With `T.axis.*` and `T.predicate`,  it specifies mapping from loop space to computation instances space (block iter space).
       - With `T.reads`, `T.writes`, it specifies mapping from block iter space to buffer access space.
   
   - For `BufferConstraint`
       - Additionally specifies the buffer access space and behavior out of iter space before padding.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891736820


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Thanks @Lunderberg  for dissected discussions, this is helpful.
   
   Besides Goal 1, there is one additional implied goal:
   
   - Goal 7: The compositionality of primitives of the buffer with layout-constraints. Specifically, how composable are the existing and new primitives(such as split/reorder/tensorization, reduction factorization) with the buffer layout constraints.
   
   When building abstractions, we are actually trying to make a balance among two things: the simplicity/composationality and the things we can support. 
   
   - It is quite natural that a more complicated impl would hit more marks initially.
   - On the other hand, there is always a consideration of added complexity and how composable our additions are with existing constructs. 
   
   In our case, we are facing a N * M problem. Where N is number of primitives and M is number of possible IR variantions(like layout constraints) we introduce to the IR. Additional field in the IR effectively means we either have to (a) introduce specific codepath to handle layout constraints, (b) generalize all relevant primitives to take that into account. The N * M problem will grow as N and M increases.
   
   To manage our complexity, our current rationale is to keep M as stable as possible and grow N that can compose with each other.
   
   It is also useful to come back to the high-level goal besides these goals for a single function. Our high-level goal is to enable effective end to end models under a good native layout(which involves padding and layout transformation). And it would actually be really nice to have an example at the e2e level to show how the set of transformations affect our optimizations.
   
   Among the existing goals listed. Goal 6 is certainly a very important one. Goal 3 is primarily an implementation difference as in terms of different ways of building pattern matching. Goal 4 is not necessarily a need as many optimizations actually benefit from reduced complexity (e.g. tensorization in physical memory)
   
   Goal 6 is an important one that would indeed touches the high-level (e2e) goal itself. Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891760089


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+intended for use as `BufferConstraintNode::value`, to indicate that it
+is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  This is primarily used to
+allow simplifications in a producer, as any partial computations
+written to this space (e.g. by vectorized operations) may be left
+as-is.
+
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+### New Transform - Hoist Expression

Review Comment:
   @wrongtest I remember you also proposed imperative loop partitioning in https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807. Could you comment how does this relate to the one you proposed?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163616169

   > For example, we may introduce explicit cache stage to add the padding, and mark this block for later processing.
   
   Wouldn't that require a "remove entirely" annotation that was suggested against [here](https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163019805)?  I could see how we could mark a transformation to be hoisted out later, but when some simplifications require the constraint to be expressed in the producer, and others in the consumer, exposing it to both `PrimFuncs` for local simplifications would require either duplication of the block, or maintaining non-local information only for a single pass.  If the stage is duplicated, all but one of the duplicates would need to be marked as temporary.  If the information is only retained for a single pass, then any scheduling/optimization of a single subgraph would require walking through the entire end-to-end model.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   > Re: @vinx13: To add a discussion point, I'd like to ask whether the semantic like over computation and writing default value of next operate, can be achieved with graph level rewriting.
   > Re: @tqchen: Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension rewriting and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   * A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   * A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialize a function to legalize between layers only when necessary. 
   
   The main differences I see is that A0 pushes more work into manual hardware specific optimization (compute definitions + patterns/rewriters) that is not as easily repurposable for other hardware targets; whereas A1 provides the infrastructure for more general compiler simplification to proceed from hardware semantics the user provides about the buffer when transforming the layout at schedule time. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   > Re: @vinx13: To add a discussion point, I'd like to ask whether the semantic like over computation and writing default value of next operate, can be achieved with graph level rewriting.
   > Re: @tqchen: Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension rewriting and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   * A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   * A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialize a function to legalize between layers when necessary. 
   
   The main differences I see is that A0 pushes more work into manual hardware specific optimization (compute definitions + patterns/rewriters) that is not as easily repurposable for other hardware targets; whereas A1 provides the infrastructure for more general compiler simplification to proceed from hardware semantics the user provides about the buffer when transforming the layout at schedule time. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163535262

   Indeed if buffer is used in annotation value that will change the semantic of a node, however, that are different ways to represent this, as long as it can be reconstructed later. For example, we may introduce explicit cache stage to add the padding, and mark this block for later processing.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1152928725

   Thanks for the all great discussions! It is so excited that we will have a more powerful ability to handle all things like paddings and imperfect tiles.
   
   Since our team rely on the code path of s-tir, we are extremely interested in the story on s-tir. I would be very appreciated if we have some details on s-tir padding. I would like to use a [127, 127, 127] matmul to depict my questions :)
   
   ```python
   @T.prim_func
   def matmul(A: T.Buffer[(127, 127), "float32"], B: T.Buffer[(127, 127), "float32"], C: T.Buffer[(127, 127), "float32"]):
       for i, j, k in T.grid(127, 127, 127):
           with T.block("compute"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   C[vi, vj] = 0.0
               C[vi, vj] += A[vi, vk] * B[vk, vj]
   ```
   
   In current s-tir state, we can construct padded loop and buffer using existing primitives by "split and then fuse" trick:
   ```python
   s = tvm.tir.Schedule(matmul)
   blk = s.get_block("compute")
   i, j, k = s.get_loops(blk)
   s.fuse(*s.split(i, factors=[4, 32]))
   s.fuse(*s.split(j, factors=[4, 32]))
   s.fuse(*s.split(k, factors=[4, 32]))
   s.transform_layout(blk, "A", lambda i,k: ((i // 32) * 32 + i % 32, (k // 32) * 32 + k % 32))
   s.transform_layout(blk, "B", lambda k,j: ((k // 32) * 32 + k % 32, (j // 32) * 32 + j % 32))
   s.transform_layout(blk, "C", lambda i,j: ((i // 32) * 32 + i % 32, (j // 32) * 32 + j % 32))
   ```
   We will get (if simplified)
   ```python
   @T.prim_func
   def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]):
       for i_0_i_1_fused, j_0_j_1_fused, k_0_k_1_fused in T.grid(128, 128, 128):
           with T.block("compute"):
               vi = T.axis.spatial(127, i_0_i_1_fused)
               vj = T.axis.spatial(127, j_0_j_1_fused)
               vk = T.axis.reduce(127, k_0_k_1_fused)
               T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and k_0_k_1_fused < 127)
               T.reads(A[vi, vk], B[vk, vj])
               T.writes(C[vi, vj])
               with T.init():
                   C[vi, vj] = T.float32(0)
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   ```
   Then the only thing left is the condition for padding: `T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and k_0_k_1_fused < 127)`. I believe we now get to the point on current RFC about over-computation and branch tradeoff. And below are some my questions ~
   
   1. What happened when change to `s.transform_layout(...,  pad_value=0)`? (if we want over-computations)
      - (possible behavior 1) Insert padding filling code as a producer block of `compute`.  
        - since the effect is immediate, maybe we do not need `BufferConstraint` annotations afterwards?
      - (possible behavior 2) Annotate buffers and let lowering passes to handle.
        - we may require `BufferConstraint` to direct lowering passes, 
      - (possible behavior 3) Pass `BufferConstraint` upwards into graph level
        -  thus assume the param buffer match the constraint, do not write edge values.
      
   2.  For (1.2)(1.3), it seems encode the `BufferConstraint` into the buffer object is not the only choice.
       - For s-tir,  fix me, at least for common cases the constraint could be treat to be local wrt the transformed block. What if we encode the constraint just into the block, as its memory access properties.
         We found previously, block memory annotations `T.reads`, `T.writes` (`BufferRegion`) have some limitations that they loss conditional access informations. Maybe we can also combine `BufferConstraint` with `BufferRegion`?
   
       - For graph level annotations, IIUC,  it uses "Tensor" typed value instead of "Buffer" conceptually. Maybe we still need another construction instead of `Buffer` with `BufferConstraint` field? 
         We could also consider instantiate graph level transformation explicitly. This is our solution currently: https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4. 
   
       - Nevertheless, if finally we decide extent the buffer node structure, hope we can have an explicit lifetime for the `BufferConstraint` in the TIR lowering. Thus storage related passes afterwards do not bother, especially for customized passes developed by vendors.
   
   3. For the reduce axis padding, mentioned in https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301
       - In TIR level, since the schedule primitive should preserve the semantic correctness, how we prove the `k` dimension padding should only be zero? Especially when we do not know it is a "matmul" op generally. I think it is important if we want to use padded `transform_layout` in auto-schedule fashion applications.
   
   cc @Lunderberg @tqchen @vinx13 @Hzfengsy 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r909777012


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+* When encountering a `T.assume` statement, this should be used for
+  later simplifications.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = n//8
+
+  # After simplification.  Because the range of `n` is provided in the
+  # assumption, n//8 can be simplified.
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = 0
+  ```
+
+  These assumptions are statements only known to be true at the
+  location of the `T.assume` call.  For assumptions based on value
+  stored in a buffer, the assumption may be invalidated by later
+  writes to the buffer.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      if A[0] == B[0]:
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+
+  # After simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      # The first access of B[0] may be replaced with 0 using the
+      # assumption.
+      if A[0] == 0:
+          # These later accesses of B[0] may not be replaced, because
+          # for all loop iterations i!=0, the value stored in B[0] has
+          # been overwritten since the T.assume call.
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+  ```
+
+### New Transform - Hoist Expression
+
+A new utility `HoistExpression`, which is a generalization of the
+current `HoistIfThenElse` pass.  The transformation `HoistExpression`
+would apply to the entire body of the `PrimFunc`, and would be used to
+avoid duplication of functionality between `HoistIfThenElse` and
+`HoistExpression`.
+
+`HoistExpression` would also be exposed as a metaschedule primitive,
+acting within a specified block of the `PrimFunc`, with the
+configuration options given below.
+
+```c++
+enum class HoistConditional {
+  kNone = 0,
+  kIfElseStmt = (1<<0),
+  kIfElseExpr = (1<<1),
+  kBooleanExpression = (1<<2),
+};
+
+enum class HoistLetBinding {
+  kNone = 0,
+  kRequiredByCondition = (1<<0),
+  kLetStmt = (1<<1),
+  kLetExpr = (1m<<2),
+};
+```
+
+* The values in `HoistConditional` are bit flags, indicating which
+  conditionals should be hoisted.
+
+  * `HoistConditional::kNone` - Do not hoist conditionals
+
+  * `HoistConditional::kIfElseStmt` - If set, attempt to hoist
+    conditionals that occur within `IfThenElseNode::condition`.
+
+  * `HoistConditional::kIfElseExpr` - If set, attempt to hoist
+    conditionals that occur as the condition of a
+    `builtin::if_then_else` call.
+
+  * `HoistConditional::kBooleanExpression` - If set, attempt to hoist
+    any `PrimExpr` whose data type is `DataType::Bool()`.
+
+* The values in `HoistLetBindings` are bit flags, indicating which
+  bindings should be hoisted.
+
+  * `HoistLetBinding::kNone` - Do not hoist any let bindings.
+
+  * `HoistLetBinding::kRequiredByCondition` - If set, hoist a let
+    binding if it is required in order to hoist a conditional.
+
+  * `HoistLetBinding::kLetStmt = (1<<1)` - If set, attempt to hoist
+    any let bindings performed using `LetStmt`.
+
+  * `HoistLetBinding::kLetExpr` - If set, attempt to hoist any let
+    bindings performed using `Let`.
+
+The existing pass `HoistIfElse` is roughly equivalent to using
+`HoistExpression` with `HoistConditional::kIfElseStmt` and
+`HoistLetBinding::kNone`.  The one exception is that `HoistIfElse`
+occurs after all let bindings have been inlined, and does not check
+let bindings when determining if a condition can be hoisted.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(4,4), "float32"]):
+    for i in T.serial(4):
+        is_in_bounds = i < 3
+        if is_in_bounds:
+            A[i] = 0.0
+
+# Incorrectly hoisted by `HoistIfThenElse`
+@T.prim_func
+def func(A: T.Buffer[(4,), "float32"]) -> None:
+    is_in_bounds = T.var("bool")
+    if is_in_bounds:
+        for i in T.serial(4):
+            is_in_bounds = i < 3
+            A[i] = 0.0
+```
+
+### New Transform - Reduce Loop Extents

Review Comment:
   That's a good point, and would avoid having the special case pass.
   
   Looking at the implementation of `LoopPartition` for edge cases, and making some notes to myself on the steps required.
   
   * Apply a `T.likely` annotation when checking if it's a pad value (e.g. `if !T.likely(4*io + ii < 14)`), since `LoopPartition` uses this to identify partitionable conditions.
   * Maintain the `T.likely` annotation when hoisting part of a conditional (e.g. When hoisting `io==3` out of `!T.likely(io==3 and ii >=2`).
   * Look into relaxing the restriction against partitioning a constant loop, currently only allowed by a pass config.  If we're generating loops that we know should be partitioned, it would be strange to also require the user to opt-in.  I don't know the history of this restriction, so this would require some investigation.  (Perhaps could allow the additional partition only if the loop is a serial loop and all but one of the partitions are no-ops.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r890708951


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`

Review Comment:
   Could we name it as "undef"? Do we need to specify the behavior when the `arbitraty` value involve in computations like llvm undef and poission values.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891831175


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+intended for use as `BufferConstraintNode::value`, to indicate that it
+is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  This is primarily used to
+allow simplifications in a producer, as any partial computations
+written to this space (e.g. by vectorized operations) may be left
+as-is.
+
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+### New Transform - Hoist Expression

Review Comment:
   They are different alternative or can be combined on certain workloads. There is a discussion on performance issue of matmul when we just change dimension 128 -> 127. https://discuss.tvm.apache.org/t/te-vectorize-do-we-have-plan-to-support-vectorize-for-non-divisible-split/12469
   I think it might be a good working example. Below is what user get with loop split  `j`: 127 -> (4, 32)
   
   ```python
   for i in range(127):
       for k in range(127):
           for j.outer in range(4):
                for j.inner in T.vectorized(32):
                    if T.likely(j.outer * 32 + j.inner < 127, dtype="bool"):
                        C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
   ```
   
   The issue is that complex condition has to be introduced to maintain the program semantic, and it hurts the performance and generally we can not vectorize program with control flow.
   
   Now I understand we have different alternatives to handle this:
   - Loop partition
   We can already annotate the loop var with hint using non-imperative loop partition.
   
     `for j.outer in range(4, annotations={"pragma_loop_partition_hint": 1}`
   
     After `LoopPartition` pass (and simplify) it becomes:
     ```python
     for i in range(127):
         for k in range(127):
             # j.outer in [0, 3)
             for j.outer in range(3):
                  for j.inner in T.vectorized(32):
                       # condition is const true, optimize out
                       C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
             # j.outer in [3, 4), optimize out
             for j.inner in T.vectorized(31):
                  # condition becomes j.inner < 31, hoisted with loop
                  C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
     ```
     Then the condition branch get eliminated on different loop parts, thus becomes more friendly to performance optimizations like vectorization. For "imperative" partition, it just propose we can just partition on schedule phase when one wants to schedule different parts, such as giving different vectorization width.
   
   - Loop padding
   
     With current RFC, I understand we can padding `C` and `B`'s innermost dimension to 128, and drop the condition directly somehow. Then it directly becomes (IIUC, we may also insert some "arbitrary" value filling code on edges and optimize them out then?) 
     ```python
     for i in range(127):
       for k in range(127):
           for j.outer in range(4):
                for j.inner in T.vectorized(32):
                    C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
   ```
   
   On this particular case, I believe the padding is the better choice since we can get very neat codes with minimal over-computations. And we can also utilize the padding trick for different loop parts in alternative (1).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r909639315


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:

Review Comment:
   If the merging of complementary conditionals is valid, then which condition is kept doesn't matter for correctness.  For two conditions `A` and `B`, if `A` implies `!B` and `B` implies `!A`, then `A` and `!B` are different functional forms of the same expression.
   
   That said, I'd probably keep the first conditional, as it allows for the simplification to be viewed as a specific case of a more general transformation.  Given a conditional that is followed by another statement outside the conditional, it is valid to move the statement inside the conditional, placed at the end of both the `then_case` and `else_case`.  If the statement being moved is itself a conditional, then it may be simplified.  In this case, the intermediate step would look as follows.
   
   ```python
   @T.prim_func
   def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
         for i,j in T.grid(4,4):
             if 4*i + j < 14:
                 A[i] = 0.0
                 if i==3 and j>=2:
                     B[i] = 2.0
                 else:
                     B[i] = 3.0
             else:
                 A[i] = 1.0
                 if i==3 and j>=2:
                     B[i] = 2.0
                 else:
                     B[i] = 3.0
   ```
   
   I wouldn't want to generate the intermediate state in all cases, because it may not always lead to useful simplifications, which is why it would only be applied in the special cases of identical conditions and complementary conditions.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r910098544


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+* When encountering a `T.assume` statement, this should be used for
+  later simplifications.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = n//8
+
+  # After simplification.  Because the range of `n` is provided in the
+  # assumption, n//8 can be simplified.
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = 0
+  ```
+
+  These assumptions are statements only known to be true at the
+  location of the `T.assume` call.  For assumptions based on value
+  stored in a buffer, the assumption may be invalidated by later
+  writes to the buffer.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      if A[0] == B[0]:
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+
+  # After simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      # The first access of B[0] may be replaced with 0 using the
+      # assumption.
+      if A[0] == 0:
+          # These later accesses of B[0] may not be replaced, because
+          # for all loop iterations i!=0, the value stored in B[0] has
+          # been overwritten since the T.assume call.
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+  ```
+
+### New Transform - Hoist Expression
+
+A new utility `HoistExpression`, which is a generalization of the
+current `HoistIfThenElse` pass.  The transformation `HoistExpression`
+would apply to the entire body of the `PrimFunc`, and would be used to
+avoid duplication of functionality between `HoistIfThenElse` and
+`HoistExpression`.
+
+`HoistExpression` would also be exposed as a metaschedule primitive,
+acting within a specified block of the `PrimFunc`, with the
+configuration options given below.
+
+```c++
+enum class HoistConditional {
+  kNone = 0,
+  kIfElseStmt = (1<<0),
+  kIfElseExpr = (1<<1),
+  kBooleanExpression = (1<<2),
+};
+
+enum class HoistLetBinding {
+  kNone = 0,
+  kRequiredByCondition = (1<<0),
+  kLetStmt = (1<<1),
+  kLetExpr = (1m<<2),
+};
+```
+
+* The values in `HoistConditional` are bit flags, indicating which
+  conditionals should be hoisted.
+
+  * `HoistConditional::kNone` - Do not hoist conditionals
+
+  * `HoistConditional::kIfElseStmt` - If set, attempt to hoist
+    conditionals that occur within `IfThenElseNode::condition`.
+
+  * `HoistConditional::kIfElseExpr` - If set, attempt to hoist
+    conditionals that occur as the condition of a
+    `builtin::if_then_else` call.
+
+  * `HoistConditional::kBooleanExpression` - If set, attempt to hoist
+    any `PrimExpr` whose data type is `DataType::Bool()`.
+
+* The values in `HoistLetBindings` are bit flags, indicating which
+  bindings should be hoisted.
+
+  * `HoistLetBinding::kNone` - Do not hoist any let bindings.
+
+  * `HoistLetBinding::kRequiredByCondition` - If set, hoist a let
+    binding if it is required in order to hoist a conditional.
+
+  * `HoistLetBinding::kLetStmt = (1<<1)` - If set, attempt to hoist
+    any let bindings performed using `LetStmt`.
+
+  * `HoistLetBinding::kLetExpr` - If set, attempt to hoist any let
+    bindings performed using `Let`.
+
+The existing pass `HoistIfElse` is roughly equivalent to using
+`HoistExpression` with `HoistConditional::kIfElseStmt` and
+`HoistLetBinding::kNone`.  The one exception is that `HoistIfElse`
+occurs after all let bindings have been inlined, and does not check
+let bindings when determining if a condition can be hoisted.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(4,4), "float32"]):
+    for i in T.serial(4):
+        is_in_bounds = i < 3
+        if is_in_bounds:
+            A[i] = 0.0
+
+# Incorrectly hoisted by `HoistIfThenElse`
+@T.prim_func
+def func(A: T.Buffer[(4,), "float32"]) -> None:
+    is_in_bounds = T.var("bool")
+    if is_in_bounds:
+        for i in T.serial(4):
+            is_in_bounds = i < 3
+            A[i] = 0.0
+```
+
+### New Transform - Reduce Loop Extents
+
+Reduce the extent of loops based on conditionals present in the body
+of the loop.
+
+For any non-vectorized `tir::For` loop (`ForKind::kSerial` or
+`ForKind::kParallel`), if the body is a conditional and the
+conditional's `else_case` is empty, determine if the expression is of
+the form `(loop $CMP_OP const) && (...)`.  If so, use the comparison
+operator to reduce the loop extent, such that loop skips values for
+which the comparison is provably false.
+
+TODO: Double-check that this isn't already implemented elsewhere.
+
+TODO: Check if it is implementable using `IntSetAnalyzer`.
+
+Below is an example of how this can work along-side `HoistExpression`
+to simplify the initialization of padding.
+
+```python
+# Original function.
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    for i, j in T.grid(4, 4):
+        if i == 0 and j < 2:
+            A[i, j] = 0.0
+
+
+# After hoisting with HoistConditional::kBooleanExpression
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    for i in T.serial(4):
+        if i == 0:
+            for j in T.serial(4):
+                if j < 2:
+                    A[i, j] = 0.0
+
+
+# After reducing the extents of serial loops
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    i = 0
+    for j in T.serial(2):
+        A[i, j] = 0.0
+```
+
+
+
+### Utility - Merge Adjacent Loops
+
+If it does not impact the resulting computation, loops may be merged
+together.  This is a valid transformation if both loops are serial
+loops, the loops have the same indices, and if the merging respects
+data dependencies.  This would be exposed as a metaschedule primitive,
+which takes input of the `LoopRV` to be merged.
+
+For adjacent loops, to prove that there is no data dependency, two
+conditions must hold.
+
+1. For all loop indices `i` and `j` where `i > j`, the set of indices
+   written by the first loop in iteration `i` is distinct from the set
+   of indices accessed by the second loop in iteration `j`.  That is,
+   merging the loops wouldn't cause the second loop body to read
+   partial values, nor would it cause the first loop body to overwrite
+   a value produced by the second loop body.
+
+2. For all loop indices `i` and `j` where `i < j`, the set of indices
+   read by the second loop in iteration `i` is distinct from the set
+   of indices written by the second loop in iteration `j`.  That is,
+   merging the loops wouldn't cause the second loop body to overwrite
+   values that are still required by the first loop body.
+
+Element-wise loops do not have any data dependencies, and adjacent
+element-wise loops may be merged.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+
+    for i in T.serial(16):
+        A[i] = 1.0
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j].
+#    c. intersection([i], [j]) = [i] if i==j else [].
+#    d. If i>j, the intersection is empty
+#
+# 2. a. In iteration i, loop 1 reads from index [].
+#    b. In iteration j, loop 2 writes to index [j]
+#    c. intersection([], [j]) = []
+#    c. For all i,j, the intersection is empty
+#
+# Therefore, this merger is valid
+
+# After merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+        A[i] = 1.0
+```
+
+The second loop may read indices that were written in an earlier
+iteration.  Merging would not impact the result.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+
+    for i in T.serial(16):
+        if i > 0:
+            A[i] = A[i - 1] + 1.0
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j,j-1].
+#    c. i>j implies that i!=j and i!=j-1.
+#    c. For all i,j where i<j,
+#
+# 2. a. In iteration i, loop 1 reads from index [].
+#    b. In iteration j, loop 2 writes to index [j]
+#    c. For all i,j, intersection([], [j]) = [].
+#
+# Therefore, this merger is valid
+
+
+# After merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+        if i > 0:
+            A[i] = A[i - 1] + i
+```
+
+The second loop may not read indices that were written in a later
+iteration of the first loop.  In this case, merging would impact the
+output values.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = i
+
+    for i in T.serial(16):
+        if 0 < i < 15:
+            A[i] = A[i - 1] + A[i] + A[i + 1]
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j-1,j,j+1].
+#    c. If i==j+1, then intersection([j+1], [j-1,j,j+1]) = [j+1],
+#       which is non-empty.
+#
+# Therefore, this merger is not valid.
+```
+
+### New Primitive - Remove Branching Through Overcompute
+
+A new transform which attempts to reduce branching by allowing
+overcompute.  It takes an argument to specify which block it should be
+applied within.
+
+For each `IfThenElseStmt`, check if the
+`IfThenElseStmtNode::else_case` is a simplified form of the
+`IfThenElseStmtNode::then_case`.  This check is done by simplifying
+`then_case`, under the assumption that `condition` is false, and
+substituting the known value in a `BufferConstraint` in any
+`BufferLoad` for which the predicate can be proven to be true.  If
+this simplified form is identical to the `else_case`, then the entire
+if/else block can be replaced with `then_case`.  Otherwise, this check
+is repeated to see if the `then_case` can be simplified down to the
+`else_case`.  If neither simplification holds, then no change is made.
+
+For example, consider the following example.  This is a 1-d
+convolution, where both the input and output buffers have a layout
+transformation applied.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(14,), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + A[i + f]
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(14,), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + A[(i + f) // 4, (i + f) % 4]
+
+
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=0.0)
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(4, 4), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + A[(i + f) // 4, (i + f) % 4]
+
+        for io,ii in T.grid(4,4):
+            if io==3 and ii>=2:
+                B[io,ii] = 0.0
+
+
+# sched.sequential_buffer_access(block='compute', buffer='B')
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(4, 4), "float32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 0 <= 4*io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    B[io, ii] = B[io, ii] + A[io + (ii + f) // 4, (ii + f) % 4]
+
+        for io,ii in T.grid(4,4):
+            if io==3 and ii>=2:
+                B[io,ii] = 0.0
+```
+
+
+We'd like to remove the conditional `if 0 <= 4*io + ii < 14` in the
+compute loop.  In order to do so, we need to prove that the body of
+the conditional is a no-op in the case where the conditional is false.
+
+Using the [updated `DomainTouched`
+utility](#enhancement-remove-no-op), this else-block would be a no-op.
+It is a write to `B[io,ii]` predicated on `4*io+ii >= 14`, followed by
+a write to `B[io,ii]` predicated on `io==3 and ii>=2`, without a read
+in between.  Since these predicates are equivalent, the first write is
+a no-op.
+
+```python
+# sched.remove_branching_through_overcompute(block='compute')

Review Comment:
   > Does this only apply to outputs?
   
   This specific reasoning, of identifying the overcompute as a no-op by virtue of being overwritten later, is specific to outputs.  In general, overcompute may only be introduced if the new statements can be shown to be no-ops.
   
   * A write that is overwritten without being read is a no-op. (Used to introduce overcompute of outputs.)
   * A write to a location that is deallocated without being read is a no-op. (Used to introduce overcompute of local caches.)
   * A write of the same value that is already at the write location is a no-op.  (Used to introduce overcompute based on known facts about the input buffer, such as `Output[0] = Output[0] + Input[i]` being a no-op if `Input[i]` is known to be zero.
   
   > I think we should per-buffer directive that indicates that out-of-bounds access is allowed.
   
   Agreed.  This is specified using the `pad_value` argument in a transformation, and is exposed to local analysis either using the proposed `T.assume` intrinsic for input buffers, or through `BufferStore` for local caches.  That way, the out-of-bounds access within an existing statement can be used to know that other out-of-bounds access are safe to use.  There's an example [here](https://github.com/Lunderberg/tvm-rfcs/blob/buffer_layout_padding/rfcs/0077-layout-transform-padding.md#apply-operator-element-wise-over-the-transformation-padding), where `T.assume(4*io+ii < 14 or A[io,ii] == T.undef())` is used to know that it is safe to insert reads to indices for which `4*io + ii >= 14`.  If no `pad_value` is specified, then there is no previous read/write from those locations, and so a later transformation may not introduce a new read/write.
   
   > The user can add padding -INF to inputs to maxpool, but how does the maxpool compute know that it can use the out-of-bounds values?
   
   This would be the role of the `T.assume` intrinsic.  It wouldn't have any effect on its own, and would be removed as part of lowering, but would expose information to the function that could be used as part of optimizations.  In this case, the statement `T.assume(!indices_are_padding or buf[indices] == -INF)` could let maxpool know that those values can be used.
   
   > As to whether to actually utilize this should probably be left to the compiler. Auto-scheduling should not be a replacement for compiler optimizations.
   
   For transforms that could either be used in auto-scheduling or when compiling, I had been seeing compiler optimizations as the preferred place to implement transforms that are universally beneficial, and auto-scheduling as the preferred place to implement functionality that are conditionally beneficial.  In this case, because the overcompute may be very large for some pathological cases, I think it is better exposed as a scheduling decision, as the cost of overcompute may not always be worth the benefit of avoiding a branch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   > Re: @vinx13: To add a discussion point, I'd like to ask whether the semantic like over computation and writing default value of next operate, can be achieved with graph level rewriting.
   > Re: @tqchen: Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension reordering and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialization a function to legalize. 
   
   The main differences I see is that A0 pushes much more work into hardware specific optimization (compute definitions + patterns/rewriters) that is not as easily re-purposable for other hardware targets; whereas A1 provides the infrastructure for more general compiler simplification to proceed from hardware semantics the user provides about the buffer when transforming the layout at schedule time. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891323645


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`

Review Comment:
   I'd been avoiding "undefined", as that could cause confusion with C++ notion of "undefined behavior".  Where any existence of undefined behavior in C++ causes the entire program to have undefined behavior, use of `T.arbitrary` would only propagate to any expression that uses `T.arbitrary` as an input.
   
   That said, the LLVM `undef` maps quite well to the concept I was thinking of, so I agree with the name change.  As far as I can tell, the main difference between the proposed `tir::builtin::undef` and LLVM's `undef` is that LLVM is tracked to individual bits, whereas the `tir::builting::undef` would propagate to entire values.
   
   I've changed the name, added a link to the Prior Art section for LLVM's `undef`, and updated this section to define the computations that use `undef`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r890650624


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   
   It would be great to discuss ideas to simplify it further. While on one hand it is useful to introduce additional semantics to the IR itself. Doing so would generally affect the complexities around analysis and transformation.
   
   It would be great for us to explore such capabilities without introducing additional complexities on the IR.  Back to our goal of introducing padding. It should be possible to have an explicit buffer transformation stage that coppies the data into the target padded buffer(with predication) , then run computation on the padded value.
   
   There is certainly some tradeoffs here, but decoupling the padding behavior as a separate stage of IR computation should be able to allow us to reuse more primitives without having to specializing for BufferConstraint
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] areusch commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
areusch commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891581768


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   i agree it's helpful to think through alternatives here. could we consider some example transformations we may want to make (e.g. eliding or moving the operations which write to the padding) or pattern-matching on such operations and reducing them to hardware intrinsics (e.g. perhaps there is a way to tell the hardware how much padding to include when the value is always constant and a particular operation is in use). 
   
   on the one hand, modeling the padding computation explicitly in TIR is a more logical reuse of existing TIR. on the other hand, it may be more expensive to match this and the compiler may be slower. 
   
   i'm not necessarily in favor fo any one solution, but i think this is the sort of thing we should discuss to try and inform that decision.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891387773


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+primarily used to allow simplifications in a producer.  See [section
+on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage, and [section on
+`tir.transform.RemoveArbitraryStore`](#new-lowering-transform-remove-tarbitrary)
+for its removal.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-primitive-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Primitive - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.

Review Comment:
   Thank you, and there was a similar comment from @Hzfengsy .  I've updated this section to instead propose a utility function to generate the appropriate split/fuse/reorder, rather than being itself a primitive.  Taking another look at https://github.com/apache/tvm/pull/11485 and the block iter bindings, I think the utility might be as simple as applying `transform_block_layout` with a mapping defined based on the block iter bindings of spatial dimensions.
   
   The main uniform usage coming to mind would be applying the transformation to all three layouts uniformly.  Though, that would only be well-defined if all three are already uniform, so that wouldn't catch cases where it changes from scatter to gather or vice versa.  I'll think it over a bit more there, thank you!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163436177

   > Indeed it is important to avoid having a separate compute definition for each workload on a new target. In this particular case, all computation definition would start with the original layout. Then there is a "schedule transformation" like transform layout which will generate the new stage as part of the scheduling process.
   
   Thank you, and that is roughly how I'm seeing it as well.  That everything starts with the base compute definition and is modified from there.  If I understand correctly, the main differences are below.
   
   * Option A: Layout transformations of inputs are allowed, but only during initial graph-level optimization.  When optimizing an individual PrimFunc, layout transformations of inputs and outputs are not allowed.
   
   * Option B: Layout transformations of inputs and outputs are not allowed.  If this is desired, it should be done by first introducing a cache stage in TIR, then transforming the layout of the cache, and finally by a graph-level transformation that inspects each PrimFunc and hoists the cache stage out.
   
   > The particular stage can be marked, which contains effectively the same information as BufferConstraint, except that it does not introduce new data structures. During global layout reflowing, such information can be used to guide the reflowing to reconstruct a data structure like BufferConstraint or other Layout mappings and use that to serve the same purpose.
   
   So long as the constraints can be statically searched for, this approach makes sense to me.  I would be more concerned about adding additional semantics to existing nodes, such as a AttrStmt node, since it then requires passes to be aware not only of the existence of the constraint, but also that it must be reconstructed from the existing data structure.  This approach would make it much more difficult for a static analysis tool to identify locations where the constraints must be updated.
   
   As a way to potentially find a way forward, what if we start by implementing pad values only for buffers that are allocated internally to a function?  This would be allowed behavior under both Option A and Option B, and would help determine how difficult reconstruction of the constraints would be from the transformation block without any additional annotation.  This could help motivate whether additional annotations are necessary, regardless of whether they are stored alongside the Buffer itself or in a separate attribute/annotation.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1165713753

   > Talking about “constraints”, it is also useful to talk about categories of them, roughly we can divide them into three categories.
   
   I like this breakdown, and agree.  In this categorization, what I've been calling "constraints" would be "assumptions".  Double-checking in `builtin.h`, it looks like we don't currently have a TIR equivalent of `__builtin_assume`.
   
   For usage of assumptions, I think the key would be to insert an assumption whenever the information that could otherwise prove it is hoisted out of the PrimFunc.  That would provide non-local information that could be used by the PrimFunc to allow local simplifications.
   
   > transformation of PrimFunc do not change the PrimFunc interface: this is really important so we can transform a PrimFunc without worrying about how the graph interacts with it(as the interface remains the same, we can lift out the blocks earlier)
   
   I don't think we can make this strong of a statement, as it would also forbid fusing operators together or hoisting a stage out of a PrimFunc.  In both cases, the signature of the resulting PrimFunc may be different than it was before.  This shows up in the example, as the interface of `grow` is different from the transformed `grow_packed`.
   
   As a slightly less general statement, I would say that transformations of a PrimFunc *in isolation* may not change the PrimFunc's interface. So an optimization search to improve the performance of a single subgraph may not change the layout of its own arguments, nor may it change assumptions of what is present in the padding, as those would change its interface.  However, a graph-level transform would be allowed to fuse subgraphs, to hoist stages out of a PrimFunc, to alter the layout of a PrimFunc's input, or to alter the assumptions provided about the inputs.  In general, a PrimFunc's interface could only be changed when calls into the PrimFunc are also modified to remain compatible.
   
   Is there a better term than "scheduling primitive" to describe layout transformations that impact input/output buffers?  I think the difference is between context-independent transformations that may be performed on a PrimFunc without changing, as opposed to context-dependent transformations that may only be performed as part of a graph-level transformation.
   
   
   
   > Each function needs to have its own TIR analysis of how it flows things back, for example, in the case of `addone`, we can safely flow PadMapping back, changing `addone` to `addone_packed` by analyzing the TIR. If the addone is elemwise exp however, we need to insert a select operator(because `exp(0)=1` ) the message to input becomes `PadMapping(constraint, pad_value=undef)`.
   
   Would this handle cases where there are multiple different options for how an operator could be implemented?  Otherwise, I'm not sure how this would handle cases where multiple different sets of layouts/constraints could be inferred from different TIR-level schedules of the same operator.  As examples, the drop-down has 6 different implementations of `addone`, each of which would allow different hoistable pad/crop operations.
   
   <details>
   <summary>Click to expand</summary>
   <br>
   
   ```python
   # Implementation 1, no preproc/postproc are present.
   #
   # No hoistable layout transformations.  Could be fused with a layout
   # transformation, but doesn't otherwise provide any constraints.
   @T.prim_func
   def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       for i in T.serial(14):
           with T.block("compute"):
               B[i] = A[i] + 1
   
   
   # Implementation 2, pad input/output, but never access the padding of
   # either input or output.
   #
   # In back-propagation of constraints, the T.undef() that is cropped
   # from BC could be narrowed to a known value provided from the
   # successor.  However, AC's padding is never written to, so could
   # propagate T.undef() back to preceding function.
   @T.prim_func
   def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       for io, ii in T.grid(4, 4):
           with T.block():
               T.block_attr("preproc", "pad")
               if 4 * io + ii < 14:
                   AC[io, ii] = A[4 * io + ii]
   
       for i in T.serial(14):
           with T.block("compute"):
               BC[i // 4, i % 4] = AC[i // 4, i % 4] + 1
   
       for i in T.serial(14):
           with T.block():
               T.block_attr("postproc", ["crop", T.undef()])
               B[i] = BC[i // 4, i % 4]
   
   
   # Implementation 3, pad input with known value, but never access
   # padding of output.
   #
   # In back-propagation of constraints, the T.undef() that is cropped
   # from BC could be narrowed to a known value provided from the
   # successor.  AC's padding is written to, so this would propagate
   # `PadMapping(predicate, pad_value=0)` to the previous operator.
   @T.prim_func
   def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       for io, ii in T.grid(4, 4):
           with T.block():
               T.block_attr("preproc", "pad")
               AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
   
       for i in T.serial(14):
           with T.block("compute"):
               BC[i // 4, i % 4] = AC[i // 4, i % 4] + 1
   
       for i in T.serial(14):
           with T.block():
               T.block_attr("postproc", ["crop", T.undef()])
               B[i] = BC[i // 4, i % 4]
   
   
   # Implementation 4, pad input with arbitrary value, provide no
   # guarantees in output.
   #
   # In back-propagation of constraints, the T.undef() that is cropped
   # from BC could be narrowed to a known value provided from the
   # successor.  AC's padding is written to, so this would propagate
   # `PadMapping(predicate, pad_value=BC_pad_value - 1)` to the
   # previous operator.
   @T.prim_func
   def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       for io, ii in T.grid(4, 4):
           with T.block():
               T.block_attr("preproc", "pad")
               AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], T.undef())
   
       for io, ii in T.grid(4, 4):
           with T.block("compute"):
               BC[io, ii] = AC[io, ii] + 1
   
       for i in T.serial(14):
           with T.block():
               T.block_attr("postproc", ["crop", T.undef()])
               B[i] = BC[i // 4, i % 4]
   
   
   # Implementation 5, pad input with known value, analysis of TIR
   # successfully propagates pad value through to provide assumption when
   # cropping.
   #
   # In back-propagation of constraints, the output assumption is fixed.
   # Unless the operator following addone has included the constraint 1
   # as the required value in its padding, the crop/pad pair wouldn't be
   # able to be removed.  AC's padding is written to, and would propagate
   # `PadMapping(predicate, pad_value=0)` to the previous operator.
   @T.prim_func
   def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       for io, ii in T.grid(4, 4):
           with T.block():
               T.block_attr("preproc", "pad")
               AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
   
       for io, ii in T.grid(4, 4):
           with T.block("compute"):
               BC[io, ii] = AC[io, ii] + 1
   
       for i in T.serial(14):
           with T.block():
               T.block_attr("postproc", ["crop", 1])
               B[i] = BC[i // 4, i % 4]
   
   
   # Implementation 6, pad input with known value, analysis of TIR can't
   # successfully propagate pad value through to the output.
   #
   # In back-propagation of constraints, the output assumption is fixed.
   # Since we don't provide an assumption of what will be returned, the
   # graph-level pair of `crop(T.undef())` followed by `pad(x)` could
   # only be canceled out if `x` is `T.undef()`.  AC's padding is written
   # to, and would propagate `PadMapping(predicate, pad_value=0)` to
   # the previous operator.
   @T.prim_func
   def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       for io, ii in T.grid(4, 4):
           with T.block():
               T.block_attr("preproc", "pad")
               AC[io, ii] = T.if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
   
       for io, ii in T.grid(4, 4):
           with T.block("compute"):
               BC[io, ii] = AC[io, ii] + 1
   
       for i in T.serial(14):
           with T.block():
               T.block_attr("postproc", ["crop", T.undef()])
               B[i] = BC[i // 4, i % 4]
   ```
   
   </details>
   
   I think the main change is that the temporary stages with annotation will need to allow multiple possibilities, rather than a single definitive layout.  These options could then be searched at the graph-level to decide on the appropriate layout.  After that is decided, the tempoerary stage could be selected and the transformations hoisted.
   
   
   > But extra amount of care is needed when we attempt to move `crop_with_pad_assume`, as it really depends on the value property of its input.
   
   Completely agreed.  I think this is true at both the TIR and graph levels, that allowing assumptions means ensuring that the assumption isn't changed after it is used for simplifications.  The advantage of writing the assumptions at the graph level is that specific pairs of functions (such as `crop_with_pad_assume(pad_value)` followed by `pad_with_value(pad_value)`) can be identified as no-ops, without needing a full proof of it.
   
   I think the main rules that would need to be followed when handling assumptions would be the following three.
   
   1. An assumption may be inserted wherever it can be statically proven, or asserted by a user about user-supplied input.
      
   2. An assumption may be removed only if it can be statically proven. Assertions from a user about user-supplied input may never be removed, as they may have already been used to perform irreversible simplifications.
      
   3. Static provers must reset all assumptions about a variable when `T.undef()` is assigned to it, even though these assignments are removed during lowering.
   
   The restriction against changing a PrimFunc's interface fall out directly from rule #1.  Since an assumption that restrict values of an input cannot be proven, these assumptions may not be modified.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 merged pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 merged PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1182157349

   Thank you very much on the comments, suggestions, and discussion, and I'm quite happy with how the design evolved over the course of the discussions!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   @Lunderberg That's true. If the reduction dimension is padded, we will need to insert hint in the graph to assert it was previously padded by 0. From the graph rewriting pov, we can also see this a transformation done in graph level (doesn't rely on arithmetic simplifications)
   
   Example
   ```
   X: R.Tensor[16]
   F: R.Const[16]
   Y = conv1d(X, F, pad=2)
   Z = conv1d(Y, F, pad=2)
   ```
   Inserting padding and crop:
   ```
   X: R.Tensor[16]
   F: R.Const[16]
   X_pad = pad(X, before=2, after=6)
   Y = conv1d(X_pad, F, pad=0)
   assert(Y[18:] == 0)
   Y_crop = crop(Y[0:18])
   Y_crop_pad = pad(Y_crop, before=2, after=4)
   Z = conv1d(Y_crop_pad, F, pad=0)
   Z_crop = crop(Z[0:20])
   ```
   Then we can propagate the padding information and combine:
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891831175


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+intended for use as `BufferConstraintNode::value`, to indicate that it
+is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  This is primarily used to
+allow simplifications in a producer, as any partial computations
+written to this space (e.g. by vectorized operations) may be left
+as-is.
+
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+### New Transform - Hoist Expression

Review Comment:
   They are be different alternative or can be combined on certain workloads. There is a discussion on performance issue of matmul when we just change dimension 128 -> 127. https://discuss.tvm.apache.org/t/te-vectorize-do-we-have-plan-to-support-vectorize-for-non-divisible-split/12469
   I think it might be a good working example. Below is what user get with loop split  `j`: 127 -> (4, 32)
   
   ```python
   for i in range(127):
       for k in range(127):
           for j.outer in range(4):
                for j.inner in T.vectorized(32):
                    if T.likely(j.outer * 32 + j.inner < 127, dtype="bool"):
                        C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
   ```
   
   The issue is that complex condition has to be introduced to maintain the program semantic, and it hurts the performance and generally we can not vectorize program with control flow.
   
   Now I understand we have different alternatives to handle this:
   - Loop partition
   We can already annotate the loop var with hint using non-imperative loop partition.
   
     `for j.outer in range(4, annotations={"pragma_loop_partition_hint": 1}`
   
     After `LoopPartition` pass (and simplify) it becomes:
     ```python
     for i in range(127):
         for k in range(127):
             # j.outer in [0, 3)
             for j.outer in range(3):
                  for j.inner in T.vectorized(32):
                       # condition is const true, optimize out
                       C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
             # j.outer in [3, 4), optimize out
             for j.inner in T.vectorized(31):
                  # condition becomes j.inner < 31, hoisted with loop
                  C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
     ```
     Then the condition branch get eliminated on different loop parts, thus becomes more friendly to performance optimizations like vectorization. For "imperative" partition, it just propose we can just partition on schedule phase when one wants to schedule different parts, such as giving different vectorization width.
   
   - Loop padding
   
     With current RFC, I understand we can padding `C` and `B`'s innermost dimension to 128, and drop the condition directly somehow. Then it directly becomes (IIUC, we may also insert some "arbitrary" value filling code on edges and optimize them out then?) On this particular case, I believe the padding is the better choice since we can get very neat codes with minimal over-computations. 
     ```python
     for i in range(127):
       for k in range(127):
           for j.outer in range(4):
                for j.inner in T.vectorized(32):
                    C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r892661161


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Thank you, and especially with the explicit Goal 7, that helps to clarify where I had been thinking about it differently.
   
   > It is quite natural that a more complicated impl would hit more marks initially.  On the other hand, there is always a consideration of added complexity and how composable our additions are with existing constructs.
   
   Definitely a good point.  (My goal in the RFC was to improve existing utilities wherever as possible, such as the changes to `StmtSimplifier` and `RemoveNoOp`, rather to introducing too many single-use passes.)  My concern was that attempting to describe a constraint in terms of existing TIR constructs and `AssertStmt` would be extremely fragile (e.g. loop over an input buffer and assert the constraint), and would lead other passes to need to be away of the composite structure in order to avoid breaking the structure.
   
   On the composability side, based on yours and @Hzfensy's feedback, it sounds like the proposed scheduling primitives are one of the main issues, as they go against the typical design.  Since they are simplification-dependent, make multiple changes, and in some cases are composable from existing functionality, perhaps it would help to change them from stand-alone scheduling primitives into utilities that analyze a function and return a list of primitives that together have the desired effect.
   
   > Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   
   Enabling this interaction was my intent, so perhaps that means I have an error in how I'm picturing the relay/TIR interactions.  I had been picturing the `BufferConstraint` as the data structure in which to preserve the information and by which to passing the information between levels of abstraction.  When optimizing a single operator/PrimFunc, the layout and constraints of the arguments must remain constant, as otherwise it would be incompatible with neighboring operators.  The relay-level transforms would decide on buffer layout/constraints, and would specify those constraints to the TIR-level through `BufferConstraint`.
   
   > And it would actually be really nice to have an example at the e2e level to show how the set of transformations affect our optimizations.
   
   While not entirely end-to-end, I think the optimizations that would require coordination between multiple operators are briefly touched on in the example sections on [producer/consumer pairs](https://github.com/Lunderberg/tvm-rfcs/blob/buffer_layout_padding/rfcs/0077-layout-transform-padding.md#explicitly-write-next-operators-desired-default-at-end-of-function) and on [chainable `conv1d`](https://github.com/Lunderberg/tvm-rfcs/blob/buffer_layout_padding/rfcs/0077-layout-transform-padding.md#implicitly-write-default-value-of-next-operator).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   > Re: @vinx13: To add a discussion point, I'd like to ask whether the semantic like over computation and writing default value of next operate, can be achieved with graph level rewriting.
   > Re: @tqchen: Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension rewriting and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialization a function to legalize. 
   
   The main differences I see is that A0 pushes much more work into hardware specific optimization (compute definitions + patterns/rewriters) that is not as easily re-purposable for other hardware targets; whereas A1 provides the infrastructure for more general compiler simplification to proceed from hardware semantics the user provides about the buffer when transforming the layout at schedule time. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1153227651

   Thanks for sharing the contextual pointers for the community @vinx13. Agreed the approaches discussed are both valid. I would actually like to argue the stronger point that they are complimentary and are only appearing to be contrary because we are considering too narrow of a scope. 
   
   It can be helpful to share an overview of common handlings of layout transformations in ML compilers. Most of my arguments against A0 (local cancellation in the graph only) being sufficient stem from prior experiences in graph layout optimizations. The evolution of the graph approaches for optimizing layouts I've seen followed the below trajectory:
    
   1) The first graph approach that is often taken is what has been argued so far in A0, local back to back cancellation. It works reasonably when the data flow and op variety in a model are simple.
   
   Local-only cancellation tends to fail in models which still have simple data flow, but more variety in the sequence of operators, each with different valid implementations. Consider, 
   
   ```
   -> transformX -> conv2d -> (inv_transform_X) -> pool -> (transformX) -> conv2d -> inv_transform_X
   ```
   In this case `pool` can be replaced by any sequence of operations that are layout agnostic or for which there exists multiple implementations, and so the choice of layout is unconstrained. In this case these operators are layout unconstrained, whereas the convolutions are layout constrained. As you can see even for a simple model, the approach discussed in A0 already needs to be modified to support non-local layout analysis and folding. 
   
   2) The typical second approach is then to still utilize A0, but to first apply a pre-processing pass that sinks layout transforming operations in the graph along the path of data flow [[1](https://github.com/pytorch/glow/blob/56249602c9ec93fa586cea2ce8ab315003478eed/lib/Optimizer/GraphOptimizerPipeline/FunctionPassPipeline.cpp#L69), [2](https://github.com/NervanaSystems/ngraph/blob/f677a119765ca30636cf407009dabd118664951f/src/ngraph/pass/reshape_sinking.cpp#L542)]. The above case then becomes, 
   
   ```
   -> transformX -> conv2d -> pool ->  (inv_transform_X -> transformX) -> conv2d -> inv_transform_X
   ```
   Then apply the method discussed in A0 and do local cancellation. 
   
   The above method works well for models with relatively simple data flow but for models with more branching the method has limitations. A simple consideration is sinking a transform through an operation with multiple inputs. The process of doing so requires materialization of the inverse transform on the other operands. 
   
   For the sake of simplicity consider matrix multiplication: ${A^{T}}B = {(B^{T}A)}^T$, in this case the final state of sinking the transpose on A was to materialize two transposes rather than one, one on B and one on the matmul. Sinking-alone isn't sufficient to guarantee a globally optimal layout because it still only treats the propagation of transforms locally/greedily.
   
   3) A modification to sinking (downward along data flow) is to introduce upward flowing [[3](https://github.com/NervanaSystems/ngraph/blob/f677a119765ca30636cf407009dabd118664951f/src/ngraph/pass/reshape_sinking.cpp#L156)]. It can help by flowing transforms along the poisoned operands (e.g. B in the above matrix multiply) by propagating the transform up as far as possible, hopefully to a constant where it can be folded. 
   
   For inference graphs I've seen this approach work well. But the approach is still greedy and suboptimal choices can occur. For training graphs this approach works less well due to the data flow complexity involved with branching from the forward to backward graph and the optimizers in place update of weights. I omit a specific example in this case for brevity, but encourage the review of of the graphs from @t-vi application of TVM to pytorch training for Bert and the long chains of transpose and reshapes that occur within the forward and backward m-h attention layers [[3](https://github.com/apache/tvm-site/blob/85c7e4ebf6d9ed221075e38e5e5e1a0052693acc/_posts/2020-07-14-bert-pytorch-tvm.md)]. 
   
   4) Finally, to arrive at a closer to globally optimal solution for layout, different constraint-based approaches are considered. Constraints from operations which are layout constrained can be flowed across unconstrained parts of the graph until an approximate global optimum is reached. 
   
   An example implementation I have seen included layout sources (e.g. operators like conv2d on an NPU with distinct layout constraints) and layout sinks (e.g. operations which involve data movement by DMA engines or in-memory compute which allow zero-cost data layout rearrangement during store). A constraint solver in this case flows layout constraints from sources toward sinks that can absorb aggregated/merged layout transform constraints. 
   ____
   
   Coming back to the present discussion, I believe our design should be focused on ensuring that one or more of the non-local approaches discussed above in 2-4 are achievable. Any of these cases require the following components:
   
   C0) The ability to track constraints on a buffer.
   
   C1) The ability to roundtrip between an IR representation and the producer/consumer constraint representations.
   
   C2) The ability to merge/fold constraints - flowing is just merging a constraint with an unconstraint. 
   
   Even for the pure local (back-to-back) case discussed in A0, components C1 and C2 are helpful with the caveat that the inferred constraints from the IR only exists within the local context of a single producer consumer pair in a pass. 
   
   Thus both A0 and A1 can benefit from these components, and the delta that exists between A0 and A1 is clearer:
   
   * Delta 1: In A1 buffer constraints are maintained per buffer in the graph globally (non-local); and therefore can be optimized by any of the methods 2-4 discussed. 
   
   * Delta 2: In addition to inferring buffer constraints from IR (one half of C1), A1 proposes for constraint expression about the memory during scheduling to be maintained for some time.
   
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163620046

   Writing out some of my thoughts, to see if there's a way to express the constraints while only using existing TIR features.  The main goals would be as follows.
   
   1. Allow simplification of expressions based on the values present in the padding.
   2. Allow local simplifications to take advantage of non-local constraints, without requiring a full end-to-end analysis.
   3. Specify the non-local constraints in some deducible manner that doesn't impose a runtime performance penalty.
      
   Next, working through various options for how the constraints could be stored. In the examples below, sketching out how these would apply to the element-wise operation which starts as below.
   
   ```python
   @T.prim_func
   def func(A: T.Buffer[(14), "int32"], B: T.Buffer[14, "int32"]):
       for i in T.serial(14):
           B[i] = 2 * A[i]
   ```
   
   1. Apply layout transforms on local caches.  Here, the full lifetime of a buffer is known.  All TIR optimization are done prior to hoisting the cache and layout transformation into the graph level.
      
      - For read caches, pad value is whatever gets conditionally written to the padding while generating it.  In example below, `AC` could be recognized as being padded.
        
        ```python
        @T.prim_func
        def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
            AC = T.alloc_buffer([4, 4], "int32")
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    AC[io, ii] = A[4 * io + ii]
                else:
                    AC[io, ii] = 0
        
            for i in T.serial(14):
                B[i] = 2 * AC[i // 4, i % 4]
        ```
        
      - For write caches, pad value is whatever is in the padding after the last write to the cache.  In example below, `BC` could be recognized as being padded.
   
        ```python
        @T.prim_func
        def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
            BC = T.alloc_buffer([4, 4], "int32")
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    BC[io, ii] = 2 * A[4*io + ii]
                else:
                    BC[io, ii] = 0
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    B[i] = BC[io, ii]
        ```
   
      - Downside, either of the `else` statements could be eliminated as a no-op, since they don't contribute to the output `B` value. After that elimination, there wouldn't be any way to reconstruct the pad value.
        
   2. When hoisting an allocation+transformation, write the pad value to the buffer at the start of function from which it was hoisted. This way, the pad value can still be used in local reasoning.
      
      - No change needed in producers, since they would already write the pad value to the buffer.
      
      - For consumers, would be represented as writing `pad_value` into the padding at the start of the function.
      
        ```python
        @T.prim_func
        def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
            for io, ii in T.grid(4, 4):
                if 4 * io + ii >= 14:
                    AC[io, ii] = 0
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    B[4 * io + ii] = 2 * AC[io, ii]
        ```
        
      - Downside, repeated unnecessary effort at the beginning of each consumer.  Avoiding it with this representation would require knowing that the producer had written `pad_value` already, which is exactly the information we're trying to avoid.
        
   3. When hoisting an allocation+transformation, write the pad value to the buffer at the start of function from which it was hoisted, and write `T.undef()` at the end.  This way, the pad value can still be used in local reasoning, and no-op removal can remove the repeated writing when lowering.
      
      - No change needed in producers, since they would already write the pad value to the buffer.
        
      - For consumers, would be like option 2, but with an additional write of `T.undef()` at the end of the function.  When lowering, the write of `T.undef()` would allow the first write to be removed as a no-op because it is overwritten.  The `T.undef()` can then be removed as described in the RFC.
      
        ```python
        @T.prim_func
        def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
            for io, ii in T.grid(4, 4):
                if 4 * io + ii >= 14:
                    AC[io, ii] = 0
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    B[4 * io + ii] = 2 * AC[io, ii]
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii >= 14:
                    AC[io, ii] = T.undef()
        ```
        
      - Downside, no way to distinguish between "can assume the pad value is zero" and "can overwrite the pad value at will".  The writing of `T.undef()` would allow any writes to the padding to be inserted as a no-op.
        
      - Downside, wouldn't actually simplify out in cases where the pad value is used.  The first in a pair of repeated writes to the same location can only be removed if there are no reads between the writes.  After using the pad value to eliminate `if 4 * io + ii < 14` from the compute, the dummy loop that writes the padding could no longer be removed.
        
   4. Use `AssertStmt` in a loop to declare known information about the buffers.
   
      - No change needed in producers, since the pad value is already written out.
        
      - For consumers, would have an initial loop that asserts the pad value is correct.
   
        ```python
        @T.prim_func
        def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
            for io, ii in T.grid(4, 4):
                if 4 * io + ii >= 14:
                    assert AC[io, ii] == 0, "padding"
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    B[4 * io + ii] = 2 * AC[io, ii]
        ```
        
      - Downside, assert statements have target-dependent handling.  In `CodeGenLLVM` and `CodeGenSPIRV`, they are treated as no-ops.  In `CodeGenCPU` and `CodeGenC`, they generate asserts.  In `CodeGenCUDA`, they aren't handled at all and would error out.
        
        Could work around this with a lowering pass, but identifying these conditions would require having a special string in the message, and packing structured data into strings makes me wary.
        
   5. Use `AssertStmt` with implicitly-defined variables to declare known information about the buffers.
      
      ```python
      @T.prim_func
      def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
          a = T.var("int32")
          b = T.var("int32")
          assert (
              AC[a, b] == 0 or (4 * a + b < 14) or (a < 0) or (a >= 4) or (b < 0) or (b >= 4)
          ), "padding"
      
          for io, ii in T.grid(4, 4):
              if 4 * io + ii < 14:
                  B[4 * io + ii] = 2 * AC[io, ii]
      ```
      
      - Can apply to clamped texture memory, since the variables in the assertion isn't restricted to the bounds.
        
      - Would need to recognize specific pattern of `BufferLoad` being used to define variables used in constraint.
        
      - The implicitly-defined variables can be written in current TIR, but  variables would ensure that this isn't something that ever makes it into generated code at runtime.
      
      - Downside, implicitly-defined variables are something of a red flag.
   
   6. Store constraints in the function attributes, either as a dictionary or as a structured object.
      
      ```python
      @T.prim_func
      def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
          T.func_attr(
              "buffer_constraints",
              [
                  {
                      "buffer": AC,
                      "predicate": lambda io, ii: 4 * io + ii < 14,
                      "pad_value": lambda io, ii: 0,
                  },
              ],
          )
      
          for io, ii in T.grid(4, 4):
              if 4 * io + ii < 14:
                  B[4 * io + ii] = 2 * AC[io, ii]
      ```
      
      - Downside, requires transformations that change a buffer to be aware that other structures will also need to be replaced.
        
      - Downside, requires simplifications to either be passed the entire `PrimFunc`, or to be explicitly passed the `"buffer_constraints"` list.
        
      - Downside, would break expectations of `IRMutatorWithAnalyzer`. The current entry point of any `Stmt` or `Expr` would need to have additional information of the `"buffer_constraints"`.
        
   
   7. Store constraints in the `Buffer` object, either as a dictionary or as a structured object.
      
      ```python
      @T.prim_func
      def func(ac: T.handle, B: T.Buffer[14, "int32"]):
          AC = T.match_buffer(
              shape=(4, 4),
              dtype="int32",
              constraints=[T.BufferConstraints(predicate=lambda io, ii: 4 * io + ii < 14, pad_value=0)],
          )
      
          for io, ii in T.grid(4, 4):
              if 4 * io + ii < 14:
                  B[4 * io + ii] = 2 * AC[io, ii]
      ```
      
      - Downside, introduces additional data structure in TIR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1165889831

   > Our design principle at TIR level ideally we start with one instance of possibility, then use probabilistic space of meta-schedule to represent multiple choices.
   
   For this, would the layout re-flowing occur periodically during optimization?  Otherwise, including transformations in the performance benchmarking of candidates would unfairly penalize candidates that add a transformation step, while excluding transformations would unfairly bias toward transformations, even when sequential operators require separate layout transformations.
   
   Representing different options as different allowed steps in a search space makes sense to me, so long as the candidates are reasonably exposed to optimizer.
   
   > In our particular example, however, the idea is that the schedule primitive do not modify the input/output buffer, but introduce preproc and postproc stages with clear hint that they should be lifted out (aka we are doing the same thing in two steps)
   
   I think I understand.  That would effectively be treating the preproc/postproc stages as separate function bodies, but ones which happen to exist within the same TIR PrimFunc for ease of use.
   
   With this representation, I think the biggest part would be determining when to fix a previously free parameter, in order to expose it as an assumption to another TIR PrimFunc.  Maybe in the "Step 2: Reflowing of Layouts", this isn't used to cancel any statements out, but instead to create a dynamic performance penalty if an assertion is no longer held, with the performance penalty equal to the time required to do the transformation.
   
   > As a quick intermediate middle ground. For most intermediate stage(really like add or exp), we would ideally not insert any layout decisions and allow decisions from other key ops(conv and matmul) to backprop their decision to them.
   
   I'd agree, though would phrase it somewhat differently.  The element-wise operations impose a constraint such that input and output layouts, that the input and output have identical layouts.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163019805

   > I'm still a bit confused with this approach, specifically how one would avoid having a separate compute definition for each workload on a new target
   
   Indeed it is important to avoid having a separate compute definition for each workload on a new target. In this particular case, all computation definition would start with the original layout. Then there is a "schedule transformation" like transform layout which will generate the new stage as part of the scheduling process.
   
   The particular stage can be marked, which contains effectively the same information as BufferConstraint, except that it does not introduce new data structures. During global layout reflowing, such information can be used to guide the reflowing to reconstruct a data structure like `BufferConstraint` or other Layout mappings and use that to serve the same purpose.
   
   > Is there an existing annotation to indicate that a stage should be removed entirely during lowering? 
   
   Ideally we should not introduce annotation to indicate a stage should be removed, as that breaks the interface of the code itself (ideally the computation should remain the same).
   
    However, we can hint to the compiler that this particular stage is a layout transformation that should be lifted and resolved through the global constraint reflowing. Additionally, such annotation can be used to guide benchmarking, such that the overall tuning should only look at non-rewriting part(and we can leverage the transform block to generate input examples correctly).
   
   
   As a high level summary, the main message is to allow enough info in the TIR(as part of transform block) such that we can reconstruct a `BufferConstraint` like auxiliary data structure in global reflowing, while still making the TIR part self-contained enough so it is sufficient to construct such data structure.
   
   This also helps in cases where there are other graph-level layout rewriting(e.g. transpose) that can be fused with those additional transformation stages.
   
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1162392893

   > Introducing changes to TIR would needs some additional thoughts that deserves some extra consideration. Due to the N*M complexity (where N is the TIR possibilities and M is the number of primitives to be supported) that needs to be handled in implementation (by backend implementers and primitive implementers)
   
   This was part of the design consideration, to minimize the impact of the proposed changes to primitives, lowering transformations, and backends.
   
   * The `BufferConstraint` annotations do not need specific handling at the codegen level, as it is only present to enable compile-time optimizations.
     
   * Use of the `BufferConstraint` hints would occur within existing utilities, primarily as additional information available in `arith::Analyzer` utilities.  This minimizes the need for other primitives/transforms to be aware of the buffer constraints, while still benefiting from them.
     
   * The `T.undef()` built-in does not need specific handling at the codegen level, as it is removed during lowering.
     
   * The `T.undef()` built-in does not require specific handling from other primitives, as stores of `T.undef()` can be treated the same as stores of any other value.
     
   > Right now it is possible to do non-local constraint rewriting flowings as part of the graph pass. Note that while E1 is indeed less "compact" on one hand, we can use it to reconstruct the desirable compact data structure(something like BufferConstraint that represents the layout mapping) that we can use to flow the decisions across the graph node during the pass.
     
   I definitely agree that graph-level transforms are where the layouts and constraints should be decided.  The `BufferConstraint` annotations are not intended as a way to override in TIR what was already decided at the graph level, but rather a way to communicate to TIR transformations what has been decided at the graph level.
   
   > E1: Composing a stage that transforms the layout(a loop that represents the mapping)
   
   I'm still a bit confused with this approach, specifically how one would avoid having a separate compute definition for each workload on a new target (Initially brought up by @csullivan [here](https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372).) In my mind, if I'm going to compose a layout transformation stage, it would need to be followed by a compute stage that takes a transformed layout as input.  So rather than having a single conv2d that can be generalized over layouts, each transformed layout would still need to have a compute stage for it.
   
   > Note that intiially such data structure do not need to live beyond the life of a pass, because they can be reconstructed at anytime from the other representation.
   
   How would this be represented while optimizing the performance of a subgraph?  My concern would be how to express the non-local constraints while keeping a small search space for optimization.
   
   * Ensure that the producer and consumer stages are within the same subgraph.  Since the constraints provided to a consumer depend not only on the producer, but also on the constraints provided to the producer, so this might require fusing the entire end-to-end model into a single monolithic kernel.
     
     My understanding is that this would result in a search space that is too large to effectively optimize, though I haven't explicitly tested it.
     
   * Insert a transformation stage into the subgraph, in which the constraint is written.  Later portions of the subgraph could then rely on the constraint without examining other subgraphs.
     
     Would need to have some way to indicate that the transformation stage shouldn't be altered during optimization, nor should it be part of the performance timing.
     
   * Express the graph-level constraints to a subgraph, so that it can optimize using those constraints.
     
     This was my intent with the `BufferConstraint` annotations, since then the subgraphs could take advantage of
     
   > E1 also enables some additional capabilities (e.g.) expressing future memory remappings that do not necessarily fit into padding/packing.
   
   Is there an existing annotation to indicate that a stage should be removed entirely during lowering?  That might be an effective way to allow more general usage by annotating a stage that can be assumed to have been performed prior to the subgraph.  This would be a way to express the second option of an extra transformation stage, while still providing enough information to remove the transformation stage during lowering.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1167249134

   > For this, would the layout re-flowing occur periodically during optimization?
   
   This is a point where likely different variation of (some sort of search)algorithm might be necessary, our first step would be to allow the TIR level to give such feedback to the global level(via a probabilistic space) and search can be done more smartly.
   
   >  The element-wise operations impose a constraint such that input and output layouts, that the input and output have identical layouts.
   
   Agree, this is something that we can do


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1181142880

   Thanks everyone for the very fruitful discussions! We indeed have a good path forward and are aligned on the principles that for the end to end optimization we will maintain function interface invariance and achieve graph level layout optimization via a combination of local decisions, reconstruction with assumptions, and rewriting based on the result of graph level analysis and planning. 
   
   I would ask that we move this discussion into a final comment period as we would like to soon open a tracking issue for the items described in the RFC. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Hzfengsy commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r890793224


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   +1. I agree that we are solving a challenging and important problem, which needs additional data-structure and transformations. But as @tqchen mentioned, IR data structure needs to be stable. Once it changes, we may need lots of effort on reviewing all exsiting primitives and transformations for the little changes.
   
   I agree we can enhance the IR sematics when "necessary", if we have no other way to go. Before that, let's think about it carefully to find alternate paths. 



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+primarily used to allow simplifications in a producer.  See [section
+on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage, and [section on
+`tir.transform.RemoveArbitraryStore`](#new-lowering-transform-remove-tarbitrary)
+for its removal.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-primitive-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Primitive - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+A new S-TIR transformation `Schedule.sequential_buffer_access` should
+be introduced, which rewrites iteration loops according to the access
+pattern of a buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):

Review Comment:
   It is not a typical S-TIR func, which is expected to be:
   ```python 
   @T.prim_func
    def func(A: T.Buffer[(16,), "int32"]):
       for i in T.serial(16):
           with T.block('compute'):
               vi = T.axis.S(16, i)
               A[vi] = vi
   ```
   
   Here are mistakes in most cases in this RFC:
   1. Usually loops are outside the block, and one block only contains a single stmt
   2. The `if-else` branch is represented by `T.where` if it's inside a block
   3. Please explicitly write block vars, since it's important during schedule.



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+primarily used to allow simplifications in a producer.  See [section
+on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage, and [section on
+`tir.transform.RemoveArbitraryStore`](#new-lowering-transform-remove-tarbitrary)
+for its removal.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-primitive-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Primitive - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+A new S-TIR transformation `Schedule.sequential_buffer_access` should
+be introduced, which rewrites iteration loops according to the access
+pattern of a buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed

Review Comment:
   I agree that `sequential_buffer_access` is beneficial. But I'd like to make it a sugar rather than a new primitive. i.e., when users call `sch.sequential_buffer_access`, it will implicitly call a set of `split,` `reorder`, and the schedule trace will be:
   ```python
   sch.split(...)
   sch.reorder(...)
   ```
   
   The reasons are:
   1. `Primitive` is expected irreplaceable. If one transformation can be represented by a set of existing transformations, we won't create a new primitive (but sugar is fine)
   2. Reuse primitives split/fuse/reorder may reduce the maintenance cost. We don't need to fix the same bug (if it exists) twice



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   To summarize the to approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension reordering and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialization a function to legalize. 
   
   The main differences I see is that A0 pushes much more work into hardware specific optimization, where as A1 allows for more compiler simplification through information the user provides as hardware semantics that are true about the buffer. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891387773


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::arbitrary`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+primarily used to allow simplifications in a producer.  See [section
+on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage, and [section on
+`tir.transform.RemoveArbitraryStore`](#new-lowering-transform-remove-tarbitrary)
+for its removal.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-primitive-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Primitive - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.

Review Comment:
   Thank you, and there was a similar comment from @Hzfengsy .  I've updated this section to instead propose a utility function to generate the appropriate split/fuse/reorder, rather than being itself a primitive.  Taking another look at https://github.com/apache/tvm/pull/11485 and the block iter bindings, I think the utility might be as simple as applying `transform_block_layout` with a mapping defined based on the block iter bindings of spatial dimensions.
   
   The main uniform usage coming to mind would be applying the transformation to all three layouts uniformly.  Though, that would only be well-defined if all three are already uniform, so that wouldn't catch cases where it changes from scatter to gather or vice versa.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r892802660


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Some notes from a conversation with @vinx13 @csullivan @tqchen @junrushao1994.  Wording is mine, attempting to summarize statements made.
   
   * Hoist transformations into graph-level vs apply transformations in TIR.
   
     From @vinx13: The RFC proposal models layout transformations in TIR primarily focused on the TIR side.  However, this isn't strictly necessary, as padding could instead be introduced and reasoned about at the relay/relax level.
   
     Example: A single operator that performs a conv1d could be replaced by a sequence of three operators, to transform the layout, perform a conv1d, then apply the inverse transform.  If two adjacent conv1d use the same transform, then the transforms could be canceled out.
   
     From @Lunderberg: Uncertain how generic this could be for padding, as it would require a pre-existing implementation of the operator applied to a padded buffer, and would require graph-level knowledge of what padding can be applied to different buffers without effect.
   
     Buffer padding can also be trickier to reason about.  A layout transform of NCHW to NHWC followed by a layout transformation of NHWC to NCHW cancel each other out.  If a buffer is first cropped, then padded with zeros, these only cancel each other out if the contents of the cropped locations can be proven to have held zero.
   
   * Optimization of PrimFuncs
   
     From @Lunderberg: Had been visualizing the layout transformations not solely as transformations that would be made in isolation, but as transformations that would need to be done in conjunction with the calling scope.
   
     When optimizing a PrimFunc in isolation, `transform_layout` would only be allowed to be applied on internal buffers, not buffers passed in as arguments.
   
   * Graph-level optimization of buffer layouts.
   
     From @Lunderberg: Combining the previous two, this sounds like something that will be useful and doable in Relax, where a transformation could change both the input shape, and also the calling convention of all TIR functions that use the buffer.
   
     In relay, this would be trickier, but still possible.  Would need some way to query what layouts/assumptions an operator would find useful, either through manual tagging or some automatic search. Relay would first query the operators, then decide the layouts, then pass these down to the operators to use.
   
   * Method to write compute definitions, especially for new hardware.
   
     From @csullivan: Modeling the layout transformation within TIR would have huge benefits when supporting new hardware targets.  When writing a new operator, there isn't a specific "graph-level" that it is being designed for.  The current state is quite difficult for hardware vendors, who must write a new compute definition and schedule for each shape of an operator.
   
     This is a benefit that wouldn't be present in the graph-level layout transformations, which would still require each supported layout to have a different implementation.
   
     From @Lunderberg: Writing a function that outputs a PrimFunc that follows a specific layout as a reasonable starting point for optimization would be a goal, and that writing in terms of a series of changes made to an initial PrimFunc would be the reasonable way to go about it.  Whether that is best described as a schedule primitive, a function that applies a schedule primitive, or a function that returns a series of schedule primitives to be applied is less important than the functionality itself.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894941508


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes

Review Comment:
   > @[tqchen](https://github.com/tqchen) replied [yesterday](https://github.com/apache/tvm-rfcs/pull/77#discussion_r893932157)
   > ...
   > Under certain scenarios we could indeed consider put some constraints at the interface level. That would need some more thoughts on semantics, how would they interact with graph, and structural complexity (whether padding is something that worth the IR complexity).
   >  ...
   
   @tqchen Our hope is to have these thoughts and discussion in this RFC and welcome your and others analysis on the semantics and the specific complexity it would introduce. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r892859018


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+intended for use as `BufferConstraintNode::value`, to indicate that it
+is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  This is primarily used to
+allow simplifications in a producer, as any partial computations
+written to this space (e.g. by vectorized operations) may be left
+as-is.
+
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+### New Transform - Hoist Expression

Review Comment:
   > (IIUC, we may also insert some "arbitrary" value filling code on edges and optimize them out then?)
   
   Yup, the loop that defines writes `T.undef()` into the padding values would be present as an intermediate.  This allows `RemoveNoOp` to be much more general, since it only needs to look for two sequential writes to the same indices to conclude that the first is a no-op.  As a result, a matching `else_case` would be a no-op, and therefore safe to insert without impacting the final result



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension reordering and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialization a function to legalize. 
   
   The main differences I see is that A0 pushes much more work into hardware specific optimization (compute definitions + patterns/rewriters) that is not as easily re-purposable for other hardware targets; whereas A1 provides the infrastructure for more general compiler simplification to proceed from hardware semantics the user provides about the buffer when transforming the layout at schedule time. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893932157


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   I think the general idea of reusability  A0 and A1 are not that different. Notably, A0 brings the reusablity in a sense of the utility simplifications wrt to folding etc. While A1 brings the simplifcations in the form of TIR level.
   
   Note that compute defs in A0 can remain the same, the main difference is how the preferable schedules are being derived per informed by hw target as they are scheduled.
   
   From a pure impl pov, considering the e2e goal. Having padding being sorted out in the graph level, actually still simplifies the scheduling layer. 
   
   Note that A0 still does not preclude us from doing constraint matching, one can view that as insert a pre-proc stage that "must simplifies". At the high-level it is only a repr difference(of putting in the buffer decl vs stages).
   
   Under certain scenarios we could indeed consider put some constraints at the interface level. That would need some more thoughts on semantics, how would they interact with graph, and structural complexity (whether padding is something that worth the IR complexity). 
   
   For most e2e goals perhaps having a padding throughout graph level is not a bad way to reduce that part of the complexity.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] areusch commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
areusch commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1151684472

   we discussed this at the June 6 [community meeting](https://discuss.tvm.apache.org/t/next-tvm-community-meeting-june-8-2022/12900). a significant chunk of the meeting was spent presenting the RFC, and we had about 15 minutes of discussion at the end. 
   
   i think there is more to be discussed here. if. we'd like to discuss in high-bandwidth, we can bring this back up at future community meetings. here are notes:
   
   @kparzysz-quic : 
   - aside from transform_layout, the immediate application i see from this is vectorization of variable-length loops. we should separate the transformation and optimization parts because those two things are logically independent. the transform_layout will generate TIR, and then that TIR is optimized using a set of other passes/techniques.
     - @Lunderberg agrees. this is the motivation behind splitting this into "transforms" and "more generic operations." HoistExpression does a large part of what is needed for variable-length loop vectorization by splitting out parts that do depend on a dynamic size from the parts that don't. 
   - KP is worried that it'll take quite a while to implement enough transforms to get to overcompute (e.g. it's hard to determine whether overcompute can be applied). can we have something that transforms the layout, then allow the user to provide a compute statement that is attested by them to work on the transformed layout without any verification?
     - @Lunderberg i think that on its own (assuming ops are fused together by providing a tensorization that defines "this entire fused operation can be replaced with x followed by y"), can be done by 
     - don't have a good way to express "turn off all additional safeties" but proceed to perform those optimizations.
     - could imagine having something analogous to the `undef` (where that is the "least-convenient value") except as the "most convenient value." if it's most convenient to presume a value is 0, then where this value is present, it's legal to assume that the value is 0 and move forward.
     - there's also a [partway condition](https://github.com/apache/tvm-rfcs/pull/77/files#diff-a5740745158592278e549c62bd8c7ccb5b6317deb56d1164d8bf845ee4db5e75R1919) that doesn't require any of the overcompute proving, but does get to a useful intermediate using only expression hoisting and insertion of existing if/then's that happen for loop rewrites. after everything's been hoisted and simplified, what falls out naturally is an outer loop that splits up into two inner loops:
        -  a slow one that handles the edges
        - a fast one that handles the interior
     this might allow us to get to the point of adding the branchless/vectorizable piece even if it's not the only thing there.
   - @tqchen notes that one of the reasons we have complexity here is that we are trying to decompose the problem into more general predicates. if we try to go for less complexity, we could introduce transformations that do more transformations at once and thus require less proving.
     - the question remains how we might remove additional unnecessary steps added by the initial layout_transform. on GPUs it might be possible to pad while loading shared memory. in other cases we may need to consult the graph-level model to determine how much padding is needed.
   - @Lunderberg notes much of this complexity came from "how can we prove the additional steps are unnecessary?" there are also some additional parts where the constraints written in the copy stage may need to flow upwards from something downstream in the data-dependency graph in order to properly state it. 
     - between explicitly specifying options over N buffers with different pre-existing layouts and identifying whether a layout transformation would require branching loops to handle the edge, a lot of it boils down to which level of abstraction is the layout decided on and how is that exposed to lower levels of abstraction.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894879301


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   I was trying to work through what the transformation hoisting would look like at the graph level, and I think it runs into the same requirement of tracking the buffer constraints.  The derivation uses a toy model, consisting of the two sequential 1-d convolutions from [this section](https://github.com/Lunderberg/tvm-rfcs/blob/buffer_layout_padding/rfcs/0077-layout-transform-padding.md#implicitly-write-default-value-of-next-operator).  Layout transformations are introduced in read/write caches, and those layout transformations are hoisted into separate functions, so that they can be manipulated at the graph level.  The adjacent transformations are then fused and simplified.
   
   (The full derivation is in [this gist](https://gist.github.com/Lunderberg/7dcd4edbdd7bedfb08072037792aa585), as including it here made the comment unreadably long.)
   
   The key result from the derivation is that, for a padded transformation `f` that maps from the logical layout to the physical layout, `f_inv(f(X))` can be simplified to `X` at the graph level, but `f(f_inv(X))` cannot.  Instead of simplifying into a memcpy, the transformations result in the following:
   
   ```python
   @T.prim_func
   def fused_inv_transform_Y_transform_Y(
       Y_write_cache: T.Buffer[(3, 8), "float32"],
       Y_read_cache: T.Buffer[(3, 8), "float32"],
   ):
       for io, ii in T.grid(3, 8):
           i = 8 * io + ii - 2
           if 0 <= i < 18:
               Y_read_cache[io, ii] = Y_write_cache[io, ii]
           else:
               Y_read_cache[io, ii] = 0.0
   ```
   
   This expression could be simplified by using the original constraint on `Y_write_cache` provided in the layout transformation, but reconstructing that constraint couldn't be done by local analysis of any single function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   @Lunderberg That's true. If the reduction dimension is padded, we will need to insert hint in the graph to assert it was previously padded by 0. From the graph rewriting pov, we can also see this a transformation done in graph level (doesn't rely on arithmetic simplifications)
   
   Example
   ```
   X: R.Tensor[16]
   F: R.Const[16]
   Y = conv1d(X, F, pad=2)
   Z = conv1d(Y, F, pad=2)
   ```
   Inserting padding and crop:
   ```
   X: R.Tensor[16]
   F: R.Const[16]
   X_pad = pad(X, before=2, after=6)
   Y = conv1d(X_pad, F, pad=0)
   assert(Y[18:] == 0)
   Y_crop = crop(Y[0:18])
   Y_crop_pad = pad(Y_crop, before=2, after=4)
   Z = conv1d(Y_crop_pad, F, pad=0)
   Z_crop = crop(Z[0:20])
   ```
   Then we can propagate the padding information and combine
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894941045


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Of course for the sake of discussion the example is limited to two convolutions. But the general case of multiple (N) back-to-back contractions with padded transformations, handling at the graph-level requires similar non-local information/hints across the sequence of operators to Nth order.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1152992143

   Thanks for the discussion. To provide more context, the A0 approach we discussed is TIR-Relax layout rewriting https://github.com/tlc-pack/relax/issues/162 (the general idea is to lift such transformation in TIR scheduling into the graph, and then cancels out redundant intermediate transformations by either proving fusing the pair of post-compute and pre-compute transformations produces an identity TIR function, or use high-level operator semantic). I think this is very similar to  the [graph-level solution](https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4)  mentioned by @wrongtest 
   In general, both A0 and A1 are valid approaches. It is mainly about how we would like to handle the complexity in simplifications.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1156407362

   Adding some additional discussion with @csullivan .
   
   We agree that:
   - There are different ways to encode layout and padding decisions:
       - E0: BufferConstraint(as element in the IR)
       - E1: Composing a stage that transforms the layout(a loop that represents the mapping)
   - Non-local rewrites are needed to be able to propagate the layout and padding decision through out the entire model through constraint solving. 
   
   Right now we have some healthy discussions about ways to encode layout and padding decisions. 
   
   Some of my thoughts:
   
   Introducing changes to TIR would needs some additional thoughts that deserves some extra consideration. Due to the N*M complexity (where N is the TIR possibilities and M is the number of primitives to be supported).
   
   Right now it is possible to do non-local constraint rewriting flowings as part of the graph pass. Note that while E1 is indeed less "compact" on one hand, we can use it to reconstruct THE compact data structure(that represents the layout decision) that we can use to flow the decisions across the graph node.  E1 also enables some additional capabilities (e.g.) expressing future memory remappings that do not necessarily fit into padding/packing.
   
   Starting from the graph level allows us to capture learnings, then use some e2e goals to make an informed decision on TIR level change later if needed.
   
   
   
   
   
   
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r909639315


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:

Review Comment:
   If the merging of complementary conditionals is valid, then which condition is kept doesn't matter for correctness.  For two conditions `A` and `B`, if `A` implies `!B` and `B` implies `!A
   
   That said, I'd probably keep the first conditional, as it allows for the simplification to be viewed as a specific case of a more general transformation.  Given a conditional that is followed by another statement outside the conditional, it is valid to move the statement inside the conditional, placed at the end of both the `then_case` and `else_case`.  If the statement being moved is itself a conditional, then it may be simplified.  In this case, the intermediate step would look as follows.
   
   ```python
   @T.prim_func
   def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
         for i,j in T.grid(4,4):
             if 4*i + j < 14:
                 A[i] = 0.0
                 if i==3 and j>=2:
                     B[i] = 2.0
                 else:
                     B[i] = 3.0
             else:
                 A[i] = 1.0
                 if i==3 and j>=2:
                     B[i] = 2.0
                 else:
                     B[i] = 3.0
   ```
   
   I wouldn't want to generate the intermediate state in all cases, because it may not always lead to useful simplifications, which is why it would only be applied in the special cases of identical conditions and complementary conditions.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] areusch commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
areusch commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1149027473

   this is on the agenda for tomorrow's [community meeting](https://discuss.tvm.apache.org/t/next-tvm-community-meeting-june-8-2022/12900). Perhaps we could discuss in higher bandwidth there?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] csullivan commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r893701372


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   To summarize the two approaches being discussed, I see them as 
   
   **A0. Compute definition based with pruning.** 
   1) Describe all data layout transformations (which can include dimension reordering and padding) as part of the workload compute definition (hardware dependent). 
   2) Hoist layout operations into the graph and rely on pattern matching to do cancelation or folding when possible. 
   
   **A1. Schedule based with constraint flowing.** 
   1) Describe a generic workload compute definition (hardware independent) and apply scheduling primitives that inject information about the hardware support for layouts and padding that allow for compiler simplification. 
   2) Run a type-inference like pass at the graph level to flow the tir.Buffer constraints and only materialize a layout conversion when a contradiction exists. 
   
   It seems clear to me these approaches essentially stem from the two canonical compiler approaches, 
   1) Aggressively insert legalization (layout transforms before and after every operation) and prune.
   2) Constraint flowing. 
   
   In addition there are also implications of taking either of the compute or schedule based approaches:
   
   A0 requires compute definitions and schedules written for every workload, for every hardware, and for every layout. Additionally, because data layout is intimately tied to optimal use of the microarchitecture, the layout transformation patterns will also be hardware specific. Thus, A0 requires the hardware-specific decomposition of hardware semantics to IR (compute definition) as well as hardware-specific recomposition of IR into hardware semantics (pattern matching) that can be used to rewrite/remove IR which are mergeable/no-ops. 
   
   A1 requires generic workload compute definitions (not hardware specific) and schedules written for every hardware and for every layout. From the expression of constraints on the buffer layout, simplification of the schedule can proceed in a hardware-agnostic fashion. During constraint flowing, the layout constraints on a buffer can be used to determine when agreement or contradictions exist, and materialization a function to legalize. 
   
   The main differences I see is that A0 pushes much more work into hardware specific optimization, where as A1 allows for more compiler simplification through information the user provides as hardware semantics that are true about the buffer. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] wrongtest commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
wrongtest commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891831175


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2540 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs
+
+`BufferNode` has a new member `std::vector<BufferConstraint>
+constraints` that describes known properties of this buffer.  Any
+transformation that introduces padding will also add a buffer
+constraint.
+
+```c++
+struct BufferConstraintNode {
+  Array<Var> indices;
+  PrimExpr predicate;
+  Optional<PrimExpr> value
+};
+```
+
+The `indices` holds variables that represent the index being used to
+access the buffer.  Both `predicate` and `value` are in terms of the
+variables stored in `indices`.  If `predicate` is true for a given
+value of the indices, then the buffer has contents of `value` at those
+indices.  If `value` is empty, then any indices that match the
+predicate may not be accessed.
+
+The `indices` field is automatically populated based on the
+post-transformation indices.  The `predicate` field is automatically
+determined based on the transformation, and is true for any index
+corresponding to the transformation padding.  The `value` field is
+defined by the user input in `pad_value`
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  This is
+intended for use as `BufferConstraintNode::value`, to indicate that it
+is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  This is primarily used to
+allow simplifications in a producer, as any partial computations
+written to this space (e.g. by vectorized operations) may be left
+as-is.
+
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+### Buffer Annotation of Layout Transforms
+
+TODO: Should a buffer remember which layout transforms have been
+applied to it?  It would be useful for generating converters between
+logical/transformed/physical layout.  As it is, users must provide
+inputs that have the transformed layout.
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, Callable]]`.  This provides the `value`
+field of the `BufferConstraintNode`.
+
+For buffer consumers, the buffer constraint is updated, and no further
+changes are required based on the padding value.  For buffer
+producers, the buffer constraint is updated, and an additional loop is
+added to write `pad_value` to the padding that has been introduced.
+
+```python
+# Before transforming A
+@T.prim_func
+def func(A: T.Buffer[(14,), "float32"]):
+    for i in T.serial(14):
+        A[i] = i
+
+# After applying transform_layout(lambda i: [i//4, i%4], pad_value=-1)
+@T.prim_func
+def func(A: T.Buffer[(4,4), "int32"]):
+    # This loop writes the same values, but to the new locations in
+    # `A`.
+    for i in T.serial(14):
+        A[i//4, i%4] = i
+
+    # This loop writes the padding values.  In this case, `io==3 and
+    # ii>2` is the predicate, and `-1` is the value.
+    for io,ii in T.grid(4,4):
+        if io==3 and ii>2:
+            A[io, ii] = -1
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the producer is the stage that outputs the transformed tensor.
+In TIR, the producer is the block that writes to all values of the
+pre-transformation tensor.
+
+
+
+### New Primitive - Add buffer constraint
+
+Similar to `Schedule.set_axis_separators`, this adds an annotation to
+an existing buffer, and can be used independently of
+`transform_layout`.  This can be useful for hardware that provides a
+default value for out-of-bounds reads (e.g. texture memory clamping on
+a GPU).
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+### New Transform - Hoist Expression

Review Comment:
   They are different alternative or can be combined on certain workloads. There is a discussion on performance issue of matmul when we just change dimension 128 -> 127. https://discuss.tvm.apache.org/t/te-vectorize-do-we-have-plan-to-support-vectorize-for-non-divisible-split/12469
   I think it might be a good working example. Below is what user get with loop split  `j`: 127 -> (4, 32)
   
   ```python
   for i in range(127):
       for k in range(127):
           for j.outer in range(4):
                for j.inner in T.vectorized(32):
                    if T.likely(j.outer * 32 + j.inner < 127, dtype="bool"):
                        C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
   ```
   
   The issue is that complex condition has to be introduced to maintain the program semantic, and it hurts the performance and generally we can not vectorize program with control flow.
   
   Now I understand we have different alternatives to handle this:
   - Loop partition
   We can already annotate the loop var with hint using non-imperative loop partition.
   
     `for j.outer in range(4, annotations={"pragma_loop_partition_hint": 1}`
   
     After `LoopPartition` pass (and simplify) it becomes:
     ```python
     for i in range(127):
         for k in range(127):
             # j.outer in [0, 3)
             for j.outer in range(3):
                  for j.inner in T.vectorized(32):
                       # condition is const true, optimize out
                       C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
             # j.outer in [3, 4), optimize out
             for j.inner in T.vectorized(31):
                  # condition becomes j.inner < 31, hoisted with loop
                  C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
     ```
     Then the condition branch get eliminated on different loop parts, thus becomes more friendly to performance optimizations like vectorization. For "imperative" partition, it just propose we can just partition on schedule phase when one wants to schedule different parts, such as giving different vectorization width.
   
   - Loop padding
   
     With current RFC, I understand we can padding `C` and `B`'s innermost dimension to 128, and drop the condition directly somehow. Then it directly becomes (IIUC, we may also insert some "arbitrary" value filling code on edges and optimize them out then?) On this particular case, I believe the padding is the better choice since we can get very neat codes with minimal over-computations. 
     ```python
     for i in range(127):
       for k in range(127):
           for j.outer in range(4):
                for j.inner in T.vectorized(32):
                    C[i*127 + j.outer*32 + j.inner] += A[i*127 + k] * B[k*127 + j.outer*32 + j.inner]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r892675903


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Thanks @Lunderberg for providing the details. To add a discussion point, I'd like to ask whether the semantic like over computation and writing default value of next operate, can be achieved with graph level rewriting. I agree that certain feedback loop between TIR and the Relay graph is needed. In Relay, this was previously achieved by some registration mechanism like (`AlterOpLayout`). We can use similar mechanism to guide the graph rewriting. For example, in chained elemwise operators, since we don't care about the undef value in the padding region, if we rewrite the graph to so that each operator has a shape that fits the layout requirement perfectly (which will directly generate the final TIR [here](https://github.com/Lunderberg/tvm-rfcs/blob/buffer_layout_padding/rfcs/0077-layout-transform-padding.md#apply-operator-element-wise-over-the-transformation-padding), would it be helpful to avoid the complexity in each individual operator involved dealing with the padding block and
  the conditions?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891736820


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   Thanks @Lunderberg  for dissected discussions, this is helpful.
   
   Besides Goal 1, there is one additional implied goal:
   
   - Goal 7: The compositionality of primitives of the buffer with layout-constraints. Specifically, how composable are the existing and new primitives(such as split/reorder/tensorization, reduction factorization) with the buffer layout constraints.
   
   When building abstractions like this one, we are actually trying to make a balance among two things: the simplicity/composationality and the things we can support. It is quite natural that a more complicated impl would hit more marks initially.
   
   On the other hand, there is always a consideration of added complexity. Additional field in the IR effectively means we either have to (a) introduce specific codepath to handle layout constraints, (b) generalize all relevant primitives to take that into account (introduces a N * M problem, where N is number of primitives and M is number of possible IR variantions(like layout constraints) we introduce to the IR). The N * M problem will grow as N and M increases.
   
   It is also useful to come back to the high-level goal besides these goals for a single function. Our high-level goal is to enable effective end to end models under a good native layout(which involves padding and layout transformation). And it would actually be really nice to have an example at the e2e level to show how the set of transformations affect our optimizations.
   
   Among the existing goals listed. Goal 6 is certainly a very important one. Goal 3 is primarily an implementation difference as in terms of different ways of building pattern matching. Goal 4 is not necessarily a need as many optimizations actually benefit from reduced complexity (e.g. tensorization in physical memory)
   
   Goal 6 is an important one that would indeed touches the high-level (e2e) goal itself. Along that direction, a smart variant of Impl A (by interacting with graph) would actually enable simpler realization of goal 6 (which is important) by lifting the transformations of input/output out, and then cancel out between operators, while preserving the information.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on a diff in pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   @Lunderberg That's true. If the reduction dimension is padded, we will need to insert hint in the graph to assert it was previously padded by 0.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163510231

   > It doesn't add additional semantic, the computation semantic stays the same, it is a hint to the graph compiler.
   
   My apologies, I had meant the semantics of a node from the perspective of a TIR transformation, not the semantics from the perspective of the computation being described.  For a TIR transformation, if an object is replaced, whatever attributes describe that object must be updated to refer to the new object.  So if constraints are added to the block annotation, I had been thinking of that as a change to the semantics of the `BlockRealizeNode::annotations` from "does not need to be updated when a buffer is replaced" to "must be updated when a buffer is replaced".


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1182544837

   Thanks everyone for the discussions. We have agreed on the design principles and will continue to explore scheduling options. Let's keep the RFC open for final comments until the end of this week.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1180606480

   following up on this, I think we are in broad stroke agreement that we can achieve our goals with blocl/fn attributes in IR as well as builtin assume. As a result, my original blocker for the RFC has been resolved, would still be great to work together to flesh out the details of schedule primitives and how do they interact with the rest of TIR scheduling, but I somewhat think they can be done separately and we don;t need to nail down the details of primitives.
   
   The schedule primitives can be done relatively independently as long as we agree on the principle that:
   - Transformations do not change the function interface behavior
   - We decouple the graph level decisions into two steps: local decision + rewrite
   
   We can explore possible options as long as the IR spec remains stable, if there is a need to update IR itself or meaning of attribute, we can come back and discuss again


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] vinx13 commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
vinx13 commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1155766862

   Thanks @csullivan for providing the overview. I agree that non-local approaches 2-4 are necessary. From the examples in this RFC I can also see how the components C0-C2 can be used to support these non-local approaches. C0 + C1 allows to specify the constraints during scheduling, and propagate back to the graph. Besides them, I would also like to mention another component 
   * C3: ability to specify constraints for each operator.
   
   It seems to me that C0, C1, C3 are actually choices of implementation as there are multiple ways that require a combination of them to achieve the goal of constraint flowing.
   * C0 + C1 (which imply C3 are satisfied) suggests implementing the constraints at TIR level using `BufferConstraint`. To propagate back the constraints to the graph, which is `Tensor` central, it seems the graph-level counterpart of `BufferContraints` is not clear, as @wrongtest mentioned.
   * C3 is also feasible purely in the graph, which requires some mechanism to register per-operator constraints. An example I came up with is each operator can have a list of supported layout, and the constraint solver can choose layout for each operator to approximate the global optimum for the graph. This satisfies the need for non-local approaches but doesn't need TIR level constraints. Padding, instead of explicitly inserting `transform` / `inv_transform`,  is also achievable as graph-level constraint flowing.
   
   Back to the discussion of this RFC, I think the main comments about the proposed methods is IR changes required (which may have greater impacts on the existing TIR and scheduling), and the complexity involved using the new schedule primitive to reach the final desired state. From my understanding, the intention of these new primitives is to allow arithmetic simplification to perform graph rewriting like over-computation. If this can be achieved as graph-level rewriting rule (perhaps simpler as it doesn't need arithmetic manipulations), personally I think that would still be preferred for better maintainability. Also I'd like to mention that modeling such rewriting in the graph doesn't necessary tie the TIR operator with a specific graph IR implementation. As we are moving to S-TIR scheduling, it is easy to apply some preprocessing steps to derive the PrimFunc in specific layout from a standard `te.compute` definition.
   
   Finally, I would like to encourage us to focus on the e2e goals. It seems the current approaches, either implemented as A0 or A1 in graph-level, should suffice the use cases in the inference graph. Though the training graph is probably not an immediate need, if we would like to consider their use cases, probably having some concrete examples with desired result can guide us to make better decision.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] tqchen commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1170294348

   Thanks @Lunderberg for the update, I think we are moving towards positive direction of overall IR design. Some additional feedbacks:
   
   ## Keep Schedule Decisions Local to PrimFunc then Compose
   
   On schedule primitives, to be pragmatic, it would be helpful to have some of the cross PrimFunc re-flowing done in two steps. Specifically, some of your `transform_layout` example of the functions touches buffers that involves input. One approach is of course to trace up to its producers and then rewrite the producers function as well (or trace down to consumers functions). However, the complication here is that:
   
   - There can be multiple consumers/producer TIR functions
   - In certain cases producer/consumer may not have consistent requirements.
   - The producer/consumer themselves can have their own local layout preferences that needs to be consolidated.
   
   In general it is helpful to first keep schedule decision local, e.g. introducing a caching stage (AC, BC in the example), the compose with another reflowing pass to bring the decision to consumer/producers. This is mainly to reduce the overall complexity in implementing such transformations, and also makes things more modular.
   
   ```
   @T.prim_func
   def grow(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       AC = T.alloc_buffer([4, 4], "int32")
       BC = T.alloc_buffer([4, 4], "int32")
   
       for io, ii in T.grid(4, 4):
            with T.block():
   	      T.block_attr("preproc", "pad")
                 AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
   
       for i, j in T.grid(4, 4):
           BC[i, j] = 2 * AC[i, j]
   
       for io, ii in T.grid(14):
           with T.block():
               T.block_attr("postproc", ["crop", 0])
               B[io, ii] = BC[4 * io + ii]
   
   @T.prim_func
   def addone(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
       for i in T.grid(14):
           B[i] = A[i] + 1
   
   @R.func
   def main(A: T.Tensor[14, "int32"]):
   	lv0 = call_tir(grow, [A], (14))
   	# an intermdiate stage to show non-local reflowing
   	lv1 = call_tir(addone, [lv0], (14))
   	lv2 = call_tir(grow, [lv1], (14))
   	...
   ```
   
   ## Use IfThenElse expression for Padding.
   
   While it is possible to express padding with a loop and another loop that writes the padded value, it is harder to schedule the resulting blocks as there are more than one producers. Having a single loop and use `T.if_then_else ` will express such pattern in a single shot and makes future rewriting easier.
   
   
   ```python
       for io, ii in T.grid(4, 4):
            with T.block():
   	      T.block_attr("preproc", "pad")
                 AC[io, ii] = if_then_else(4 * io + ii < 14, A[4 * io + ii], 0)
   ```
   
   ## Propagate Padding Decisions from the End.
    
   Some of the complications of duplicated condition(and their simplification) roots from the fact that we do layout transform of output and input separately(each introducing their own conditions which then needs to be simplified). It might be helpful to do a global transformation, usually driven from the output, then "backprop" the implication of that decisions to the input. Doing such transformation at a single shot will likely alleviate the need of generating extra conditions then simplifying them.
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm-rfcs] Lunderberg commented on pull request #77: [RFC] Buffer Layout Padding

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1171290053

   > In general it is helpful to first keep schedule decision local, e.g. introducing a caching stage (AC, BC in the example), the compose with another reflowing pass to bring the decision to consumer/producers.
   
   My goal with the latest update wasn't to require global decisions, but to make local changes only, which could be used in different contexts.  For the auto-scheduler, since the context requires maintaining the same PrimFunc interface, local optimization would be restricted to transformations of the caching stage.  For stand-alone usage, such as preparing a single PrimFunc for a unit test, the context allows the interface to change.  That way, the restrictions to the transformations are imposed by the level of abstraction that requires them.
   
   > While it is possible to express padding with a loop and another loop that writes the padded value, it is harder to schedule the resulting blocks as there are more than one producers. Having a single loop and use `T.if_then_else` will express such pattern in a single shot and makes future rewriting easier.
   
   I definitely agree that this makes the later analysis/rewrites easier. I had maintained them as two separate loops both to minimize the extent of changes being made in any one scheduling change, and to maintain the current behavior of `Schedule.transform_layout` which does not alter the surrounding loop iterators ([previous conversation](https://github.com/apache/tvm/pull/10538#discussion_r826209815) with @vinx13).
   
   I see four main options on how the loopnests could be handled:
   
   1. When a buffer is transformed, all loops in the producer stage over the buffers's pre-transformation axes are replaced with loops over the buffer's post-transformation spatial dimensions.  It is an error if this replacement cannot be done (e.g. the pre-transformation loops have been fused/split/reordered).
   
      - Pro: Allows the `T.if_then_else` to be inserted at the time of the transformations.
      - Pro: Removes the need for the .
      - Con: May restrict search space, since earlier manipulation of the loop iterators would prevent later buffer transformations.
      - Con: Doesn't help consumers of a transformed buffer.  In a reduction, may be desirable to iterate over the input buffer, but this couldn't be expressed in terms of an output.
      - Con: For buffers whose padding is not written to, must either insert a conditional statement or maintain the pre-transformation loop structure.
   
   2. When a buffer is transformed, an attempt is made to replace all loops in the producer stage over the buffers's pre-transformation axes with loops over the buffer's post-transformation spatial dimensions.  If this replacement cannot be done (e.g. the pre-transformation loops have been fused/split/reordered), and if `pad_value` is not `None`, then an error should be raised.
   
      - Pro: Always valid to apply a transform
      - Pro: Avoids undoing scheduling benefits from previous changes to iterators.
      - Pro: Later calls to `reduce_branching_through_overcompute` could still introduce a value for the padding, if the full life cycle of the buffer is known.
      - Con: Allowing the follow-up stage at all requires just as much analysis to identify as if it were always present.
   
   3. When a buffer is transformed, all loops in the producer stage over the buffers's pre-transformation axes are replaced with loops over the buffer's post-transformation spatial dimensions.  If this replacement cannot be done (e.g. the pre-transformation loops have been fused/split/reordered), then the follow-up stage is inserted.
   
      - Pro: Always valid to apply a transform
      - Pro: Avoids undoing scheduling benefits from previous changes to iterators.
      - Con: Allowing the follow-up stage at all requires just as much analysis to identify as if it were always present.
   
   4. When a buffer is transformed, all loops over spatial dimensions in the producer are replaced with loops over the post-tranformation buffer axes.
   
      - Pro: Always valid to apply a transform.
      - Con: May undo scheduling that has previously provided useful performance improvements.
      - Con: Loop iterators over pre-transformation indices may have been fused with reduction axes.  Would need to undo the fusion to apply.
        
   The current proposed version would be option 4, but I think I'd prefer option 2 in order to reduce the number of follow-up simplifications required.
   
   > Some of the complications of duplicated condition(and their simplification) roots from the fact that we do layout transform of output and input separately(each introducing their own conditions which then needs to be simplified). It might be helpful to do a global transformation, usually driven from the output, then "backprop" the implication of that decisions to the input. Doing such transformation at a single shot will likely alleviate the need of generating extra conditions then simplifying them.
   
   At the TIR level, I suppose I'm unclear on what "'backprop' the implication of that decisions to the input" would mean, since changing the layout of one buffer doesn't strictly require changing the layout of other buffers.  Intuitively, I can picture how it would apply to some operators (e.g. perform analogous transformations on the inputs to element-wise functions) and how those could be identified (e.g. track which indices are used for access of each buffer, and identify corresponding shapes from the indices), but I'm unclear as to how a similar intuition would be applied for more complicated functions.  (I'm also not sure if this would require a similarly difficult sequence of proofs as the proposed transforms, just with the goal of proving a preferred layout rather than proving a possible simplification.)
   
   We could allow the user to specify transformations of all buffers simultaneously, but this wouldn't really solve the problem, as the simplifications made would still need to be based on that information provided.
   
   At the graph level, I don't think a single direction of constraint propagation is sufficient.  Backward propagation, starting with the output values returned to the user, could track which indices contribute to that final output, which could be exposed to producers. Forward propagation, starting with the input values provided by the user, could track which indices of intermediate buffers contain known values, which could be exposed to consumers.
   
   With these uncertainties, I'm starting to think of `layout_transform` and `pad_value` not as a complete end-to-end handling in itself, but providing a platform on which the graph-level reasoning can be built. That is, it doesn't itself perform the graph-level reasoning, but can accept the layout/padding requirements given from graph-level reasoning.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org