You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/28 23:06:25 UTC

[GitHub] [tvm-rfcs] kparzysz-quic commented on a diff in pull request #77: [RFC] Buffer Layout Padding

kparzysz-quic commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r909051867


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0

Review Comment:
   Lines 862 and 865 are swapped.  Either that, or line 851 should read `if i//8 == 0:`.



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+* When encountering a `T.assume` statement, this should be used for
+  later simplifications.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = n//8
+
+  # After simplification.  Because the range of `n` is provided in the
+  # assumption, n//8 can be simplified.
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = 0
+  ```
+
+  These assumptions are statements only known to be true at the
+  location of the `T.assume` call.  For assumptions based on value
+  stored in a buffer, the assumption may be invalidated by later
+  writes to the buffer.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      if A[0] == B[0]:
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+
+  # After simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      # The first access of B[0] may be replaced with 0 using the
+      # assumption.
+      if A[0] == 0:
+          # These later accesses of B[0] may not be replaced, because
+          # for all loop iterations i!=0, the value stored in B[0] has
+          # been overwritten since the T.assume call.
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+  ```
+
+### New Transform - Hoist Expression
+
+A new utility `HoistExpression`, which is a generalization of the
+current `HoistIfThenElse` pass.  The transformation `HoistExpression`
+would apply to the entire body of the `PrimFunc`, and would be used to
+avoid duplication of functionality between `HoistIfThenElse` and
+`HoistExpression`.
+
+`HoistExpression` would also be exposed as a metaschedule primitive,
+acting within a specified block of the `PrimFunc`, with the
+configuration options given below.
+
+```c++
+enum class HoistConditional {
+  kNone = 0,
+  kIfElseStmt = (1<<0),
+  kIfElseExpr = (1<<1),
+  kBooleanExpression = (1<<2),
+};
+
+enum class HoistLetBinding {
+  kNone = 0,
+  kRequiredByCondition = (1<<0),
+  kLetStmt = (1<<1),
+  kLetExpr = (1m<<2),
+};
+```
+
+* The values in `HoistConditional` are bit flags, indicating which
+  conditionals should be hoisted.
+
+  * `HoistConditional::kNone` - Do not hoist conditionals
+
+  * `HoistConditional::kIfElseStmt` - If set, attempt to hoist
+    conditionals that occur within `IfThenElseNode::condition`.
+
+  * `HoistConditional::kIfElseExpr` - If set, attempt to hoist
+    conditionals that occur as the condition of a
+    `builtin::if_then_else` call.
+
+  * `HoistConditional::kBooleanExpression` - If set, attempt to hoist
+    any `PrimExpr` whose data type is `DataType::Bool()`.
+
+* The values in `HoistLetBindings` are bit flags, indicating which
+  bindings should be hoisted.
+
+  * `HoistLetBinding::kNone` - Do not hoist any let bindings.
+
+  * `HoistLetBinding::kRequiredByCondition` - If set, hoist a let
+    binding if it is required in order to hoist a conditional.
+
+  * `HoistLetBinding::kLetStmt = (1<<1)` - If set, attempt to hoist
+    any let bindings performed using `LetStmt`.
+
+  * `HoistLetBinding::kLetExpr` - If set, attempt to hoist any let
+    bindings performed using `Let`.
+
+The existing pass `HoistIfElse` is roughly equivalent to using
+`HoistExpression` with `HoistConditional::kIfElseStmt` and
+`HoistLetBinding::kNone`.  The one exception is that `HoistIfElse`
+occurs after all let bindings have been inlined, and does not check
+let bindings when determining if a condition can be hoisted.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(4,4), "float32"]):
+    for i in T.serial(4):
+        is_in_bounds = i < 3
+        if is_in_bounds:
+            A[i] = 0.0
+
+# Incorrectly hoisted by `HoistIfThenElse`
+@T.prim_func
+def func(A: T.Buffer[(4,), "float32"]) -> None:
+    is_in_bounds = T.var("bool")
+    if is_in_bounds:
+        for i in T.serial(4):
+            is_in_bounds = i < 3
+            A[i] = 0.0
+```
+
+### New Transform - Reduce Loop Extents
+
+Reduce the extent of loops based on conditionals present in the body
+of the loop.
+
+For any non-vectorized `tir::For` loop (`ForKind::kSerial` or
+`ForKind::kParallel`), if the body is a conditional and the
+conditional's `else_case` is empty, determine if the expression is of
+the form `(loop $CMP_OP const) && (...)`.  If so, use the comparison
+operator to reduce the loop extent, such that loop skips values for
+which the comparison is provably false.
+
+TODO: Double-check that this isn't already implemented elsewhere.
+
+TODO: Check if it is implementable using `IntSetAnalyzer`.
+
+Below is an example of how this can work along-side `HoistExpression`
+to simplify the initialization of padding.
+
+```python
+# Original function.
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    for i, j in T.grid(4, 4):
+        if i == 0 and j < 2:
+            A[i, j] = 0.0

Review Comment:
   This would look something like
   ```python
   for i, j in T.grid(0..1, 0..2):  # I'm making the ranges more verbose for clarity
     A[i, j] = 0.0
   for i, j in T.grid(0..1, 2..4)
     pass                           #  j < 2 is false
   for i, j in T.grid(1..4, 0..2)
     pass                           # i == 0 is false
   for i, j in T.grid(1..4, 2..4)
     pass                           # i == 0 is false, j < 2 is false
   ```



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+* When encountering a `T.assume` statement, this should be used for
+  later simplifications.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = n//8
+
+  # After simplification.  Because the range of `n` is provided in the
+  # assumption, n//8 can be simplified.
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = 0
+  ```
+
+  These assumptions are statements only known to be true at the
+  location of the `T.assume` call.  For assumptions based on value
+  stored in a buffer, the assumption may be invalidated by later
+  writes to the buffer.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      if A[0] == B[0]:
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+
+  # After simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      # The first access of B[0] may be replaced with 0 using the
+      # assumption.
+      if A[0] == 0:
+          # These later accesses of B[0] may not be replaced, because
+          # for all loop iterations i!=0, the value stored in B[0] has
+          # been overwritten since the T.assume call.
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+  ```
+
+### New Transform - Hoist Expression
+
+A new utility `HoistExpression`, which is a generalization of the
+current `HoistIfThenElse` pass.  The transformation `HoistExpression`
+would apply to the entire body of the `PrimFunc`, and would be used to
+avoid duplication of functionality between `HoistIfThenElse` and
+`HoistExpression`.
+
+`HoistExpression` would also be exposed as a metaschedule primitive,
+acting within a specified block of the `PrimFunc`, with the
+configuration options given below.
+
+```c++
+enum class HoistConditional {
+  kNone = 0,
+  kIfElseStmt = (1<<0),
+  kIfElseExpr = (1<<1),
+  kBooleanExpression = (1<<2),
+};
+
+enum class HoistLetBinding {
+  kNone = 0,
+  kRequiredByCondition = (1<<0),
+  kLetStmt = (1<<1),
+  kLetExpr = (1m<<2),
+};
+```
+
+* The values in `HoistConditional` are bit flags, indicating which
+  conditionals should be hoisted.
+
+  * `HoistConditional::kNone` - Do not hoist conditionals
+
+  * `HoistConditional::kIfElseStmt` - If set, attempt to hoist
+    conditionals that occur within `IfThenElseNode::condition`.
+
+  * `HoistConditional::kIfElseExpr` - If set, attempt to hoist
+    conditionals that occur as the condition of a
+    `builtin::if_then_else` call.
+
+  * `HoistConditional::kBooleanExpression` - If set, attempt to hoist
+    any `PrimExpr` whose data type is `DataType::Bool()`.
+
+* The values in `HoistLetBindings` are bit flags, indicating which
+  bindings should be hoisted.
+
+  * `HoistLetBinding::kNone` - Do not hoist any let bindings.
+
+  * `HoistLetBinding::kRequiredByCondition` - If set, hoist a let
+    binding if it is required in order to hoist a conditional.
+
+  * `HoistLetBinding::kLetStmt = (1<<1)` - If set, attempt to hoist
+    any let bindings performed using `LetStmt`.
+
+  * `HoistLetBinding::kLetExpr` - If set, attempt to hoist any let
+    bindings performed using `Let`.
+
+The existing pass `HoistIfElse` is roughly equivalent to using
+`HoistExpression` with `HoistConditional::kIfElseStmt` and
+`HoistLetBinding::kNone`.  The one exception is that `HoistIfElse`
+occurs after all let bindings have been inlined, and does not check
+let bindings when determining if a condition can be hoisted.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(4,4), "float32"]):
+    for i in T.serial(4):
+        is_in_bounds = i < 3
+        if is_in_bounds:
+            A[i] = 0.0
+
+# Incorrectly hoisted by `HoistIfThenElse`
+@T.prim_func
+def func(A: T.Buffer[(4,), "float32"]) -> None:
+    is_in_bounds = T.var("bool")
+    if is_in_bounds:
+        for i in T.serial(4):
+            is_in_bounds = i < 3
+            A[i] = 0.0
+```
+
+### New Transform - Reduce Loop Extents
+
+Reduce the extent of loops based on conditionals present in the body
+of the loop.
+
+For any non-vectorized `tir::For` loop (`ForKind::kSerial` or
+`ForKind::kParallel`), if the body is a conditional and the
+conditional's `else_case` is empty, determine if the expression is of
+the form `(loop $CMP_OP const) && (...)`.  If so, use the comparison
+operator to reduce the loop extent, such that loop skips values for
+which the comparison is provably false.
+
+TODO: Double-check that this isn't already implemented elsewhere.
+
+TODO: Check if it is implementable using `IntSetAnalyzer`.
+
+Below is an example of how this can work along-side `HoistExpression`
+to simplify the initialization of padding.
+
+```python
+# Original function.
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    for i, j in T.grid(4, 4):
+        if i == 0 and j < 2:
+            A[i, j] = 0.0
+
+
+# After hoisting with HoistConditional::kBooleanExpression
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    for i in T.serial(4):
+        if i == 0:
+            for j in T.serial(4):
+                if j < 2:
+                    A[i, j] = 0.0
+
+
+# After reducing the extents of serial loops
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"]):
+    i = 0
+    for j in T.serial(2):
+        A[i, j] = 0.0
+```
+
+
+
+### Utility - Merge Adjacent Loops
+
+If it does not impact the resulting computation, loops may be merged
+together.  This is a valid transformation if both loops are serial
+loops, the loops have the same indices, and if the merging respects
+data dependencies.  This would be exposed as a metaschedule primitive,
+which takes input of the `LoopRV` to be merged.
+
+For adjacent loops, to prove that there is no data dependency, two
+conditions must hold.
+
+1. For all loop indices `i` and `j` where `i > j`, the set of indices
+   written by the first loop in iteration `i` is distinct from the set
+   of indices accessed by the second loop in iteration `j`.  That is,
+   merging the loops wouldn't cause the second loop body to read
+   partial values, nor would it cause the first loop body to overwrite
+   a value produced by the second loop body.
+
+2. For all loop indices `i` and `j` where `i < j`, the set of indices
+   read by the second loop in iteration `i` is distinct from the set
+   of indices written by the second loop in iteration `j`.  That is,
+   merging the loops wouldn't cause the second loop body to overwrite
+   values that are still required by the first loop body.
+
+Element-wise loops do not have any data dependencies, and adjacent
+element-wise loops may be merged.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+
+    for i in T.serial(16):
+        A[i] = 1.0
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j].
+#    c. intersection([i], [j]) = [i] if i==j else [].
+#    d. If i>j, the intersection is empty
+#
+# 2. a. In iteration i, loop 1 reads from index [].
+#    b. In iteration j, loop 2 writes to index [j]
+#    c. intersection([], [j]) = []
+#    c. For all i,j, the intersection is empty
+#
+# Therefore, this merger is valid
+
+# After merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+        A[i] = 1.0
+```
+
+The second loop may read indices that were written in an earlier
+iteration.  Merging would not impact the result.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+
+    for i in T.serial(16):
+        if i > 0:
+            A[i] = A[i - 1] + 1.0
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j,j-1].
+#    c. i>j implies that i!=j and i!=j-1.
+#    c. For all i,j where i<j,
+#
+# 2. a. In iteration i, loop 1 reads from index [].
+#    b. In iteration j, loop 2 writes to index [j]
+#    c. For all i,j, intersection([], [j]) = [].
+#
+# Therefore, this merger is valid
+
+
+# After merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = 0.0
+        if i > 0:
+            A[i] = A[i - 1] + i
+```
+
+The second loop may not read indices that were written in a later
+iteration of the first loop.  In this case, merging would impact the
+output values.
+
+```python
+# Before merging adjacent loops
+@T.prim_func
+def func(A: T.Buffer[(16,), "float32"]):
+    for i in T.serial(16):
+        A[i] = i
+
+    for i in T.serial(16):
+        if 0 < i < 15:
+            A[i] = A[i - 1] + A[i] + A[i + 1]
+
+
+# 1. a. In iteration i, loop 1 writes to index [i].
+#    b. In iteration j, loop 2 accesses index [j-1,j,j+1].
+#    c. If i==j+1, then intersection([j+1], [j-1,j,j+1]) = [j+1],
+#       which is non-empty.
+#
+# Therefore, this merger is not valid.
+```
+
+### New Primitive - Remove Branching Through Overcompute
+
+A new transform which attempts to reduce branching by allowing
+overcompute.  It takes an argument to specify which block it should be
+applied within.
+
+For each `IfThenElseStmt`, check if the
+`IfThenElseStmtNode::else_case` is a simplified form of the
+`IfThenElseStmtNode::then_case`.  This check is done by simplifying
+`then_case`, under the assumption that `condition` is false, and
+substituting the known value in a `BufferConstraint` in any
+`BufferLoad` for which the predicate can be proven to be true.  If
+this simplified form is identical to the `else_case`, then the entire
+if/else block can be replaced with `then_case`.  Otherwise, this check
+is repeated to see if the `then_case` can be simplified down to the
+`else_case`.  If neither simplification holds, then no change is made.
+
+For example, consider the following example.  This is a 1-d
+convolution, where both the input and output buffers have a layout
+transformation applied.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(14,), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + A[i + f]
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(14,), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + A[(i + f) // 4, (i + f) % 4]
+
+
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=0.0)
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(4, 4), "float32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + A[(i + f) // 4, (i + f) % 4]
+
+        for io,ii in T.grid(4,4):
+            if io==3 and ii>=2:
+                B[io,ii] = 0.0
+
+
+# sched.sequential_buffer_access(block='compute', buffer='B')
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "float32"],
+    F: T.Buffer[(3,), "float32"],
+    B: T.Buffer[(4, 4), "float32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 0 <= 4*io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    B[io, ii] = B[io, ii] + A[io + (ii + f) // 4, (ii + f) % 4]
+
+        for io,ii in T.grid(4,4):
+            if io==3 and ii>=2:
+                B[io,ii] = 0.0
+```
+
+
+We'd like to remove the conditional `if 0 <= 4*io + ii < 14` in the
+compute loop.  In order to do so, we need to prove that the body of
+the conditional is a no-op in the case where the conditional is false.
+
+Using the [updated `DomainTouched`
+utility](#enhancement-remove-no-op), this else-block would be a no-op.
+It is a write to `B[io,ii]` predicated on `4*io+ii >= 14`, followed by
+a write to `B[io,ii]` predicated on `io==3 and ii>=2`, without a read
+in between.  Since these predicates are equivalent, the first write is
+a no-op.
+
+```python
+# sched.remove_branching_through_overcompute(block='compute')

Review Comment:
   Does this only apply to outputs?  I think we should per-buffer directive that indicates that out-of-bounds access is allowed. The only thing in question is how to determine/specify that out-of-bounds reads from inputs is ok.  The user can add padding -INF to inputs to maxpool, but how does the maxpool compute know that it can use the out-of-bounds values?
   
   As to whether to actually utilize this should probably be left to the compiler.  Auto-scheduling should not be a replacement for compiler optimizations.



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+              B[i] = 3.0
+          else:
+              A[i] = 1.0
+              B[i] = 2.0
+  ```
+
+  Because the body of one conditional may alter the result of the next
+  conditional, conditionals should not be merged if they depend on
+  buffer values for data-dependent conditionals.  Only conditionals
+  that do not depend on mutable values should be merged.
+
+  ```python
+  # Data-dependent conditional, may not be merged
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+
+          if A[i] < 0.0:
+              A[i] = 0.0
+
+
+  # INCORRECT result of illegal merging of conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if A[i] < 0.0:
+              A[i] = A[i] + 1.0
+              A[i] = 0.0
+  ```
+
+* When encountering a `T.assume` statement, this should be used for
+  later simplifications.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = n//8
+
+  # After simplification.  Because the range of `n` is provided in the
+  # assumption, n//8 can be simplified.
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], n: T.int32):
+      T.assume(n >= 0 and n < 8)
+
+      for i in T.serial(16):
+          A[i] = 0
+  ```
+
+  These assumptions are statements only known to be true at the
+  location of the `T.assume` call.  For assumptions based on value
+  stored in a buffer, the assumption may be invalidated by later
+  writes to the buffer.
+
+  ```python
+  # Before simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      if A[0] == B[0]:
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+
+  # After simplification
+  @T.prim_func
+  def func(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+      T.assume(B[0] == 0)
+
+      # The first access of B[0] may be replaced with 0 using the
+      # assumption.
+      if A[0] == 0:
+          # These later accesses of B[0] may not be replaced, because
+          # for all loop iterations i!=0, the value stored in B[0] has
+          # been overwritten since the T.assume call.
+          for i in T.serial(16):
+              B[0] = B[0] + A[i]
+  ```
+
+### New Transform - Hoist Expression
+
+A new utility `HoistExpression`, which is a generalization of the
+current `HoistIfThenElse` pass.  The transformation `HoistExpression`
+would apply to the entire body of the `PrimFunc`, and would be used to
+avoid duplication of functionality between `HoistIfThenElse` and
+`HoistExpression`.
+
+`HoistExpression` would also be exposed as a metaschedule primitive,
+acting within a specified block of the `PrimFunc`, with the
+configuration options given below.
+
+```c++
+enum class HoistConditional {
+  kNone = 0,
+  kIfElseStmt = (1<<0),
+  kIfElseExpr = (1<<1),
+  kBooleanExpression = (1<<2),
+};
+
+enum class HoistLetBinding {
+  kNone = 0,
+  kRequiredByCondition = (1<<0),
+  kLetStmt = (1<<1),
+  kLetExpr = (1m<<2),
+};
+```
+
+* The values in `HoistConditional` are bit flags, indicating which
+  conditionals should be hoisted.
+
+  * `HoistConditional::kNone` - Do not hoist conditionals
+
+  * `HoistConditional::kIfElseStmt` - If set, attempt to hoist
+    conditionals that occur within `IfThenElseNode::condition`.
+
+  * `HoistConditional::kIfElseExpr` - If set, attempt to hoist
+    conditionals that occur as the condition of a
+    `builtin::if_then_else` call.
+
+  * `HoistConditional::kBooleanExpression` - If set, attempt to hoist
+    any `PrimExpr` whose data type is `DataType::Bool()`.
+
+* The values in `HoistLetBindings` are bit flags, indicating which
+  bindings should be hoisted.
+
+  * `HoistLetBinding::kNone` - Do not hoist any let bindings.
+
+  * `HoistLetBinding::kRequiredByCondition` - If set, hoist a let
+    binding if it is required in order to hoist a conditional.
+
+  * `HoistLetBinding::kLetStmt = (1<<1)` - If set, attempt to hoist
+    any let bindings performed using `LetStmt`.
+
+  * `HoistLetBinding::kLetExpr` - If set, attempt to hoist any let
+    bindings performed using `Let`.
+
+The existing pass `HoistIfElse` is roughly equivalent to using
+`HoistExpression` with `HoistConditional::kIfElseStmt` and
+`HoistLetBinding::kNone`.  The one exception is that `HoistIfElse`
+occurs after all let bindings have been inlined, and does not check
+let bindings when determining if a condition can be hoisted.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(4,4), "float32"]):
+    for i in T.serial(4):
+        is_in_bounds = i < 3
+        if is_in_bounds:
+            A[i] = 0.0
+
+# Incorrectly hoisted by `HoistIfThenElse`
+@T.prim_func
+def func(A: T.Buffer[(4,), "float32"]) -> None:
+    is_in_bounds = T.var("bool")
+    if is_in_bounds:
+        for i in T.serial(4):
+            is_in_bounds = i < 3
+            A[i] = 0.0
+```
+
+### New Transform - Reduce Loop Extents

Review Comment:
   I don't think this is necessary.  We could simply reuse loop partitioning, and break off pieces of the nest that will never execute.



##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,3090 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [New TIR Op, `tir::builtin::assume`](#new-tir-op-tirbuiltinassume)
+    - [New TIR Op, `tir::builtin::undef`](#new-tir-op-tirbuiltinundef)
+    - [Transformations/Metaschedule Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - `cache_read`, `cache_write`](#enhancement---cache_read-cache_write)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Utility - Reorder Loops According to Buffer](#new-utility---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove T.assume](#new-lowering-transform---remove-tassume)
+    - [New Lowering Transform - Remove T.undef](#new-lowering-transform---remove-tundef)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation padding](#never-write-to-transformation-padding)
+    - [Never read from transformation padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains undefined values.
+sched[B].transform_layout(transform, pad_value=tir.undef(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: tir.undef(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### New TIR Op, `tir::builtin::assume`
+
+A built-in operator that takes a single `PrimExpr` as an argument.  At
+compile-time, an error should be raised if the argument can be
+statically proven to be false at the point of call.  When lowering,
+the `tir::builtin::assume` should be replaced with a no-op.
+`tir::builtin::assume` is similar to the existing `tir::AssertStmt`,
+but does not result in a runtime assertion for conditions that cannot
+be proven.  This is equivalent to the [LLVM `__builtin_assume`
+intrinsic](https://clang.llvm.org/docs/LanguageExtensions.html#builtin-assume).
+
+The primary use of `assume` in this RFC is to allow local
+simplifications within a `PrimFunc` to take advantage of information
+that would otherwise require full end-to-end analysis of a model.
+(See examples in [Points of Communication](#points-of-communication).)
+
+* An assumption may only be inserted if it is statically proven, or if
+  it is asserted by a user about a user-provided value.
+
+* When splitting a PrimFunc into multiple PrimFuncs (e.g. factoring
+  out a subroutine, hoisting an initial preprocessing stage into an
+  independent PrimFunc), an assumption may become separated from the
+  expressions that had initially been used to prove the assumption.
+
+* An assumption may only be removed if it is statically proven.  A
+  user-provided assumption may never be removed, as it may already
+  have been used to perform irreversible simplifications.
+
+* The expression within an assumption should be visited and mutated
+  identically to any other `PrimExpr`.  This ensures that passes that
+  redefine variables (e.g. by inlining a Let binding) do not result in
+  an invalid expression in the `PrimExpr`.
+
+### New TIR Op, `tir::builtin::undef`
+
+A placeholder that represents a valid, but arbitrary value.  For
+consumers, this is used in `T.assume()` expressions to indicate that
+it is legal to access the address, but that no further constraints are
+placed on the value present in the buffer.  For producers, this is
+used to allow simplifications that change the value stored in the
+output padding and would otherwise be forbidden.  (e.g. Leaving
+partial computations written to padding by vectorized operations,
+rather than zero-ing them out.)
+
+* Multiplication of `0 * undef` may be simplified to zero, for both
+  integer and floating-point types.
+
+* A pure expression that uses `undef` can be simplified to `undef`.
+
+* `undef` may not occur in the indices used to access a buffer.
+
+* Two separate invocations instances of `undef` may not be assumed to
+  be identical.  For example, the expression `undef - undef` may not
+  be simplified to zero.  If this behavior is desired, the `undef` may
+  be assigned in a `tir::LetStmt`,
+
+* Storing a value of `undef` to a buffer is a no-op, and is removed
+  during lowering.  (See [section on
+  `tir.transform.RemoveUndefStore`](#new-lowering-transform-remove-tundef).)
+
+See [section on element-wise
+transformations](#apply-operator-element-wise-over-the-transformation-padding)
+for example usage.
+
+
+## Transformations/Metaschedule Primitives
+
+### Enhancement - `cache_read`, `cache_write`
+
+Can be used outside of any loop, with the same scope as the uncached
+buffer.  The layout of the cache can then be transformed to operate on
+a reshaped buffer without modifying the calling signature of the
+original `PrimFunc`.
+
+TODO: Check if this is already allowed.
+
+
+### Enhancement - transform_layout
+
+The `te.Stage.transform_layout` and `tir.Schedule.transform_layout`
+methods will be updated to take an additional argument `pad_value:
+Optional[Union[int, float, PrimExpr, Callable]]`.
+
+For a transformation that introduces padding and with a defined
+`pad_value`, a new stage is inserted following each write stage of the
+transformed buffer.  This new stage writes `pad_value` to the
+introduced padding.
+
+```python
+# Before transforming A_cache and B_cache
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # A read cache of the input A
+    A_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i] = A[i]
+
+    # The computation itself, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i] = 2 * A_cache[i]
+
+    # Copying from the write cache into the output B
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A_cache', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B_cache', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    A_cache = T.alloc_buffer(14, "float32")
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but writes to the transformed locations in `A_cache`.
+    for i in T.serial(14):
+        with T.block("A_cache"):
+            A_cache[i // 4, i % 4] = A[i]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("A_cache_padding"):
+            if 4 * io + ii >= 14:
+                A_cache[io, ii] = -1
+
+    # The compute stage is unchanged, other than the updated indices
+    # for A_cache and B_cache.
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B_cache[i // 4, i % 4] = 2 * A_cache[i // 4, i % 4]
+
+    # Immediately following the stage that produces values in the
+    # transformed A_cache, a new stage is added that writes the
+    # pad_value to the padding.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B_cache[io, ii] = -2
+
+    # When copying into the read cache, the loop iteration remains the
+    # same, but reads from the transformed locations in `B_cache`.
+    for i in T.serial(14):
+        with T.block("B_cache"):
+            B[i] = B_cache[i // 4, i % 4]
+```
+
+If `pad_value` is defined and the transformed buffer does not have a
+write stage within the body of the function, then it is an input
+argument.  In this case, a new stage is added at the beginning of the
+function, which calls `T.assume` for each input.
+
+For buffer consumers, the constraint is added to the body as a call to
+the `T.assume` builtin.  For buffer producers, the buffer constraint
+is updated, and an additional loop is added to write `pad_value` to
+the padding that has been introduced.
+
+```python
+# Before transforming A and B
+@T.prim_func
+def func(A: T.Buffer[14, "float32"], B: T.Buffer[14, "float32"]):
+    # The computation, doubling the input value
+    B_cache = T.alloc_buffer(14, "float32")
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+
+# After applying
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4], pad_value=-1)
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4], pad_value=-2)
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]):
+    # The buffer A does not have a write stage within this buffer.
+    # Therefore, a new stage is inserted that calls T.assume.  The
+    # assumption provided states that either the transformed indices
+    # correspond to a set of indices in the pre-transformation buffer
+    # (4*io + 11 < 14), or the value stored in the buffer is the
+    # pad_value `A[io, ii] == -1`.
+    for io, ii in T.grid(4, 4):
+        T.assume(4 * io + ii < 14 or A[io, ii] == -1)
+
+    # The computation, doubling the input value
+    for i in T.serial(14):
+        with T.block("compute"):
+            B[i] = 2 * A[i]
+
+    # The buffer B is an argument to the function, but contains a
+    # write stage.  Therefore, we add a stage that writes the
+    # pad_value after the write stage.
+    for io, ii in T.grid(4, 4):
+        with T.block("B_cache_padding"):
+            if 4 * io + ii >= 14:
+                B[io, ii] = -2
+```
+
+It is expected that the loop that writes padding may be simplified
+later.  In this case, the loop over `io` can be removed, and the range
+of the loop over `ii` can be reduced to `2 <= ii < 4`.  However, the
+default implementation should not perform these simplification yet, as
+this form is useful for [merging
+loopnests](#utility-merge-adjacent-loops) after [rewriting for
+sequential buffer
+access](#new-utility-reorder-loops-according-to-buffer).
+
+In TE, the write stage of a buffer is the stage that outputs the
+transformed tensor.  In TIR, the write stage of a buffer is any block
+that writes to all values of the pre-transformation tensor.
+
+If a transformed buffer is an argument to the PrimFunc, then this
+transformation alters the interface of the PrimFunc.  Whether this is
+allowed strongly depends on the context in which the PrimFunc is being
+used.
+
+* If a PrimFunc must remain compatible with the current calling
+  context, `transform_layout` may not be applied to argument buffers.
+  For example, when creating an optimization candidate of a subgraph,
+  if there is no legalization pass to handle layout disagreements
+  between adjacent subgraphs, the candidate must remain compatible
+  with the calling scope.
+
+* If a PrimFunc is being modified as part of a transformation that
+  also changes the context, `transform_layout` may be applied to
+  argument buffers.  For example, if an end-to-end model is
+  represented within a single `IRModule`, a transformation may alter a
+  subgraph's calling convention and the call into the subgraph at the
+  same time.
+
+* If a PrimFunc is being modified independent independent of any
+  context, `transform_layout` may be applied to argument buffers.  For
+  example, a PrimFunc that is being prepared for use as a subgraph,
+  but is not yet part of a graph, may be altered.
+
+
+### New Utility - Reorder Loops According to Buffer
+
+By default in S-TIR, `transform_layout` modifies the underlying layout
+of a buffer, but does not re-order loops that iterate over the buffer.
+The loop iterators can be re-written using split/fuse/reorder, but
+doing so requires the user to manually translate the layout
+transformation into the appropriate sequence of schedule primitives.
+
+A new utility method `Schedule.sequential_buffer_access` should be
+introduced, which generates and applies the sequence of
+split/fuse/reorder schedule primitives such that the loop iterators are
+rewritten for sequential access of a specific buffer.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(16,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(16):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            A[io, ii] = 4 * io + ii
+```
+
+This transformation is similar to what can be done using
+split/fuse/reorder, but has two key differences.  First, it presents a
+simpler user experience, as a transformed buffer can be accessed
+sequentially without needing to duplicate the information in the
+transformation.
+
+Similar to `Schedule.split`, if the loop extents do not evenly divide
+the transformation being applied, this primitive must introduce
+conditionals to avoid accessing elements that were not previously
+accessed.
+
+```python
+# Original function
+@T.prim_func
+def func(A: T.Buffer[(14,), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i] = i
+
+
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for i in T.serial(14):
+            A[i // 4, i % 4] = i
+
+
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def func(A: T.Buffer[(4, 4), "int32"]):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                A[io, ii] = 4 * io + ii
+```
+
+`Schedule.sequential_buffer_access` can operate on input buffers as
+well as output buffers.
+
+```python
+# Original function
+@T.prim_func
+def func(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i] = 0.0
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Option 1: Rewriting loops to match B's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by B's access indices
+# io = i//4
+# ii = i%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for io, ii in T.grid(4, 4):
+            if 4 * io + ii < 14:
+                B[io, ii] = 0.0
+                for f in T.serial(3):
+                    # A's indices simplify from
+                    #      [(i + f) // 4, (i + f) % 4]
+                    #   => [(4*io + ii + f) // 4, (4*io + ii + f) % 4]
+                    #   => [io + (ii + f) // 4, (ii + f) % 4]
+                    B[io, ii] = B[io, ii] + F[f] * A[io + (ii + f) // 4, (ii + f) % 4]
+
+
+# Option 2: Rewriting loops to match A's layout
+# sched.sequential_buffer_access(block='compute', buffer='A')
+#
+# New iterators defined by A's access indices
+# io = (i+f)//4
+# ii = (i+f)%4
+#
+# Invert to find non-reduction axes to be replaced.
+# i = 4*io + ii - f
+@T.prim_func
+def func(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    # Because the initialization of B[i//4, i%4] does not depend on f,
+    # it cannot be expressed solely in terms of io and ii.  Therefore,
+    # the initialization must be split into a separate loopnest.
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            B[i // 4, i % 4] = 0.0
+
+    with T.block('compute'):
+        for io,ii in T.grid(4,4):
+            for f in T.serial(3):
+                if 0 <= 4*io + ii - f < 14:
+                    # B's indices simplify from
+                    #      [i // 4, i%4]
+                    #   => [(4*io + ii - f) // 4, (4*io + ii - f)%4]
+                    #   => [io + (ii - f) // 4, (ii - f)%4]
+                    B[io + (ii - f) // 4, (ii - f) % 4] = (
+                        B[io + (ii - f) // 4, (ii - f) % 4] + F[f] * A[io, ii]
+                    )
+```
+
+In some cases, it may not be possible to separate out the
+initialization and computation in order to rewrite the loops for
+sequential buffer accesss.  In this case,
+`Schedule.sequential_buffer_access` will raise an error.
+
+```python
+# Original function
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(16,), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(14,), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i] = 0
+            else:
+                B[i] = B[i - 1]
+
+            for f in T.serial(3):
+                B[i] = B[i] + F[f] * A[i + f]
+
+
+# After transforming A's layout and B's layout, before rewriting loops
+#
+# sched.transform_layout(block='compute', buffer='A', lambda i: [i//4, i%4])
+# sched.transform_layout(block='compute', buffer='B', lambda i: [i//4, i%4])
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+
+
+# Intermediate formed when attempting to re-order access to be
+# sequential along A's layout.  This is not a legal transformation,
+# because the initialization step requires the previous result the
+# computation loop.  Therefore, Schedule.sequential_buffer_access will
+# raise an error.
+#
+# sched.sequential_buffer_access(block='compute', buffer='A')
+@T.prim_func
+def conv1d_cumsum(
+    A: T.Buffer[(4, 4), "int32"],
+    F: T.Buffer[(3,), "int32"],
+    B: T.Buffer[(4, 4), "int32"],
+):
+    with T.block('init_compute'):
+        for i in T.serial(14):
+            if i == 0:
+                B[i // 4, i % 4] = 0
+            else:
+                B[i // 4, i % 4] = B[(i - 1) // 4, (i - 1) % 4]
+
+    with T.block('compute'):
+        for i in T.serial(14):
+            for f in T.serial(3):
+                B[i // 4, i % 4] = B[i // 4, i % 4] + F[f] * A[(i + f) // 4, (i + f) % 4]
+```
+
+This utility is not required for the TE interface, as the loopnest of
+an output tensor is automatically rewritten to a row-major traversal.
+
+
+### Enhancement - Predicate for DomainTouched
+
+In `tvm::arith::DomainTouched`, track the condition for which a buffer
+is touched, in addition to the indices that are touched.
+
+### Enhancement - Remove No Op
+
+Changes to be made to `tvm::tir::NoOpRemover`, which implements the
+`tir.transform.RemoveNoOp` transform.
+
+* If two sequential `BufferStore` occur, both of which write to the
+  same buffer/index, and the second value stored does not read out the
+  first value, then the first store is a no-op.
+
+* If there exist two sequential blocks, the buffers/indices written by
+  the second block are a superset of the buffers/indices written by
+  the first block, and the second block does not read the
+  buffer/indices written by the first block, then the first block is a
+  no-op.
+
+* Reading a value then immediately writing it back is a no-op.  A
+  `BufferLoad` that is immediately used as a value to a `BufferStore`,
+  with the same buffer and indices, can be removed.
+
+  This functionality is currently part of
+  `tvm::arith::StmtSimplifier`, but is needed here to recognize
+  strings of no-op.  (Thought: Merge the Simplify and RemoveNoOp
+  passes?)
+
+* Writing a value that is known to exist within the buffer is a no-op.
+
+  ```python
+  # Before RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      B[0] = 0.0
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+
+  # After RemoveNoOp
+  @T.prim_func
+  def sum(A: T.Buffer[16, "float32"], B: T.Buffer[1, "float32"]):
+      T.assume(B[0] == 0.0)
+
+      for i in T.serial(16):
+          B[0] = B[0] + A[i]
+  ```
+
+
+### Enhancement - Simplify
+
+Changes to be made to `tvm::arith::StmtSimplifier` mutator, used in
+the `tir.transform.Simplify` transform.
+
+* When visiting an `IfThenElseStmt`, if the `then_case` and
+  `else_case` are identical, replace with
+  `SeqStmt({Evaluate(condition)}, then_case)`.
+
+  Currently, the `tvm::arith::StmtSimplifier` mutator, checks if a
+  condition can be proven, but doesn't do any checks on the body.
+
+  TODO: Double-check that functionality doesn't already exist.
+
+* If two sequential `IfThenElseStmt` have identical conditions, they
+  should be merged.  Conditions are identical if each condition can be
+  used to prove the other is true, even if they do not have the same
+  functional form.
+
+  ```python
+  # Before merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i//8 == 1:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+  # After merging identical conditionals
+  @T.prim_func
+  def func(A: T.Buffer[16, "float32"], B: T.Buffer[16, "float32"]):
+      for i in T.serial(16):
+          if i < 8:
+              A[i] = 0.0
+              B[i] = 2.0
+          else:
+              A[i] = 1.0
+              B[i] = 3.0
+  ```
+
+  Similarly, if two sequential `IfThenElseStmt` have complementary
+  conditions, they should be merged, with the `else_case` of the
+  second conditional appended to the `then_case` of the first, and
+  vice versa.  Conditions are complementary if assuming either
+  condition can be used to prove the other is false.
+
+  (Example usage in [later producer/consumer
+  section](#explicitly-write-next-operators-desired-default-at-end-of-function).)
+
+  ```python
+  # Before merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:
+              A[i] = 0.0
+          else:
+              A[i] = 1.0
+
+          if i==3 and j>=2:
+              B[i] = 2.0
+          else:
+              B[i] = 3.0
+
+
+  # After merging complementary conditionals
+  @T.prim_func
+  def func(A: T.Buffer[(4,4), "float32"], B: T.Buffer[(4,4), "float32"]):
+      for i,j in T.grid(4,4):
+          if 4*i + j < 14:

Review Comment:
   Which condition should be kept?  How do we decide?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org