You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2022/03/07 02:22:20 UTC

[tvm] branch main updated: [TE][TIR] Implement layout transformations, non-flat memory buffers (#9727)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 174d09e  [TE][TIR] Implement layout transformations, non-flat memory buffers (#9727)
174d09e is described below

commit 174d09ee2cef1ea2caab4c84e0bd58d90c09178f
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Sun Mar 6 20:21:50 2022 -0600

    [TE][TIR] Implement layout transformations, non-flat memory buffers (#9727)
    
    * [TIR] Added BufferLoadNode::LegalizeDtype
    
    When modifying a BufferLoad object, the return dtype must also be
    updated.  This exposes the legalization function, so that passes that
    use `BufferLoad::CopyOnWrite` to modify the buffer/indices don't need
    to repeat the logic to update the dtype returned.
    
    * Replacing Store/Load in Stmt/Expr Visitor/Mutator
    
    * Removing Store/Load from optimization passes
    
    - UpdatePointerStorageScope
    - UnrollLoop
    - ThreadSync
    - LinearAccessPatternFinder
    - StoragePlanRewriter
    - VectorTypeRewriter
    - VectorTypeAccessChecker
    - NarrowDataType
    - IRConvertSSA
    - CompactBufferRegion
    
    * Removing Store/Load from examples
    
    - ConvertAddToSubtract
    
    * Replacing Store/Load in StorageFlatten
    
    Now, outputs BufferLoad/BufferStore with a flattened buffer object.
    
    temp commit, replacing Store/Load, BufferBindUnwrapper
    
    temp commit, replacing Store/Load, StorageFlattener
    
    * Replacing Store/Load in utility passes.
    
    - StmtSimplifier
    - IRSubstitute
    - BaseInliner
    - FeatureVisitor
    
    * Replacing Store/Load in analysis functions
    
    - StorageAccessVisitor
    - VarTouchedAnalysis
    - MemoryAccessVerifier
    - InplaceOpVerifier
    - GPUCodeVerifier
    - VarTouchVisitor
    - LCADetector
    - BlockReadWriteDetector
    - InstrumentBoundCheckers
    
    * Replacing Store/Load in lowering/legalization passes.
    
    - MakeCrossThreadReduction
    - CacheReadRewriter/CacheWriteRewriter
    - InjectVirtualThread
    - InjectDoubleBuffer
    - InjectCopyIntrin
    - LowerWarpMemory
    - LowerThreadAllreduce
    - LowerThreadAllreduce
    - LowerCustomDatatypes
    - LowerTVMBuiltin
    - CoProcSync
    - MergeDynamicSharedMemAllocations
    - VectorizeLoop
    - BF16Legalize
    
    * Replacing Load/Store in codegens.
    
    - Device code generators
      - CodegenC
      - CodegenLLVM
      - CodeGenOpenCL
    
    - Utilities used during codegen
      - ArgBinder
      - MakePackedAPI
      - ReturnRewriter
      - SplitHostDevice
    
    - Execution environments
      - CodeGenStackVM
      - CodeGenHybrid
      - AOTExecutorCodegen
    
    * [UnitTest] Add unit tests to test physical layout remapping.
    
    * Updated tvm::address_of() to hold BufferLoad instead of Load.
    
    * [TIR] Added IndexMap class.
    
    Holds a set of variables representing the input indices and
    expressions in terms of those input indices.
    
    TODO:
    
    - Add validation, the index mapping should be invertible.
    - Add helper function, apply mapping to a set of indices.
    - Add helper function, apply mapping to bounds of input indices.
    
    * Updated Buffer::vstore/vload to return BufferLoad/BufferStore objects.
    
    StorageFlatten/FlattenBuffer passes updated to modify the
    buffer/indices directly, rather than using vload/vstore.
    
    - Primary purpose of vstore/vload is to allow IR written in python to
      define vectorized load/store.  This usage is maintained by returning
      a BufferLoad/BufferStore node whose index is a Ramp.
    
    - Previously, vstore/vload was also used to compute the 1-d physical
      index of a location within a N-d tensor.  This usage will no longer
      be allowed, as it would not allow layout transformations to be
      performed after a schedule definition, but any uses of the buffer
      are flattened.
    
    * [TE] Added Stage::transform_layout to the C++ TE implementation.
    
    Adds an `Array<IndexMap>` in the stage to define the transformations
    to be applied on the tensor's layout.  As of this commit, this mapping
    isn't propagated into the TIR graph yet.
    
    * Replace Store/Load with BufferStore/BufferLoad in ir_builder
    
    * [TE] Added Stage.transform_layout to the Python TE interface.
    
    Allows users to specify `s[A].transform_layout(mapping)`, and
    propagate into the TE definitions.
    
    * Added pre_flattened_shape/pre_flattened_stride fields to Buffer.
    
    The shape and stride checks performed in ArgBinder::BindDLTensor
    (called from MakePackedAPI) require the tensor shape/strides prior to
    index flattening.  Therefore, though it is no longer used by the
    low-level code generators, we must maintain that information for use
    in MakePackedAPI.
    
    * [UnitTest] Test N-d indices exposed to low-level codegen
    
    When using te.AXIS_SEPARATOR in the call to .transform_layout, this
    should define groups of axes, each of which is flattened to a single
    axis, then exposed to the low-level codegen.
    
    * [TIR] Added PrimFunc attribute "layout_transform_map", filled from TE.
    
    Propagated the TE definition of the physical layout into the TIR
    graph.
    
    * Added pre_flattened_type.
    
    If a boolean tensor is backed by an int8 buffer, the check on the
    argument buffer's type should be against the boolean type.
    
    When rebasing this PR, should be placed after the addition of
    pre_flatten_shape/pre_flatten_strides.
    
    * [UnitTest] Added tests for loop iteration order.
    
    After transformation, the iteration order should follow the new
    transformed axes.  In addition, the loop iteration variables should be
    exposed through the TE interface for further manipulation.
    
    * [TIR] Added BufferNode::axis_separators
    
    - Add axis_separators to represent divisions between groups
      of tensor axes, where each group is flattened into a single
      output axis, to be exposed to the low-level code generators.
    
    - Expose axis_separators to the python interface.
    
    - Update existing C++ calls to the Buffer() constructor.
    
    * [TIR] Added ApplyLayoutTransforms as part of StorageFlatten.
    
    For any buffers that have layout transforms defined in the
    "layout_transform_map" attribute of a PrimFunc, rewrite access into
    the buffer such that they use the updated ordering.
    
    * Update usage of ir_builder where necessary.
    
    * [TE] Implement te::Transform
    
    Similar to Fuse and Split, this represents a modification to the
    existing loop iterations.
    
    * [TE] Added Stage::set_axis_separators.
    
    In C++, this is implemented as an `Array<IntImm>`, specifying
    pre-flatteneing axes after which a new post-flattening should be
    started.  The python interface uses a sentinel value
    `te.AXIS_SEPARATOR` in the call to `transform_layout`, which is then
    used to define the array of axis separators.
    
    * [TIR] Expose tir.transform.ApplyLayoutTransforms for testing
    
    * [TE] Rewrite loop iteration order
    
    After .transform_layout, rewrite leaf_iter_vars to follow the updated
    order.  Use the te::Transform iter_var relationship to track use of
    the transformed variable.
    
    * [TE] Fill BufferNode::axis_separators from StageNode
    
    During ScheduleOps and SchedulePostprocToPrimfunc, the axis separators
    defined in the stage must be passed through to the TIR BufferNode.
    
    * [TE] Return transformed iteration variables
    
    * Moved Buffer's pre-flatten information to PrimFunc.
    
    Since the pre-flatten information is only used for validating user
    inputs, it makes much more sense to store it alongside the buffer_map.
    
    * Updated ethos-u C++ unit tests to remove use of Load/Store.
    
    * Bugfix, layout transformation.
    
    Error occured during conversion from TE to IRModule, when layout
    transforms were applied to a reader of a `cache_read`.
    
    * In test directory, replacing all instances of T.load.
    
    * Return buffer object from tvm.tir.script.scope_handler.Allocate
    
    Now that the load/store require buffer objects, allocation should also
    return a buffer object to be used.
    
    * Added .astype to tvm.script.tir.node.BufferSlice
    
    Since `buf[i]` returns a `BufferSlice`, this lets the TIR examples
    that use `buf[i].astype('out_dtype')` continue functioning.
    
    * Replacing all T.store TIR calls.
    
    * Added LOG(FATAL) in constructor of Store/Load nodes.
    
    * Updated tvmscript parser to report error for Store/Load nodes.
    
    * [TVMScript] Added T.preflattened_buffer stmt
    
    Used to specify `PrimFunc::preflattened_buffer_map`. Takes an argument
    of the postflattened buffer, so that it will work for both simple
    declarations and `T.match_buffer` statements without needing to
    introduce a param handle.  All other arguments are identical to
    `T.match_buffer.`
    
    * [TVMScript] Updated TVMscript for BufferLoad/BufferStore
    
    - Use `T.preflattened_buffer` calls in TVMScript to represent
      `PrimFunc::preflattened_buffer_map`.
    
    - Remove `T.buffer_decl` for return value of `T.allocate`, now that
      `T.allocate` returns a buffer.
    
    - For buffer access as a different type, make a `T.buffer_decl` for
      those accesses.
    
    * Updated test_tvmscript_roundtrip.py for BufferLoad/BufferStore.
    
    * Updated TIR reference in USMP pool allocation unit tests.
    
    Using let var handles as the data pointer in buffers, rather than just
    as `T.load`/`T.store` arguments, requires annotation as
    `T.Ptr[T.primtype]`, rather than as `T.handle`.
    
    * fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
    
    * fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
    
    * fixup! In test directory, replacing all instances of T.load.
    
    * tir.ComputeInline, correct variable count.
    
    Previously, this metaschedule primitive relied on `tir::UndefinedVars`
    ignoring the data pointer of BufferLoad/BufferStore nodes.  When
    `tir::UndefinedVars` was updated to visit the data pointer, similar to
    the previous behavior when visiting Load/Store nodes, this caused the
    count of undefined variables to be unexpectedly high.
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! Updated Buffer::vstore/vload to return BufferLoad/BufferStore objects.
    
    * fixup! In test directory, replacing all instances of T.load.
    
    * fixup! In test directory, replacing all instances of T.load.
    
    * fixup! Replacing all T.store TIR calls.
    
    * Expose Buffer index flattening function to Python.
    
    * Updated test_tir_buffer.py offset tests.
    
    Replacing calls to `Buffer.vload` with `Buffer.offset_of`, when
    testing the index calculations.
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! Updated Buffer::vstore/vload to return BufferLoad/BufferStore objects.
    
    * fixup! Replacing Store/Load in lowering/legalization passes.
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! Updated ethos-u C++ unit tests to remove use of Load/Store.
    
    * fixup! Replacing Store/Load in lowering/legalization passes.
    
    Fix linting for inject_double_buffer.cc
    
    * fixup! Updated ethos-u C++ unit tests to remove use of Load/Store.
    
    * fixup! Added .astype to tvm.script.tir.node.BufferSlice
    
    * fixup! In test directory, replacing all instances of T.load.
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! In test directory, replacing all instances of T.load.
    
    * fixup! Replacing all T.store TIR calls.
    
    * fixup! Replacing Store/Load in lowering/legalization passes.
    
    * [UnitTests] Added T.preflattened_buffer in expected result
    
    * fixup! In test directory, replacing all instances of T.load.
    
    * [UnitTests] Bound checker update, compare against N-d buffer bounds.
    
    * Fixup, bound checker vectorize test.
    
    * fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
    
    * [UnitTest] Fixed breakage in InjectRollingBuffer test.
    
    Needed a bit more re-writing than usual, because the test was
    explicitly calling lowering passes, then calling `tvm.build`.  Fixed
    by using the standard lowering flow, with preprocessing steps
    inserting with `tir.add_lower_pass`.
    
    * fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
    
    * [UnitTest] Fixed breakage in flatten buffer unit tests.
    
    - Updated pass to allow BufferStore/BufferLoad nodes to be visited
      before the block's alloc buffer.
    
    - Added `T.preflattened_buffer` annotations.
    
    * fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
    
    * [UnitTests] Fixed breakage in test_tir_buffer.py
    
    - Updated vload test for new behavior.
    - Added test for offset_of, testing behavior no longer in vload.
    - Added null check for buffer visitor.
    
    * fixup! Replacing Load/Store in codegens.
    
    * [UnitTest] ComputeInline, opaque access test updates
    
    * [UnitTest] Fixup, allow unit test to use `ib.pointer()[0]`.
    
    * fixup! Replacing Load/Store in codegens.
    
    The updated CodegenLLVM should use the BufferStore/BufferLoad
    convention of indexing by `sizeof(dtype)`, rather than
    `sizeof(dtype.element_of())`.
    
    * fixup! Replacing Store/Load in lowering/legalization passes.
    
    BF16Legalize should also update the preflattened_buffer_map, since it
    is overwriting the `BufferNode::data` stored in the buffer_map.
    
    * fixup! Replacing all T.store TIR calls.
    
    * Fixed failing codegen c host unit tests.
    
    - Generated functions were making `uint8_t*` parameter arguments for
      array handle for return value, rather than the earlier `void*`.
    
    - New parameter type was due to using
      `PointerType(PrimType(DataType::UInt(8)))` as the type annotation, to
      be usable as `BufferNode::data`.
    
    - Changing to `PointerType(PrimType(DataType::Void()))` still allows
      usage as buffer, more appropriately expresses semantics.
    
    - Updated C codegens to allow `void*` types to be generated from
      variables with type annotation, in addition to the previous behavior
      of `DataType::Handle()` variables without type annotation.
    
    * Fixup, StorageFlatten when applied to post-StorageRewrite functions.
    
    Identified in a test that applied `tvm.lower`, then `tvm.build` on the
    result.  If the result of an allocate node is used as the backing
    buffer for multiple buffers, such as the output of the StorageRewrite
    pass, then StorageFlatten would erroneously think that the second
    occurrence was an usage without earlier definition.
    
    * fixup, StorageFlatten
    
    When flattening a boolean buffer, the backing buffer should have type
    int8, not the preflattened buffer.
    
    * Bugfix, correctly represent void* in LLVM IR.
    
    * Update, replace tir.Load with tir.BufferLoad
    
    * Added TVMScript error check for matching buffer/index dimensionality
    
    Needed for tests/python/unittest/test_tvmscript_error_report.py::test_high_dim_store
    
    * Bugfix, correct return type when lowering custom datatype.
    
    * Bugfix, removed unused primfunc from test_tvmscript_complete.py
    
    * Updated test_meta_schedule_postproc_verify_gpu_code.py TIR
    
    Replaced Load/Store with BufferLoad/BufferStore.
    
    * Allowed ramp nodes with buffer use analysis.
    
    * Updated tests in test_meta_schedule_postproc_verify_gpu_code.py
    
    Needed dummy writes to prevent buffer resizing, in order to trigger
    the verification failure due to memory limits.
    
    * Updated TIR examples to be compatible with buffer dimension check.
    
    * Corrected section header in docstring.
    
    * Corrected indices size check in CogeGenC.
    
    * Fixed breakage in LowerThreadAllreduce.
    
    Since the AllocateNode is rewritten, any buffers that refer to those
    variables must also be rewritten.
    
    * [UnitTests] Replaced Store/Load in CUDA codegen tests.
    
    * Resolved breakage in C-based codegen for vectorized store/load.
    
    Needed to update to new convention of using the buffer's element type
    as the stride.
    
    * Bugfix, incorrect LCA for buffer access in root scope.
    
    This had been present before the BufferLoad/BufferStore changes, but
    hadn't triggered on tests using Load/Store nodes.
    
    * Added docstrings for TransformNode member variables.
    
    * Added TODO for future removal of preflattened_buffer_map.
    
    * Fixup, transform layout + cache write tests.
    
    The correct sequence is to first apply any caching as needed, then to
    apply layout transformations, and finally to apply thread binds for
    the computation step.
    
    * Bugfix, correct element type for scalarized access.
    
    * Bugfix, cuda buffer indexing when declared as different type.
    
    * Cuda codegen, update reference.
    
    * Bugfix, lower allreduce
    
    Loads of the output of the reduction should be replaced for all
    buffers sharing a buffer pointer, not just for the buffer object
    itself.
    
    * Removed obsolete comment.
    
    * Changed PrimFunc constructor preflattened_buffer_map to Optional
    
    * Removed flatten_buffer argument from T.match_buffer.
    
    * Correct call to VarUseDefAnalysis::VisitBuffer
    
    * Reverted unintentional testing change, lanes=2.
    
    * Updated lower_cross_thread_reduction to use buffer in allreduce
    
    * Updated transform_layout test to disable CSE
    
    * Updated CSE unit tests to use BufferStore
    
    * Replaced Store/Load for vta.transform and unit tests.
    
    * Updated unit tests for lower_cross_thread_reduction.
    
    * Updated arange to use scalar tensors.
    
    The start/stop/step tensors are declared as 0-d scalar tensors, but
    were accessed as 1-d tensors.
    
    * Fix breakage in ethosu constant encoding.
    
    Buffers generated by "ethosu_copy" should have their buffer objects
    rewritten, but shouldn't have their size updated in ethosu-specific
    Call nodes.
    
    * Fix breakage in ethosu call argument checks.
    
    Need to pull out indices from BufferLoad holders, not Load.
    
    * Resolve breakage from mismatched shape/index dimensions
    
    * Split out encoded parameters from preflattened buffer map.
    
    * Updated buffer shape/index dimensions to match in more ethosu tests
    
    * Fixed lint error
    
    * Removed debug code
    
    * Moved arith::Analyzer local variable to class member
    
    * Fixed SSA conversion of allocations.
    
    Can occur if allocation is inside an unrolled loop.  Added unit test
    to catch this failure mode.
    
    * Ethos-u index/buffer dimension updates.
    
    * Updated ethosu passes to handle buffer load/store.
    
    * Resolved bug in tvmscript printing of duplicate buffers.
    
    * Fix breakage in ethos-u test_assign_addresses, encode constants
    
    * Apply same changes to T.allocate_const as to T.allocate
    
    Return a buffer when used in TVMScript, allow for aliasing buffers.
    
    * Fix lint errors.
    
    * Further updates for ethos-u tests.
    
    * Updated ethos.u buffer sizes in test.
    
    * Updated tir.BindParams to use BufferLoad instead of Load.
    
    * Updated topi.cuda.scan implementation to follow buffer dimensions.
    
    * Resolved breakage when flattening AllocateConst nodes.
    
    * Resolved breakages from latest merge with main.
    
    * Corrected error in merge.
    
    * Use empty indices for rank-0 tensor.
    
    * Added ir_builder workaround for 1-d indexing.
    
    * Consistent buffer access type in LLVM codegen, to match C codegen
    
    * StorageRewrite, update indices of modified buffers.
    
    * Dynamic relay nodes, access 0-d tensors with 0-d indices.
    
    * BFloat16 legalization, update buffer type.
    
    * Updated meshgrid to use 0-d index for 0-d buffer.
    
    * Corrected boolean handling in Allocate nodes.
    
    * Added workaround to unpack 1-d Tensor indices into N-d buffer indices.
    
    * Resolved a few more failures in relay tests on cuda.
    
    * Resolve linting
    
    * CI bump
    
    * Updated renormalize_split_pattern tests to use BufferLoad/BufferStore
    
    * Fixed cuda codegen checks for BufferStore/Ramp.
    
    * Simplify indices further, needed to avoid cuda register limit.
    
    * fixed dyn onehot shape func accessing 1d buffer with ()
    
    * Fixed codegen indexing for int4 scalar types.
    
    * Temporary workaround for incorrect constant folding.
    
    Need to further investigate vectorized LLVM constants
    
    * s/find_allocate_usage/FindAllocateUsage/g
    
    * Added buffer type consistency TODO.
    
    * Improved comment on address_of Op.
    
    * Rename LegalizeDtype to LegalizeDType, made private.
    
    * fix format and lint errors
    
    * Disable vectorization of AllocateConst buffer in StorageRewrite.
    
    * Pass buffer_map through to the PrimFunc in cmsisnn
    
    * try disabling problematic winograd test case
    
    * try different way of buffer mapping in storage_rewrite
    
    * Removed unnecessary ramp node in ir_builder.
    
    
    * Updated LLVM codegen for buffer indexing.
    
    TVM data arrays are always densely packed.  If the LLVM type
    corresponding to a vectorized TVM datatype contains padding for
    alignment, the array location should be computed based on the
    primitive element type.
    
    
    Co-authored-by: Masahiro Masuda <ma...@gmail.com>
    Co-authored-by: adstraw <as...@octoml.ai>
---
 include/tvm/ir/attrs.h                             |   41 +
 include/tvm/te/operation.h                         |    1 +
 include/tvm/te/schedule.h                          |  123 +-
 include/tvm/tir/buffer.h                           |   44 +-
 include/tvm/tir/builtin.h                          |   11 +-
 include/tvm/tir/expr.h                             |   16 +
 include/tvm/tir/function.h                         |   44 +-
 include/tvm/tir/index_map.h                        |  140 +
 include/tvm/tir/stmt.h                             |   17 +
 include/tvm/topi/transform.h                       |    8 +-
 .../contrib/ethosu/tir/binary_elementwise.py       |    6 +-
 .../backend/contrib/ethosu/tir/convolution.py      |   12 +-
 .../relay/backend/contrib/ethosu/tir/depthwise.py  |   12 +-
 python/tvm/relay/backend/contrib/ethosu/tir/dma.py |   58 +-
 .../relay/backend/contrib/ethosu/tir/identity.py   |   18 +-
 .../tvm/relay/backend/contrib/ethosu/tir/passes.py |  425 +-
 .../relay/backend/contrib/ethosu/tir/pooling.py    |    4 +-
 .../tvm/relay/backend/contrib/ethosu/tir/spec.py   |   15 +-
 .../relay/backend/contrib/ethosu/tir/transform.py  |   11 +-
 .../contrib/ethosu/tir/unary_elementwise.py        |    6 +-
 .../tvm/relay/backend/contrib/ethosu/tir/utils.py  |   28 +-
 .../backend/contrib/ethosu/tir_to_cs_translator.py |   31 +-
 python/tvm/relay/op/_transform.py                  |    6 +-
 python/tvm/relay/op/dyn/_transform.py              |    2 +-
 python/tvm/relay/op/dyn/nn/_nn.py                  |   10 +-
 python/tvm/script/context_maintainer.py            |    3 +
 python/tvm/script/parser.py                        |   65 +-
 python/tvm/script/tir/__init__.pyi                 |    2 +-
 python/tvm/script/tir/node.py                      |    6 +-
 python/tvm/script/tir/scope_handler.py             |   49 +-
 python/tvm/script/tir/special_stmt.py              |   54 +
 python/tvm/te/__init__.py                          |    8 +-
 python/tvm/te/schedule.py                          |  152 +-
 python/tvm/tir/buffer.py                           |   38 +
 python/tvm/tir/function.py                         |   37 +-
 python/tvm/tir/ir_builder.py                       |  112 +-
 python/tvm/tir/transform/transform.py              |   15 +-
 python/tvm/topi/cuda/sparse.py                     |   12 +-
 python/tvm/topi/utils.py                           |   14 +-
 src/arith/rewrite_simplify.cc                      |    6 +-
 src/autotvm/feature_visitor.cc                     |   10 +-
 src/autotvm/feature_visitor.h                      |    4 +-
 src/contrib/hybrid/codegen_hybrid.cc               |    8 +
 src/contrib/hybrid/codegen_hybrid.h                |    2 +
 src/printer/tir_text_printer.cc                    |    3 +
 src/printer/tvmscript_printer.cc                   |  127 +-
 src/relay/backend/aot_executor_codegen.cc          |   14 +-
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  |    2 +-
 .../contrib/example_target_hooks/relay_to_tir.cc   |    8 +-
 src/relay/op/tensor/transform.cc                   |    6 +-
 src/relay/transforms/fold_constant.cc              |    1 +
 src/target/llvm/codegen_cpu.cc                     |    6 +-
 src/target/llvm/codegen_hexagon.cc                 |    6 +-
 src/target/llvm/codegen_llvm.cc                    |  182 +-
 src/target/llvm/codegen_llvm.h                     |    7 +-
 src/target/source/codegen_c.cc                     |  199 +-
 src/target/source/codegen_c.h                      |   11 +-
 src/target/source/codegen_c_host.cc                |    4 +
 src/target/source/codegen_cuda.cc                  |   10 +-
 src/target/source/codegen_cuda.h                   |    3 +-
 src/target/source/codegen_metal.cc                 |    5 +
 src/target/source/codegen_opencl.cc                |   32 +-
 src/target/source/codegen_opencl.h                 |    7 +-
 src/target/source/codegen_source_base.cc           |    4 +
 src/target/spirv/codegen_spirv.cc                  |   40 +-
 src/target/spirv/codegen_spirv.h                   |    4 +-
 src/target/stackvm/codegen_stackvm.cc              |   46 +-
 src/target/stackvm/codegen_stackvm.h               |    2 +
 src/te/operation/cross_thread_reduction.cc         |   46 +-
 src/te/schedule/message_passing.cc                 |  130 +
 src/te/schedule/schedule_lang.cc                   |   96 +
 src/te/schedule/schedule_ops.cc                    |   52 +-
 src/te/schedule/schedule_postproc_to_primfunc.cc   |  276 +-
 src/tir/analysis/block_access_region_detector.cc   |   10 +-
 src/tir/analysis/buffer_access_lca_detector.cc     |   21 +-
 src/tir/analysis/device_constraint_utils.cc        |   27 +-
 src/tir/analysis/var_touch.cc                      |   14 +-
 src/tir/analysis/verify_gpu_code.cc                |   12 +-
 src/tir/analysis/verify_memory.cc                  |   14 +-
 src/tir/ir/buffer.cc                               |  241 +-
 src/tir/ir/expr.cc                                 |   20 +-
 src/tir/ir/expr_functor.cc                         |   12 +-
 src/tir/ir/function.cc                             |   10 +-
 src/tir/ir/index_map.cc                            |  154 +
 src/tir/ir/stmt.cc                                 |   17 +-
 src/tir/ir/stmt_functor.cc                         |   82 +-
 src/tir/schedule/primitive/cache_read_write.cc     |   30 +-
 src/tir/schedule/primitive/compute_inline.cc       |   42 +-
 src/tir/transforms/arg_binder.cc                   |   77 +-
 src/tir/transforms/bf16_legalize.cc                |  108 +-
 src/tir/transforms/bind_params.cc                  |    9 +-
 src/tir/transforms/bound_checker.cc                |  166 +-
 src/tir/transforms/compact_buffer_region.cc        |    9 +-
 src/tir/transforms/coproc_sync.cc                  |   19 +-
 src/tir/transforms/flatten_buffer.cc               |  105 +-
 src/tir/transforms/inject_copy_intrin.cc           |   42 +-
 src/tir/transforms/inject_double_buffer.cc         |   93 +-
 src/tir/transforms/inject_virtual_thread.cc        |  160 +-
 src/tir/transforms/ir_utils.cc                     |  128 +-
 src/tir/transforms/ir_utils.h                      |   16 +-
 src/tir/transforms/lower_cross_thread_reduction.cc |    2 +-
 src/tir/transforms/lower_custom_datatypes.cc       |   78 +-
 src/tir/transforms/lower_match_buffer.cc           |   19 +-
 src/tir/transforms/lower_thread_allreduce.cc       |  262 +-
 src/tir/transforms/lower_tvm_builtin.cc            |  183 +-
 src/tir/transforms/lower_warp_memory.cc            |  107 +-
 src/tir/transforms/make_packed_api.cc              |   79 +-
 .../merge_dynamic_shared_memory_allocations.cc     |   87 +-
 src/tir/transforms/narrow_datatype.cc              |   56 +-
 src/tir/transforms/rewrite_unsafe_select.cc        |   15 +-
 src/tir/transforms/simplify.cc                     |   34 +-
 src/tir/transforms/split_host_device.cc            |   26 +-
 src/tir/transforms/storage_access.cc               |   49 +-
 src/tir/transforms/storage_access.h                |    9 +-
 src/tir/transforms/storage_flatten.cc              |  539 ++-
 src/tir/transforms/storage_rewrite.cc              |  352 +-
 src/tir/transforms/thread_storage_sync.cc          |   71 +-
 src/tir/transforms/unroll_loop.cc                  |    5 +
 src/tir/transforms/update_pointer_storage_scope.cc |   56 +-
 src/tir/transforms/update_pointer_storage_scope.h  |    8 +
 src/tir/transforms/vectorize_loop.cc               |  191 +-
 src/tir/usmp/analysis/extract_buffer_info.cc       |   12 +-
 src/tir/usmp/transform/assign_pool_info.cc         |    4 +-
 .../convert_pool_allocations_to_offsets.cc         |  109 +-
 tests/cpp/tir_analysis_side_effect.cc              |    5 +-
 tests/python/contrib/test_ethosu/infra.py          |   12 +-
 .../contrib/test_ethosu/test_encode_constants.py   |  136 +-
 .../test_ethosu/test_remove_concatenates.py        |   29 +-
 .../contrib/test_ethosu/test_replace_conv2d.py     |  188 +-
 .../contrib/test_ethosu/test_replace_copy.py       |   44 +-
 .../test_ethosu/test_replace_unary_elementwise.py  |    4 +-
 tests/python/contrib/test_ethosu/test_scheduler.py |   32 +-
 .../test_ethosu/test_tir_to_cs_translator.py       |  234 +-
 tests/python/contrib/test_ethosu/test_vela_api.py  |   26 +-
 tests/python/relay/aot/test_crt_aot.py             |   19 +-
 tests/python/unittest/test_lower_build.py          |   36 +-
 .../test_meta_schedule_postproc_verify_gpu_code.py |   88 +-
 .../test_runtime_module_based_interface.py         |    2 +-
 tests/python/unittest/test_runtime_module_load.py  |    2 +-
 tests/python/unittest/test_target_codegen_cuda.py  |   48 +-
 tests/python/unittest/test_target_codegen_llvm.py  |   11 +-
 .../python/unittest/test_target_codegen_vulkan.py  |    5 +-
 .../test_tir_analysis_calculate_workspace.py       |   60 +-
 .../test_tir_analysis_detect_buffer_access_lca.py  |    6 +-
 tests/python/unittest/test_tir_buffer.py           |   30 +-
 tests/python/unittest/test_tir_constructor.py      |   25 +-
 tests/python/unittest/test_tir_intrin.py           |    5 +-
 tests/python/unittest/test_tir_ir_builder.py       |    4 +-
 .../python/unittest/test_tir_lower_match_buffer.py |    4 +-
 tests/python/unittest/test_tir_nodes.py            |   17 +-
 tests/python/unittest/test_tir_ptx_mma.py          |  150 +-
 .../unittest/test_tir_schedule_cache_read_write.py |    6 +-
 .../unittest/test_tir_schedule_compute_inline.py   |   20 +-
 tests/python/unittest/test_tir_schedule_reorder.py |    4 +-
 .../unittest/test_tir_schedule_split_fuse.py       |    6 +-
 .../test_tir_transform_combine_context_call.py     |    6 +-
 .../test_tir_transform_common_subexpr_elim.py      |   30 +-
 .../test_tir_transform_compact_buffer_region.py    |   19 +-
 .../test_tir_transform_convert_for_loops_serial.py |   20 +-
 .../test_tir_transform_extract_constants.py        |    6 +-
 .../unittest/test_tir_transform_flatten_buffer.py  |   56 +-
 .../test_tir_transform_inject_double_buffer.py     |    4 +-
 .../test_tir_transform_inject_rolling_buffer.py    |   62 +-
 .../test_tir_transform_inject_virtual_thread.py    |   41 +-
 ...test_tir_transform_instrument_bound_checkers.py |    6 +-
 .../python/unittest/test_tir_transform_ir_utils.py |    9 +-
 .../unittest/test_tir_transform_loop_partition.py  |   80 +-
 ...t_tir_transform_lower_cross_thread_reduction.py |   24 +-
 .../test_tir_transform_lower_tvm_builtin.py        |    4 +-
 .../unittest/test_tir_transform_narrow_datatype.py |   14 +-
 .../unittest/test_tir_transform_remove_no_op.py    |    2 +-
 ...test_tir_transform_renormalize_split_pattern.py |   49 +-
 .../python/unittest/test_tir_transform_simplify.py |    4 +-
 .../unittest/test_tir_transform_storage_flatten.py |   47 +-
 .../unittest/test_tir_transform_unroll_loop.py     |   27 +-
 .../unittest/test_tir_transform_vectorize.py       |    9 +-
 tests/python/unittest/test_tir_usmp_algo.py        |  124 +-
 .../test_tir_usmp_analysis_extract_bufferinfo.py   |  540 +--
 ...ransform_convert_pool_allocations_to_offsets.py |  381 +-
 tests/python/unittest/test_tir_usmp_utils.py       |   40 +-
 tests/python/unittest/test_transform_layout.py     |  498 ++
 tests/python/unittest/test_tvmscript_complete.py   |    6 -
 .../python/unittest/test_tvmscript_error_report.py |    5 +-
 tests/python/unittest/test_tvmscript_roundtrip.py  | 4939 ++++++++++----------
 vta/python/vta/transform.py                        |   97 +-
 185 files changed, 9634 insertions(+), 5779 deletions(-)

diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index f6c15f9..9a24687 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -382,6 +382,47 @@ inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
   return input;
 }
 
+/*!
+ * \brief Copy the function or module, but removes the specified
+ *        attribute.
+ *
+ * \param input The thing to annotate (BaseFunc or IRModule)
+ * \param attr_key The attribute key.
+ *
+ * \tparam TFunc The corresponding function or module type.
+ *
+ * \returns The new function or module with removed attribute.
+ *
+ * \note This function performs copy on write optimization for func and module.
+ *       If we move a uniquely referenced func or module into WithoutAttr,
+ *       then no additional copy will be performed.
+ *
+ *       This is also why we make it as a function instead of a member function
+ *       and why we pass by value in the first argument.
+ *
+ * \code
+ *
+ *  // Recommended way to trigger copy on write
+ *  func = WithoutAttr(std::move(func), "key1");
+ *  func = WithoutAttr(std::move(func), "key2");
+ *
+ * \endcode
+ */
+template <typename TFunc>
+inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
+  using TNode = typename TFunc::ContainerType;
+  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
+
+  if (input->attrs.defined()) {
+    TNode* node = input.CopyOnWrite();
+    node->attrs.CopyOnWrite()->dict.erase(attr_key);
+    if (node->attrs->dict.size() == 0) {
+      node->attrs = NullValue<DictAttrs>();
+    }
+  }
+  return input;
+}
+
 // Namespace containing detail implementations
 namespace detail {
 using runtime::TVMArgValue;
diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h
index 89074d8..99c86f0 100644
--- a/include/tvm/te/operation.h
+++ b/include/tvm/te/operation.h
@@ -268,6 +268,7 @@ class ComputeOp : public Operation {
                     Array<IterVar> axis, Array<PrimExpr> body);
 
   TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode);
 };
 
 /*!
diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h
index 17aedbc..8e637b4 100644
--- a/include/tvm/te/schedule.h
+++ b/include/tvm/te/schedule.h
@@ -29,6 +29,7 @@
 #include <tvm/te/tensor.h>
 #include <tvm/te/tensor_intrin.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/index_map.h>
 
 #include <string>
 #include <unordered_map>
@@ -257,6 +258,41 @@ class Stage : public ObjectRef {
    */
   TVM_DLL Stage& rolling_buffer();  // NOLINT(*)
   /*!
+   * \brief Defines a layout transformation to be applied to the buffer.
+   *
+   * The map from initial_index to final_index must be an
+   * invertible affine transformation.
+   *
+   * \param initial_indices An array of variables to represent a
+   * value's location in the tensor, using the pre-transformation
+   * layout.  These variables are used as binding occurrences to
+   * represent the initial indices when applying the initial->final
+   * mapping, and should not occur elsewhere in the
+   * Schedule. (i.e. Pass in newly constructed variables, not the
+   * initial IterVar::var)
+   *
+   * \param final_indices An array of expressions, giving the
+   * value's location in the tensor, using the post-transformation layout.
+   * Expressions should be in terms of the variables given in
+   * initial_indices.
+   *
+   * \param out_iter_vars An optional output location for the updated
+   * loop iteration variables.
+   *
+   * \return reference to self
+   */
+  TVM_DLL Stage& transform_layout(const Array<Var>& initial_indices,
+                                  const Array<PrimExpr>& final_indices,
+                                  Array<IterVar>* out_iter_vars = nullptr);
+  /*! \brief Defines separators between groups of axes.
+   *
+   * Used to define `BufferNode::axis_separators`, which has
+   * additional details.
+   *
+   * \param axis_separators A list of axis separators.
+   */
+  TVM_DLL Stage& set_axis_separators(const Array<IntImm>& axis_separators);
+  /*!
    * \brief whether the stage has been scheduled.
    * \return whether the stage has been scheduled.
    */
@@ -466,9 +502,27 @@ class StageNode : public Object {
    *  while origin_op remains fixed.
    */
   Operation origin_op;
-  /*! \brief All the nodes in the iter var */
+  /*! \brief All the nodes in the iter var
+   *
+   * Each element of all_iter_vars represents an iteration variable
+   * that may appear within this stage's computation.  Any element
+   * of `all_iter_vars` that is in `leaf_iter_vars` represents a
+   * variable that is directly defined and usable within the stage's
+   * computation.  All other elements of `all_iter_vars` represent
+   * variables whose value must be computed from the variables in
+   * `leaf_iter_vars`.  (e.g. Support index k has been split by
+   * ``ko, ki = s.split(k, factor=4)``.  ko and ki will appear in
+   * `leaf_iter_vars`, while k will not, and must be computed as
+   * `4*ko + ki`.
+   */
   Array<IterVar> all_iter_vars;
-  /*! \brief The current active leaf iter vars in the stage. */
+  /*! \brief The current active leaf iter vars in the stage.
+   *
+   * Each element of leaf_iter_vars will either be replaced with the
+   * bound index (e.g. threadIdx.x), or will be expanded into a loop
+   * over the variable's extent.  `leaf_iter_vars` is a subset of
+   * `all_iter_vars`.
+   */
   Array<IterVar> leaf_iter_vars;
   /*!
    * \brief Specify threads to be launched at the stage.
@@ -500,6 +554,14 @@ class StageNode : public Object {
   bool double_buffer{false};
   /*! \brief Whether apply rolling buffer optimization to this stage */
   bool rolling_buffer{false};
+  /*! \brief Layout transformations to be applied onto the stage's tensors. */
+  Array<IndexMap> layout_transforms;
+  /*! \brief List of axes after which to divide physical axes.
+   *
+   * Used to populate `BufferNode::axis_separators`, which has
+   * additional details.
+   */
+  Array<IntImm> axis_separators;
   /*!
    * \brief The parent group of the current stage.
    *  The stage cannot be assigned to stages outside the group.
@@ -522,6 +584,8 @@ class StageNode : public Object {
     v->Visit("scope", &scope);
     v->Visit("is_output", &is_output);
     v->Visit("double_buffer", &double_buffer);
+    v->Visit("layout_transforms", &layout_transforms);
+    v->Visit("axis_separators", &axis_separators);
     v->Visit("group", &group);
     v->Visit("num_child_stages", &num_child_stages);
   }
@@ -771,6 +835,61 @@ class Singleton : public IterVarRelation {
   TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
 };
 
+/*!
+ * \brief Transform iterator according to some arbitrary expression.
+ */
+class TransformNode : public IterVarRelationNode {
+ public:
+  /*! \brief The loop variables that were replaced by the transformation.
+   *
+   * Prior to applying a layout transformation, these represent the
+   * loops to iterate over a tensor as it is being computed, following
+   * a row-major traversal of the tensor's original shape in the
+   * compute definition.
+   */
+  Array<IterVar> original_variables;
+
+  /*! \brief The variables generated by the transformation.
+   *
+   * After to applying a layout transformation, these represent the
+   * loops to iterate over a tensor as it is being computed, following
+   * a row-major traversal of the transformed shape of the tensor.
+   */
+  Array<IterVar> transformed_variables;
+
+  /*! \brief Map from the original variables to the transformed variables.
+   *
+   * Used to determine iterator ranges over the transformed variables.
+   */
+  IndexMap forward_transformation;
+
+  /*! \brief Map from transformed variables to the original variables
+   *
+   * Used to rewrite expressions containing the original loop iterators
+   * in terms of the transformed loop iterators.
+   */
+  IndexMap inverse_transformation;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("original_variables", &original_variables);
+    v->Visit("transformed_variables", &transformed_variables);
+    v->Visit("forward_transformation", &forward_transformation);
+    v->Visit("inverse_transformation", &inverse_transformation);
+  }
+
+  static constexpr const char* _type_key = "Transform";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode);
+};
+
+class Transform : public IterVarRelation {
+ public:
+  TVM_DLL explicit Transform(Array<IterVar> original_variables,
+                             Array<IterVar> transformed_variables, IndexMap forward_transformation,
+                             IndexMap inverse_transformation);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode);
+};
+
 /*! \brief Container for specialization conditions. */
 class SpecializedConditionNode : public Object {
  public:
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index 69453e2..aef82ae 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -55,9 +55,23 @@ class BufferNode : public Object {
   Var data;
   /*! \brief data type in the content of the tensor */
   DataType dtype;
-  /*! \brief The shape of the buffer */
+  /*! \brief The type of the buffer prior to flattening
+   *
+   * This contains the shape as it is accessed by
+   * BufferLoad/BufferStore nodes, and used by the low-level code
+   * generators.
+   */
   Array<PrimExpr> shape;
   /*!
+   * \brief Separators between input axes when generating flattened output axes
+   *
+   * For buffers representing flat 1-d memory (e.g. any buffer in
+   * RAM), this should be an empty array.  For buffers representing
+   * non-flat memory, each entry in axis_separators should be the
+   * first input axis that is part of a new flattened axis.
+   */
+  Array<IntImm> axis_separators;
+  /*!
    * \brief The strides of each dimension
    *  This can be an empty array, indicating array is contiguous
    */
@@ -89,6 +103,7 @@ class BufferNode : public Object {
     v->Visit("dtype", &dtype);
     v->Visit("shape", &shape);
     v->Visit("strides", &strides);
+    v->Visit("axis_separators", &axis_separators);
     v->Visit("elem_offset", &elem_offset);
     v->Visit("name", &name);
     v->Visit("data_alignment", &data_alignment);
@@ -98,10 +113,11 @@ class BufferNode : public Object {
   }
 
   bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
-    // Use DefEqual as buffer can define variables
-    // in its semantics, skip name as name is not important.
+    // Use DefEqual as buffer can define variables in its semantics,
+    // skip name as name is not important.
     return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
            equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
+           equal.DefEqual(axis_separators, other->axis_separators) &&
            equal.DefEqual(elem_offset, other->elem_offset) &&
            equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type);
   }
@@ -112,6 +128,7 @@ class BufferNode : public Object {
     hash_reduce.DefHash(shape);
     hash_reduce.DefHash(strides);
     hash_reduce.DefHash(elem_offset);
+    hash_reduce.DefHash(axis_separators);
     hash_reduce(data_alignment);
     hash_reduce(buffer_type);
   }
@@ -127,7 +144,7 @@ class BufferNode : public Object {
    * without adjusting for number of lanes.  (e.g. The number of
    * float16x4 elements in a buffer of type float16x4.)
    */
-  PrimExpr ElemOffset(Array<PrimExpr> index) const;
+  Array<PrimExpr> ElemOffset(Array<PrimExpr> index) const;
 
   static constexpr const char* _type_key = "tir.Buffer";
   static constexpr const bool _type_has_method_sequal_reduce = true;
@@ -146,7 +163,7 @@ class Buffer : public ObjectRef {
   // A default value will be picked.
   TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
                  PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
-                 BufferType buffer_type, Span span = Span());
+                 BufferType buffer_type, Array<IntImm> axis_separators = {}, Span span = Span());
 
   /*!
    * \brief Return a new buffer that is equivalent with current one
@@ -187,6 +204,19 @@ class Buffer : public ObjectRef {
   TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
 
   /*!
+   * \brief Get a flattened version of the buffer
+   */
+  Buffer GetFlattenedBuffer() const;
+
+  /*! \brief Determine the offset in the buffer of the given index.
+   *
+   * Returns the buffer offset, in number of elements of type dtype,
+   * without adjusting for number of lanes.  (e.g. The number of
+   * float16x4 elements in a buffer of type float16x4.)
+   */
+  Array<PrimExpr> OffsetOf(Array<PrimExpr> index) const;
+
+  /*!
    * \brief Return the storage scope associated with this buffer.
    */
   TVM_DLL String scope() const;
@@ -201,12 +231,14 @@ class Buffer : public ObjectRef {
  * \param dtype The content data type.
  * \param name The name of the buffer
  * \param storage_scope The storage scope associated with this buffer
+ * \param axis_separators Divisions defining the groups of axes that will be flattened together.
  * \param span The location of this object in the source code.
  * \return The created buffer.
  * \sa Buffer for complete constructor.
  */
 TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
-                           String name = "buffer", String storage_scope = "", Span span = Span());
+                           String name = "buffer", String storage_scope = "",
+                           Array<IntImm> axis_separators = {}, Span span = Span());
 
 /*!
  * \brief Base node for data producers.
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index d8a5ea6..f7e1cfbc 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -105,10 +105,15 @@ TVM_DLL const Op& large_uint_imm();
 TVM_DLL const Op& q_multiply_shift();
 
 /*!
- * \brief See pesudo code
+ * \brief Returns the address of an element in the buffer (see pseudocode below).
+ *
+ * The number of indices should match the dimensionality of the buffer
+ * being accessed.  If this operation occurs after buffer flattening,
+ * the number of indices must be supported by the target (i.e. N>1
+ * only on targets that support non-flat memory buffers).
  *
- *  Handle address_of(Load *op) {
- *     return &op->buffer_var[index];
+ *  Handle address_of(BufferLoad *op) {
+ *     return &op->buffer_var[op->indices[0], op->indices[1], ..., op->indices[N-1]];
  *  }
  */
 TVM_DLL const Op& address_of();
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index f674111..674ff0b 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -630,6 +630,22 @@ class BufferLoadNode : public PrimExprNode {
 
   static constexpr const char* _type_key = "tir.BufferLoad";
   TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
+
+ private:
+  /*! \brief Set the dtype based on the buffer/indices
+   *
+   * Usually, the BufferLoad's dtype will be the same dtype as the
+   * buffer.  This may have a different number of lanes than the
+   * buffer's dtype if index values have more than 1 lane.
+   *
+   * This function should only be called during construction and after
+   * CopyOnWrite.  Friend class used here to restrict usage.
+   */
+  void LegalizeDType();
+  friend class BufferLoad;
+  friend class CustomDatatypesLowerer;
+  friend class VectorTypeRewriter;
+  friend class Vectorizer;
 };
 
 /*!
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 2b3c4b5..dc7014c 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -91,11 +91,30 @@ class PrimFuncNode : public BaseFuncNode {
    */
   Map<tir::Var, Buffer> buffer_map;
 
+  /*! \brief The buffer map prior to flattening.
+   *
+   * This contains the buffers as they exists prior to flattening, and
+   * is used for validating an input tensor passed into the packed
+   * API.  Any buffer that is present in `buffer_map` but not present
+   * in `preflattened_buffer_map` is assumed to be the same before
+   * and after flattening (e.g. a 1-d tensor that is backed by 1-d
+   * flat memory).
+   *
+   * TODO(Lunderberg): Remove preflattened_buffer_map, and instead
+   * declare each flattened buffer as aliasing the original tensor
+   * shape.  This should include improving the StmtExprMutator to
+   * provide easier interactions with Buffer objects, so that the
+   * bookkeeping of relationships between buffers doesn't need to be
+   * repeated across several transforms.
+   */
+  Map<tir::Var, Buffer> preflattened_buffer_map;
+
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("params", &params);
     v->Visit("body", &body);
     v->Visit("ret_type", &ret_type);
     v->Visit("buffer_map", &buffer_map);
+    v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
     v->Visit("attrs", &attrs);
     v->Visit("span", &span);
     v->Visit("_checked_type_", &checked_type_);
@@ -104,6 +123,7 @@ class PrimFuncNode : public BaseFuncNode {
   bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
     // visit params and buffer_map first as they contains defs.
     return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
+           equal(preflattened_buffer_map, other->preflattened_buffer_map) &&
            equal(ret_type, other->ret_type) && equal(body, other->body) &&
            equal(attrs, other->attrs);
   }
@@ -111,6 +131,7 @@ class PrimFuncNode : public BaseFuncNode {
   void SHashReduce(SHashReducer hash_reduce) const {
     hash_reduce.DefHash(params);
     hash_reduce(buffer_map);
+    hash_reduce(preflattened_buffer_map);
     hash_reduce(ret_type);
     hash_reduce(body);
     hash_reduce(attrs);
@@ -136,16 +157,33 @@ class PrimFunc : public BaseFunc {
  public:
   /*!
    * \brief Constructor
+   *
    * \param params The parameters of the function.
+   *
    * \param body The body of the function.
+   *
    * \param ret_type The return type of the function.
+   *
    * \param buffer_map The buffer map for parameter buffer unpacking.
+   * This contains buffer objects as they appear in the body of the
+   * PrimFunc.  (e.g. a buffer of shape ``[1024]`` originally
+   * generated as a tensor of shape ``[32, 32]``)
+   *
+   * \param preflattened_buffer_map The buffer map for
+   * parameter buffer unpacking.  This contains buffer
+   * objects as they are expected to be passed in by the
+   * callee.  (e.g. a buffer of shape ``[32, 32]`` originally
+   * generated as a tensor of shape ``[32, 32]``)
+   *
    * \param attrs Additional function attributes.
+   *
    * \param span The location of this object in the source code.
    */
-  TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
-                   Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
-                   DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
+  TVM_DLL PrimFunc(
+      Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
+      Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
+      Optional<Map<tir::Var, Buffer>> preflattened_buffer_map = Optional<Map<tir::Var, Buffer>>(),
+      DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
new file mode 100644
index 0000000..2371113
--- /dev/null
+++ b/include/tvm/tir/index_map.h
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/tir/index_map.h
+ * \brief Defines a remapping of buffer indices
+ *
+ * For use with tvm::tir::Buffer.
+ */
+#ifndef TVM_TIR_INDEX_MAP_H_
+#define TVM_TIR_INDEX_MAP_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/object.h>
+#include <tvm/tir/var.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Defines a mapping between two representations of indices
+ * into a buffer.
+ *
+ * This is primarily used for layout transformations of Buffer
+ * objects.
+ */
+class IndexMapNode : public Object {
+ public:
+  /*! \brief Variables representing the indices prior to remapping.
+   *
+   * If initial_indices is empty, then final_indices should also be
+   * empty, and no mapping is applied.
+   */
+  Array<Var> initial_indices;
+
+  /*!
+   * \brief Expressions defining the indices after remapping.
+   *
+   * These expressions should only be in terms of the initial_indices,
+   * and must be expressible as an IterSumExpr.  The mapping from
+   * initial_indices to final_indices must be injective.
+   *
+   * If final_indices is empty, then initial_indices should also be
+   * empty, and the map is an identity function.
+   */
+  Array<PrimExpr> final_indices;
+
+  /*!
+   * \brief Default constructor
+   *
+   * Defines the mapping as an identity function, with initial_indices
+   * equal to the final indices.
+   */
+  IndexMapNode() {}
+
+  /*!
+   * \brief Map indices to the output space
+   *
+   * \param indices The indices in the input space.  Should contain
+   * one value for each variable in `initial_indices`.
+   *
+   * \returns The indices in the output space.  Contains one value for
+   * each expression in `final_indices`.
+   */
+  Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices) const;
+
+  /*! \brief Map a memory range to the output space
+   *
+   * If contiguous memory locations in the input space are not
+   * necessarily contiguous in the output space (e.g. `lambda i:
+   * [8*(i%8) + (i//8)]`), then this will return the smallest range
+   * such that all valid indices are contained within the given range.
+   *
+   * \param ranges The ranges in the input space.  Should contain one
+   * value for each variable in `initial_indices`.
+   *
+   * \returns The ranges in the output space.  Contains one value for
+   * each expression in `final_indices`.
+   */
+  Array<Range> MapRanges(const Array<Range>& ranges) const;
+
+  /*! \brief Map a buffer shape to the output space
+   *
+   * \param shape The buffer shape in the input space.  Should contain
+   * one value for each variable in `initial_indices`.
+   *
+   * \returns The buffer shape in the output space.  Contains one
+   * value for each expression in `final_indices`.
+   */
+  Array<PrimExpr> MapShape(const Array<PrimExpr>& shape) const;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("initial_indices", &initial_indices);
+    v->Visit("final_indices", &final_indices);
+  }
+
+  TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object);
+};
+
+class IndexMap : public ObjectRef {
+ public:
+  IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices);
+
+  /*! \brief Generate the inverse mapping.
+   *
+   * The range of the input indices is required in order to ensure
+   * that the transformation is bijective over the input domain.
+   *
+   * TODO(Lunderberg): Look into allowing non-bijective
+   * transformations.  If injective, the inverse mapping could still
+   * be generated with some predicate.  If non-injective, could
+   * simplify the implementation of other optimizations (e.g. double
+   * buffering as a map `lambda *indices: [buffer_loop%2, *indices]`).
+   */
+  IndexMap Inverse(Array<Range> initial_ranges) const;
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
+};
+
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_INDEX_MAP_H_
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 972f781..9ccab50 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -388,6 +388,7 @@ class BufferRealize : public Stmt {
                                  Span span = Span());
 
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode);
 };
 
 /*!
@@ -585,6 +586,7 @@ class Allocate : public Stmt {
                    Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode);
 };
 
 /*!
@@ -1373,6 +1375,21 @@ constexpr const char* pragma_tensor_core = "pragma_tensor_core";
  */
 constexpr const char* prefetch_scope = "prefetch_scope";
 /*!
+ * \brief Marks the layout transforms to be used for a tensor.
+ *
+ * Only applies to a DataProducer, as it should be made part of the
+ * PrimFunc attributes for TIR.
+ */
+constexpr const char* layout_transforms = "layout_transforms";
+/*!
+ * \brief Marks the physical axis separators
+ *
+ * Only applies to a DataProducer, as it should be made part of the
+ * Buffer definition in a PrimFunc.  See `BufferNode::axis_separators`
+ * for more details.
+ */
+constexpr const char* axis_separators = "axis_separators";
+/*!
  * \brief Marks production of double buffer data
  */
 constexpr const char* double_buffer_scope = "double_buffer_scope";
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 04027f8..ef36c01 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1566,7 +1566,11 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& in
         out_shape,
         [&](const Array<Var>& indices) {
           const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
-          Array<PrimExpr> real_indices = {indices[src_index]};
+          auto ndim = inputs[i]->GetShape().size();
+          Array<PrimExpr> real_indices = {};
+          if (ndim > 0) {
+            real_indices = {indices[src_index]};
+          }
           return inputs[i](real_indices);
         },
         name, tag));
@@ -1815,7 +1819,7 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr
       [&](const Array<Var>& indices) {
         PrimExpr ret = default_value;
         if (0 == rank_sparse_indices) {
-          ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret);
+          ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret);
         } else if (1 == rank_sparse_indices) {
           for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
             ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
index 53b46ae..11e4720 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
@@ -77,12 +77,12 @@ def get_binary_elementwise_params(
 
     _, _, _, _, _, inner = get_outer_loops(body, "NHWC")
     op = ignore_cast(inner.value)
-    input_pointer = ignore_cast(op.a).buffer_var
-    input_pointer1 = ignore_cast(op.b).buffer_var
+    input_pointer = ignore_cast(op.a).buffer.data
+    input_pointer1 = ignore_cast(op.b).buffer.data
 
     if reversed_operands:
         input_pointer, input_pointer1 = input_pointer1, input_pointer
-    output_pointer = inner.buffer_var
+    output_pointer = inner.buffer.data
     # Get feature map info
     serial_ifm, _ = get_ifm_params(input_pointer, producers)
     serial_ifm2, _ = get_ifm_params(input_pointer1, producers)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
index 50c27cc..bdca6a8 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
@@ -59,8 +59,8 @@ def get_conv2d_params(stmt, producers, consumers):
     loads = get_loads(rc.body)
     # stores = [output]
     stores = get_stores(rc.body)
-    input_pointer = loads[1].buffer_var
-    output_pointer = stores[0].buffer_var
+    input_pointer = loads[1].buffer.data
+    output_pointer = stores[0].buffer.data
     # Get feature map info
     serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
     serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
@@ -75,16 +75,16 @@ def get_conv2d_params(stmt, producers, consumers):
     )
     # Get scale_bias info
     scale_bias_load = loads[3]
-    scale_bias_base = get_base_address(scale_bias_load.index)
+    scale_bias_base = [get_base_address(index) for index in scale_bias_load.indices]
     serial_scale_bias = SerialAddressRange(
-        address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
+        address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base),
         length=SCALE_BIAS_LENGTH * serial_ofm[3],
     )
     # Get weight info
     weight_load = loads[2]
-    weight_base = get_base_address(weight_load.index)
+    weight_base = [get_base_address(index) for index in weight_load.indices]
     serial_weight = SerialAddressRange(
-        address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
+        address=tvm.tir.BufferLoad(weight_load.buffer, weight_base),
         length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent,
     )
     # Get activation info
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
index b1a4ebd..b39ec36 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
@@ -68,8 +68,8 @@ def get_depthwise_conv2d_params(
     loads = get_loads(rw.body)
     # stores = [output]
     stores = get_stores(rw.body)
-    input_pointer = loads[1].buffer_var
-    output_pointer = stores[0].buffer_var
+    input_pointer = loads[1].buffer.data
+    output_pointer = stores[0].buffer.data
     # Get feature map info
     serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
     serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
@@ -84,16 +84,16 @@ def get_depthwise_conv2d_params(
     )
     # Get scale_bias info
     scale_bias_load = loads[3]
-    scale_bias_base = get_base_address(scale_bias_load.index)
+    scale_bias_base = [get_base_address(index) for index in scale_bias_load.indices]
     serial_scale_bias = SerialAddressRange(
-        address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
+        address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base),
         length=SCALE_BIAS_LENGTH * serial_ofm[3],
     )
     # Get weight info
     weight_load = loads[2]
-    weight_base = get_base_address(weight_load.index)
+    weight_base = [get_base_address(index) for index in weight_load.indices]
     serial_weight = SerialAddressRange(
-        address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
+        address=tvm.tir.BufferLoad(weight_load.buffer, weight_base),
         length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1],
     )
     # Get activation info
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py
index 9f82d74..aa4c09f 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py
@@ -41,12 +41,12 @@ def get_pad_params(stmt):
     """
     _, body = get_op_attrs(stmt)
     n, h, w, c, _, inner = get_outer_loops(body, "NHWC")
-    output_pointer = inner.buffer_var
+    output_pointer = inner.buffer.data
     pad = SerialPadding(top=0, left=0, bottom=0, right=0)
     if isinstance(inner.value, tvm.tir.Call):
-        input_pointer = inner.value.args[1].buffer_var
+        input_pointer = inner.value.args[1].buffer.data
     else:
-        input_pointer = inner.value.buffer_var
+        input_pointer = inner.value.buffer.data
         return pad, input_pointer, output_pointer
 
     padded_shape = [n.extent, h.extent, w.extent, c.extent]
@@ -94,10 +94,10 @@ def get_upscale_params(stmt):
     _, body = get_op_attrs(stmt)
     _, _, _, _, _, inner = get_outer_loops(body, "NHWC")
     if isinstance(inner.value, tvm.tir.Call):
-        input_pointer = inner.value.args[1].buffer_var
+        input_pointer = inner.value.args[1].buffer.data
     else:
-        input_pointer = inner.value.buffer_var
-    output_pointer = inner.buffer_var
+        input_pointer = inner.value.buffer.data
+    output_pointer = inner.buffer.data
     return (input_pointer, output_pointer)
 
 
@@ -126,11 +126,11 @@ def get_convert_to_nhwc_params(stmt):
     # compute that is deemed uneccesary isn't removed by TVM.
     if attrs["layout"] == "NHCWB16":
         inner = inner.body
-        input_pointer = inner.value.b.buffer_var
+        input_pointer = inner.value.b.buffer.data
     else:
-        input_pointer = inner.value.buffer_var
+        input_pointer = inner.value.buffer.data
 
-    output_pointer = inner.buffer_var
+    output_pointer = inner.buffer.data
     return c.extent, input_pointer, output_pointer
 
 
@@ -154,13 +154,13 @@ def get_convert_to_nhcwb16_params(stmt):
     """
     attrs, body = get_op_attrs(stmt)
     _, _, _, c, b, inner = get_outer_loops(body, attrs["layout"])
-    output_pointer = inner.buffer_var
+    output_pointer = inner.buffer.data
     if isinstance(inner.value, tvm.tir.Call):
         cond = inner.value.args[0]
         out_channels = cond.b.value
-        input_pointer = inner.value.args[1].buffer_var
+        input_pointer = inner.value.args[1].buffer.data
     else:
-        input_pointer = inner.value.buffer_var
+        input_pointer = inner.value.buffer.data
         out_channels = c.extent * b.extent if attrs["layout"] == "NHCWB16" else c.extent
 
     return out_channels, input_pointer, output_pointer
@@ -186,12 +186,17 @@ def get_read_params(stmt):
     """
     attrs, body = get_op_attrs(stmt)
     _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"])
-    input_pointer = inner.value.buffer_var
-    output_pointer = inner.buffer_var
+    input_pointer = inner.value.buffer.data
+    output_pointer = inner.buffer.data
+
+    # Needed for stride calculation, can replace with
+    # inner.value.buffer.strides in future.
+    assert len(inner.value.indices) == 1, "Ethos-U DMA expects flattened buffers"
     stride_vars = [h.loop_var, w.loop_var, c.loop_var]
-    strides = get_strides(inner.value.index, stride_vars)
-    base_address = get_base_address(inner.value.index)
-    data_type = inner.buffer_var.type_annotation.element_type.dtype
+    strides = get_strides(inner.value.indices[0], stride_vars)
+
+    base_address = [get_base_address(index) for index in inner.value.indices]
+    data_type = inner.buffer.data.type_annotation.element_type.dtype
     return (
         SerialFeatureMap(
             data_type=data_type,
@@ -201,7 +206,7 @@ def get_read_params(stmt):
             tile_height_0=h.extent,
             tile_height_1=0,
             tile_width_0=w.extent,
-            tile_address_0=tvm.tir.Load(data_type, inner.value.buffer_var, base_address),
+            tile_address_0=tvm.tir.BufferLoad(inner.value.buffer, base_address),
             tile_address_1=0,
             tile_address_2=0,
             tile_address_3=0,
@@ -237,12 +242,17 @@ def get_write_params(stmt):
     """
     attrs, body = get_op_attrs(stmt)
     _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"])
-    input_pointer = inner.value.buffer_var
-    output_pointer = inner.buffer_var
+    input_pointer = inner.value.buffer.data
+    output_pointer = inner.buffer.data
+
+    # Needed for stride calculation, can replace with
+    # inner.value.buffer.strides in future.
+    assert len(inner.indices) == 1, "Ethos-U DMA expects flattened buffers"
     stride_vars = [h.loop_var, w.loop_var, c.loop_var]
-    strides = get_strides(inner.index, stride_vars)
-    base_address = get_base_address(inner.index)
-    data_type = inner.buffer_var.type_annotation.element_type.dtype
+    strides = get_strides(inner.indices[0], stride_vars)
+
+    base_address = [get_base_address(index) for index in inner.indices]
+    data_type = inner.buffer.data.type_annotation.element_type.dtype
     return (
         SerialFeatureMap(
             data_type=data_type,
@@ -252,7 +262,7 @@ def get_write_params(stmt):
             tile_height_0=h.extent,
             tile_height_1=0,
             tile_width_0=w.extent,
-            tile_address_0=tvm.tir.Load(data_type, inner.buffer_var, base_address),
+            tile_address_0=tvm.tir.BufferLoad(inner.buffer, base_address),
             tile_address_1=0,
             tile_address_2=0,
             tile_address_3=0,
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py
index 6dccb5a..40686ac 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py
@@ -59,12 +59,14 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur
 
     fm_inner = inner.value if fm_type == "ifm" else inner
 
+    # Needed for stride calculation, can replace with
+    # inner.value.buffer.strides in future.
+    assert len(fm_inner.indices) == 1, "Ethos-U passes expect flattened buffers"
     stride_vars = [l.loop_var for l in loops]
-    strides = get_strides(fm_inner.index, stride_vars)
+    strides = get_strides(fm_inner.indices[0], stride_vars)
 
-    base_address = get_base_address(fm_inner.index)
-    data_type = inner.buffer_var.type_annotation.element_type.dtype
-    pointer = fm_inner.buffer_var
+    base_address = [get_base_address(index) for index in fm_inner.indices]
+    data_type = inner.buffer.data.type_annotation.element_type.dtype
 
     serial_feature_map = SerialFeatureMap(
         data_type=data_type,
@@ -74,7 +76,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur
         tile_height_0=loops[0].extent,
         tile_height_1=0,
         tile_width_0=loops[1].extent if len(loops) > 1 else 1,
-        tile_address_0=tvm.tir.Load(data_type, pointer, base_address),
+        tile_address_0=tvm.tir.BufferLoad(fm_inner.buffer, base_address),
         tile_address_1=0,
         tile_address_2=0,
         tile_address_3=0,
@@ -86,7 +88,7 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur
         stride_c=strides[2] if len(strides) > 2 else 1,
     )
 
-    output_pointer = inner.buffer_var
+    output_pointer = inner.buffer.data
 
     return serial_feature_map, output_pointer
 
@@ -130,8 +132,8 @@ def get_identity_params(
     # loads = [input, LUT, LUT]
     loads = get_loads(stmt)
 
-    input_pointer = loads[0].buffer_var
-    output_pointer = stmt.buffer_var
+    input_pointer = loads[0].buffer.data
+    output_pointer = stmt.buffer.data
 
     read = producers[input_pointer]
     write = consumers[output_pointer]
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index c2fff8a..5f0b9fe 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -28,7 +28,7 @@ from .binary_elementwise import get_binary_elementwise_params
 from .identity import get_identity_params
 from .unary_elementwise import get_unary_elementwise_params
 from .transform import get_copy_params
-from .utils import get_weights_pointer, get_scale_bias_pointer
+from .utils import get_weights_buffer, get_scale_bias_buffer
 
 
 def RemoveZeroStores():
@@ -82,8 +82,8 @@ def ReplaceOperators():
         loads = []
 
         def _get_loads(stmt):
-            if isinstance(stmt, tvm.tir.Load):
-                loads.append(stmt.buffer_var)
+            if isinstance(stmt, tvm.tir.BufferLoad):
+                loads.append(stmt.buffer.data)
 
         if isinstance(stmt, tvm.tir.Allocate):
             pointer_to_extents[stmt.buffer_var] = stmt.extents
@@ -94,8 +94,8 @@ def ReplaceOperators():
         elif isinstance(stmt, tvm.tir.AttrStmt):
             if stmt.attr_key == "pragma_op":
                 tvm.tir.stmt_functor.post_order_visit(stmt, _get_loads)
-                for load_buffer in loads:
-                    pointer_to_consumer[load_buffer] = stmt
+                for load_pointer in loads:
+                    pointer_to_consumer[load_pointer] = stmt
 
     def _replace_operator(stmt):
         """Replace operators with call_externs, having derived the parameters
@@ -232,21 +232,26 @@ def DivideConstants(const_dict):
     def _visit(stmt):
         new_args = []
         for i, arg in enumerate(stmt.args):
-            if isinstance(arg, tvm.tir.expr.Load):
+            if isinstance(arg, tvm.tir.expr.BufferLoad):
                 # If we're trying to load a buffer that maps to a constant
-                if arg.buffer_var in buffer_to_const:
-                    const = buffer_to_const[arg.buffer_var]
-                    offset = int(arg.index)
+                if arg.buffer.data in buffer_to_const:
+                    const = buffer_to_const[arg.buffer.data]
+
+                    assert len(arg.indices) == 1, "Ethos-U passes expects flattened buffers"
+
+                    offset = int(arg.indices[0])
                     # Note by convention the arg after a constant read is the length of the read
                     length = int(stmt.args[i + 1])
                     # If it's anything other than a full read, create a new buffer
                     if offset != 0 or len(const) != length:
                         new_consts.append(const[offset : offset + length])
-                        new_buffer = tvm.tir.decl_buffer((length,), arg.dtype)
+                        new_buffer = tvm.tir.decl_buffer(
+                            (length,), arg.dtype, scope=arg.buffer.scope()
+                        )
                         new_buffers.append(new_buffer)
-                        new_args.append(tvm.tir.expr.Load(new_buffer.dtype, new_buffer.data, 0))
+                        new_args.append(tvm.tir.expr.BufferLoad(new_buffer, [0]))
                         continue
-                    keep_buffers.add(arg.buffer_var)
+                    keep_buffers.add(arg.buffer.data)
 
             new_args.append(arg)
 
@@ -278,7 +283,15 @@ def DivideConstants(const_dict):
             new_buffer_map[handle] = new_buffer
             new_const_dict[len(new_params) - 1] = new_consts[i]
 
-        new_f = tvm.tir.PrimFunc(new_params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span)
+        new_f = tvm.tir.PrimFunc(
+            new_params,
+            new_body,
+            f.ret_type,
+            new_buffer_map,
+            f.preflattened_buffer_map,
+            f.attrs,
+            f.span,
+        )
         return new_f
 
     def _divide_constants(mod):
@@ -302,179 +315,232 @@ def EncodeConstants(const_dict):
 
     """
     new_const_dict = {}
-    buffer_to_const = {}
-    pointer_to_buffer = {}
-    rewrite_buffer = {}
-    rewrite_pointer = {}
-    accel_config = vela_api.get_accelerator_config()
-
-    def _align_scale_bias(tir_extern_call, bias):
-        """Align the scale_bias to 16 bytes."""
-        value_bytes = bytearray()
-        value_bytes.extend(bias.tobytes())
-        # Align to 16
-        remainder = (len(value_bytes)) % 16
-        if remainder > 0:
-            value_bytes.extend(bytearray(16 - remainder))
-        value = np.frombuffer(value_bytes, dtype="uint8")
-        return value
-
-    def _encode_weights(tir_extern_call, weights):
-        """Encode the weights for a TIR extern call."""
-        value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config)
-        value = np.frombuffer(value_bytes, dtype="uint8")
-        return value
-
-    def _new_buffer(old_buffer, new_value):
-        """Create a new buffer and add the old buffer and its pointer to the
-        rewriting maps."""
-        if old_buffer in rewrite_buffer:
-            new_buffer = rewrite_buffer[old_buffer]
-        else:
-            new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype))
-            pointer_to_buffer[new_buffer.data] = new_buffer
-            buffer_to_const[new_buffer] = new_value
-
-        rewrite_buffer[old_buffer] = new_buffer
-        rewrite_pointer[old_buffer.data] = new_buffer.data
-
-    def _visit_encode_pre(stmt):
-        if isinstance(stmt, tvm.tir.Call):
-            # Handle copies as a special-case by propagating the buffer information
-            # from the read to the write pointer.
-            if stmt.args[0] == "ethosu_copy":
-                read_pointer = stmt.args[1].buffer_var
-                if read_pointer in pointer_to_buffer:
-                    write_pointer = stmt.args[3].buffer_var
+
+    def collect_encoding_definitions(stmt, old_buffer_to_const):
+        # Map from copy destination to copy source.
+        copy_map = {}
+        # List of buffer copies that occurred
+        copied_buffers = []
+        # List of encoded buffer information
+        constant_buffer_replacements = []
+
+        def _align_scale_bias(tir_extern_call, bias):
+            """Align the scale_bias to 16 bytes."""
+            value_bytes = bytearray()
+            value_bytes.extend(bias.tobytes())
+            # Align to 16
+            remainder = (len(value_bytes)) % 16
+            if remainder > 0:
+                value_bytes.extend(bytearray(16 - remainder))
+            value = np.frombuffer(value_bytes, dtype="uint8")
+            return value
+
+        accel_config = vela_api.get_accelerator_config()
+
+        def _encode_weights(tir_extern_call, weights):
+            """Encode the weights for a TIR extern call."""
+            value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config)
+            value = np.frombuffer(value_bytes, dtype="uint8")
+            return value
+
+        def _declare_constant_buffer(old_buffer, encoded_constants):
+            """Create a new buffer and add the old buffer and its pointer to the
+            rewriting maps."""
+            new_buffer = tvm.tir.decl_buffer(
+                shape=[len(encoded_constants)],
+                dtype=str(encoded_constants.dtype),
+                name=old_buffer.name + "_encoded",
+                scope=old_buffer.scope(),
+            )
+
+            constant_buffer_replacements.append(
+                {
+                    "old_buffer": old_buffer,
+                    "new_buffer": new_buffer,
+                    "encoded_constants": encoded_constants,
+                }
+            )
+
+        def _visit(stmt):
+            if isinstance(stmt, tvm.tir.Call):
+                # Handle copies as a special-case by propagating the buffer information
+                # from the read to the write pointer.
+                if stmt.args[0] == "ethosu_copy":
+                    read_buffer = stmt.args[1].buffer
+                    write_buffer = stmt.args[3].buffer
                     # Assert writing to the base of the write_var (pre-StorageRewrite)
-                    assert stmt.args[3].index == 0
-                    assert stmt.args[1].index == 0
-                    pointer_to_buffer[write_pointer] = pointer_to_buffer[read_pointer]
-            else:
-                # Encode the weights
-                weights_pointer = get_weights_pointer(stmt)
-                if weights_pointer is not None:
-                    assert weights_pointer in pointer_to_buffer
-                    weights_buffer = pointer_to_buffer[weights_pointer]
-                    weights_value = buffer_to_const[weights_buffer]
-                    new_weights_value = _encode_weights(stmt, weights_value)
-                    _new_buffer(weights_buffer, new_weights_value)
-                # Align the scale_bias to 16 bytes
-                scale_bias_pointer = get_scale_bias_pointer(stmt)
-                if scale_bias_pointer is not None:
-                    assert scale_bias_pointer in pointer_to_buffer
-                    scale_bias_buffer = pointer_to_buffer[scale_bias_pointer]
-                    scale_bias_value = buffer_to_const[scale_bias_buffer]
-                    new_scale_bias_value = _align_scale_bias(stmt, scale_bias_value)
-                    _new_buffer(scale_bias_buffer, new_scale_bias_value)
-
-    def _visit_encode_post(stmt):
-        # Because encoding may change the data type (e.g. bias to uint8) and type information
-        # is stored in pointer vars, it's necessary to rewrite all the pointers which point
-        # to encoded data.
-        if isinstance(stmt, tvm.tir.Allocate):
-            allocate_pointer = stmt.buffer_var
-            if allocate_pointer in pointer_to_buffer:
-                buffer = pointer_to_buffer[allocate_pointer]
-                if buffer in rewrite_buffer:  # If the pointer needs rewriting
-                    # Create a new pointer var with the type of the new buffer
-                    new_buffer = rewrite_buffer[buffer]
-                    storage_type = tvm.ir.PrimType(new_buffer.dtype)
-                    new_pointer = tvm.tir.Var(
-                        allocate_pointer.name,
-                        tvm.ir.PointerType(storage_type, buffer.scope()),
-                        allocate_pointer.span,
-                    )
-                    # Set the new pointer to resolve to the new buffer
-                    pointer_to_buffer[new_pointer] = new_buffer
-                    # Add the old pointer to the pointer rewriting dict
-                    rewrite_pointer[allocate_pointer] = new_pointer
-
-    def _visit_rewrite(stmt):
-        if isinstance(stmt, tvm.tir.Call):
-            # For extern calls, we need to rewrite pairs of arguments corresponding to
-            # base address load and the length of the load.
-            new_args = [stmt.args[0]]
-            new_buffers = rewrite_buffer.values()
-            for i in range(1, len(stmt.args)):
-                # If the previous argument was a load, the current should be a length
-                if isinstance(stmt.args[i - 1], tvm.tir.Load):
-                    load = stmt.args[i - 1]
-                    pointer = load.buffer_var
-                    if pointer in pointer_to_buffer:
-                        buffer = pointer_to_buffer[pointer]
-                        # Only rewrite the arguments of buffers that have been encoded
-                        if buffer in new_buffers:
-                            new_arg = np.prod(list(pointer_to_buffer[pointer].shape))
-                            new_args.append(new_arg)
-                            continue
-                new_args.append(stmt.args[i])
-
-            return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span)
-        if isinstance(stmt, tvm.tir.Allocate):
-            # Where a pointer needs rewriting, the allocate for it must be rewritten
-            allocate_pointer = stmt.buffer_var
-            if allocate_pointer in pointer_to_buffer:
-                if pointer_to_buffer[allocate_pointer] in rewrite_buffer:
-                    new_buffer = rewrite_buffer[pointer_to_buffer[allocate_pointer]]
-                    new_pointer = rewrite_pointer[allocate_pointer]
+                    assert list(stmt.args[3].indices) == [0]
+                    assert list(stmt.args[1].indices) == [0]
+                    copied_buffers.append({"source": read_buffer, "dest": write_buffer})
+                    copy_map[write_buffer] = read_buffer
+
+                else:
+                    # Encode the weights
+                    weights_buffer = get_weights_buffer(stmt)
+                    if weights_buffer is not None:
+                        if weights_buffer in copy_map:
+                            weights_buffer = copy_map[weights_buffer]
+                        unencoded_weights_value = old_buffer_to_const[weights_buffer]
+                        encoded_weights_value = _encode_weights(stmt, unencoded_weights_value)
+                        _declare_constant_buffer(weights_buffer, encoded_weights_value)
+
+                    # Align the scale_bias to 16 bytes
+                    scale_bias_buffer = get_scale_bias_buffer(stmt)
+                    if scale_bias_buffer is not None:
+                        if scale_bias_buffer in copy_map:
+                            scale_bias_buffer = copy_map[scale_bias_buffer]
+                        scale_bias_value = old_buffer_to_const[scale_bias_buffer]
+                        aligned_scale_bias_value = _align_scale_bias(stmt, scale_bias_value)
+                        _declare_constant_buffer(scale_bias_buffer, aligned_scale_bias_value)
+
+        tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
+
+        return {
+            "copied_buffers": copied_buffers,
+            "constant_buffer_replacements": constant_buffer_replacements,
+        }
+
+    def transform_stmt(stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const):
+        def _visit_rewrite(stmt):
+            if isinstance(stmt, tvm.tir.Call):
+                # For extern calls, we need to rewrite pairs of arguments corresponding to
+                # base address load and the length of the load.
+                old_args = list(stmt.args)
+
+                new_args = [stmt.args[0]]
+                for prev_arg, arg in zip(old_args[:-1], old_args[1:]):
+                    # If the previous argument was a load from an
+                    # encoded buffer, the current should be a length.
+                    if (
+                        isinstance(prev_arg, tvm.tir.BufferLoad)
+                        and prev_arg.buffer in new_buffer_to_const
+                    ):
+                        arg = np.prod(list(prev_arg.buffer.shape))
+
+                    new_args.append(arg)
+
+                return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span)
+
+            if isinstance(stmt, tvm.tir.Allocate):
+                # Where a pointer needs rewriting, the allocate for it must be rewritten
+                allocate_pointer = stmt.buffer_var
+                if allocate_pointer in var_remap:
+                    new_allocate_pointer = var_remap[allocate_pointer]
+                    new_buffer = pointer_to_buffer[new_allocate_pointer]
+
                     return tvm.tir.Allocate(
-                        new_pointer,
+                        new_buffer.data,
                         new_buffer.dtype,
                         new_buffer.shape,
                         stmt.condition,
                         stmt.body,
                         stmt.span,
                     )
-        # The following rewrites would be better expressed by just rewriting the Vars, however
-        # ir_transform doesn't seem to visit Vars. So instead we do the next best thing and rewrite
-        # the nodes which contain the Vars.
-        if isinstance(stmt, tvm.tir.Load):
-            load_pointer = stmt.buffer_var
-            if load_pointer in rewrite_pointer:
-                new_pointer = rewrite_pointer[load_pointer]
-                element_type = new_pointer.type_annotation.element_type.dtype
-                return tvm.tir.Load(
-                    element_type, new_pointer, stmt.index, stmt.predicate, stmt.span
-                )
-        if isinstance(stmt, tvm.tir.AttrStmt):
-            node_pointer = stmt.node
-            if node_pointer in rewrite_pointer:
-                return tvm.tir.AttrStmt(
-                    rewrite_pointer[node_pointer], stmt.attr_key, stmt.value, stmt.body, stmt.span
-                )
-        return None
+
+            # The following rewrites would be better expressed by just
+            # rewriting the Buffers. However ir_transform doesn't
+            # visit Buffers, so instead we do the next best thing and
+            # rewrite the nodes which contain the Buffers.
+            if isinstance(stmt, tvm.tir.BufferLoad):
+                if stmt.buffer in buf_remap:
+                    return tvm.tir.BufferLoad(buf_remap[stmt.buffer], stmt.indices, stmt.span)
+
+            if isinstance(stmt, tvm.tir.AttrStmt):
+                node_pointer = stmt.node
+                if node_pointer in var_remap:
+                    return tvm.tir.AttrStmt(
+                        var_remap[node_pointer],
+                        stmt.attr_key,
+                        stmt.value,
+                        stmt.body,
+                        stmt.span,
+                    )
+
+            return None
+
+        return tvm.tir.stmt_functor.ir_transform(
+            stmt,
+            None,
+            _visit_rewrite,
+            ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"],
+        )
 
     def _ftransform(f, mod, ctx):
+        # Step 0: Unpack the constant dictionary in terms of the
+        # functions buffers.
+        old_buffer_to_const = {}
         for i, param in enumerate(f.params):
             if i in const_dict:
-                buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten()
-                pointer_to_buffer[f.buffer_map[param].data] = f.buffer_map[param]
-
-        # First analyse what needs to be rewritten
-        new_body = tvm.tir.stmt_functor.ir_transform(
-            f.body, _visit_encode_pre, _visit_encode_post, ["tir.Call", "tir.Allocate"]
-        )
-        # Then perform the rewrites
-        new_body = tvm.tir.stmt_functor.ir_transform(
-            f.body, None, _visit_rewrite, ["tir.Call", "tir.Allocate", "tir.Load", "tir.AttrStmt"]
+                old_buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten()
+
+        # Step 1: Collect information on the buffers that will be
+        # replaced by encodings.
+        buffer_information = collect_encoding_definitions(f.body, old_buffer_to_const)
+
+        # Step 2: Generate variable/buffer remaps, based on the
+        # collected information.
+        buf_remap = {}
+        new_buffer_to_const = {}
+
+        # Any encoded buffers must be replaced
+        for info in buffer_information["constant_buffer_replacements"]:
+            buf_remap[info["old_buffer"]] = info["new_buffer"]
+            new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"]
+
+        # Any buffers that are copied into from an encoded buffer must
+        # be replaced.
+        for info in buffer_information["copied_buffers"]:
+            copy_source = info["source"]
+            while copy_source in buf_remap:
+                copy_source = buf_remap[copy_source]
+
+            copy_dest = info["dest"]
+
+            if copy_source.shape != copy_dest.shape or copy_source.dtype != copy_dest.dtype:
+                new_dest = tvm.tir.decl_buffer(
+                    shape=copy_source.shape,
+                    dtype=copy_source.dtype,
+                    name=copy_dest.name,
+                    scope=copy_dest.scope(),
+                )
+                buf_remap[copy_dest] = new_dest
+                if copy_source in new_buffer_to_const:
+                    new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source]
+
+        # Define additional dependent lookup tables.
+        var_remap = {old.data: new.data for (old, new) in buf_remap.items()}
+        pointer_to_buffer = {
+            buf.data: buf for (old, new) in buf_remap.items() for buf in [old, new]
+        }
+
+        # Step 3: Then perform the rewrites
+        new_body = transform_stmt(
+            f.body, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const
         )
+
+        # Step 4: Rewrite the buffer map and const dict to instead use the encoded versions
         new_buffer_map = {}
-        # Rewrite the buffer map and const dict to instead use the encoded versions
         for i, param in enumerate(f.params):
             buffer = f.buffer_map[param]
-            if buffer in rewrite_buffer:
-                new_buffer = rewrite_buffer[buffer]
-                new_buffer_map[param] = new_buffer
-                new_value = buffer_to_const[new_buffer]
-                new_const_dict[i] = new_value
-            elif buffer in buffer_to_const:
-                new_const_dict[i] = buffer_to_const[buffer]
-                new_buffer_map[param] = buffer
-            else:
-                new_buffer_map[param] = buffer
-
-        new_f = tvm.tir.PrimFunc(f.params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span)
+            if buffer in buf_remap:
+                buffer = buf_remap[buffer]
+
+            if buffer in new_buffer_to_const:
+                new_const_dict[i] = new_buffer_to_const[buffer]
+            elif buffer in old_buffer_to_const:
+                new_const_dict[i] = old_buffer_to_const[buffer]
+
+            new_buffer_map[param] = buffer
+
+        new_f = tvm.tir.PrimFunc(
+            f.params,
+            new_body,
+            f.ret_type,
+            new_buffer_map,
+            f.preflattened_buffer_map,
+            f.attrs,
+            f.span,
+        )
         return new_f
 
     def _encode_constants(mod):
@@ -706,15 +772,26 @@ def CreatePrimFuncWithoutConstants(const_dict):
     def _ftransform(f, mod, ctx):
         new_params = list()
         new_buffer_map = dict()
+        new_preflattened_buffer_map = dict()
         for param_idx in const_dict.keys():
             # We are using buffer_var to key the constants as
             # PrimFunc params of constants will be removed.
             new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx]
-        for i in range(len(f.params)):
+        for i, param in enumerate(f.params):
             if i not in const_dict.keys():
-                new_params.append(f.params[i])
-                new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]]
-        return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.attrs, f.span)
+                new_params.append(param)
+                new_buffer_map[param] = f.buffer_map[param]
+                if param in f.preflattened_buffer_map:
+                    new_preflattened_buffer_map[param] = f.preflattened_buffer_map[param]
+        return tvm.tir.PrimFunc(
+            new_params,
+            f.body,
+            f.ret_type,
+            new_buffer_map,
+            new_preflattened_buffer_map,
+            f.attrs,
+            f.span,
+        )
 
     def _create_primfunc_without_constants(mod):
         transform_func = tvm.tir.transform.prim_func_pass(
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py
index e929caa..3b32ef0 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py
@@ -61,8 +61,8 @@ def get_pooling_params(
     loads = get_loads(rw.body)
     # stores = [output]
     stores = get_stores(rw.body)
-    input_pointer = loads[1].buffer_var
-    output_pointer = stores[0].buffer_var
+    input_pointer = loads[1].buffer.data
+    output_pointer = stores[0].buffer.data
     # Get feature map info
     serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
     serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
index f9d38df..d390fc0 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
@@ -93,10 +93,10 @@ class SerialFeatureMap(SerializableFormat):
         tile_height_0: int,
         tile_height_1: int,
         tile_width_0: int,
-        tile_address_0: tvm.tir.expr.Load,
-        tile_address_1: Union[tvm.tir.expr.Load, int],
-        tile_address_2: Union[tvm.tir.expr.Load, int],
-        tile_address_3: Union[tvm.tir.expr.Load, int],
+        tile_address_0: tvm.tir.expr.BufferLoad,
+        tile_address_1: Union[tvm.tir.expr.BufferLoad, int],
+        tile_address_2: Union[tvm.tir.expr.BufferLoad, int],
+        tile_address_3: Union[tvm.tir.expr.BufferLoad, int],
         scale: float,
         zero_point: int,
         layout: str,
@@ -148,7 +148,7 @@ class SerialAddressRange(SerializableFormat):
     """Specialization class to retrieve arguments of a AddressRange
     (similiar to NpuAddressRange of Vela) on a predefined ordering"""
 
-    def __init__(self, address: tvm.tir.expr.Load, length: int):
+    def __init__(self, address: tvm.tir.expr.BufferLoad, length: int):
         self.address = address
         self.length = length
 
@@ -237,7 +237,10 @@ class SerialCopy(SerializableFormat):
     a ethosu.copy tir extern call on a predefined ordering"""
 
     def __init__(
-        self, read_address: tvm.tir.expr.Load, length: int, write_address: tvm.tir.expr.Load
+        self,
+        read_address: tvm.tir.expr.BufferLoad,
+        length: int,
+        write_address: tvm.tir.expr.BufferLoad,
     ):
         self.read_address = read_address
         self.length = length
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py
index 141505a..53e0bd2 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py
@@ -50,17 +50,16 @@ def get_copy_params(stmt, producers, consumers):
     _, body = get_op_attrs(stmt)
     length = body.extent
     write_store = body.body
-    write_base = get_base_address(write_store.index)
+    write_base = [get_base_address(index) for index in write_store.indices]
     read_load = body.body.value
-    read_base = get_base_address(read_load.index)
-    dtype = body.body.value.dtype
+    read_base = [get_base_address(index) for index in read_load.indices]
     return (
         SerialCopy(
-            read_address=tvm.tir.expr.Load(dtype, read_load.buffer_var, read_base),
+            read_address=tvm.tir.expr.BufferLoad(read_load.buffer, read_base),
             length=length,
-            write_address=tvm.tir.expr.Load(dtype, write_store.buffer_var, write_base),
+            write_address=tvm.tir.expr.BufferLoad(write_store.buffer, write_base),
         ),
-        write_store.buffer_var,
+        write_store.buffer.data,
         None,
         True,
     )
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py
index b550b79..9c570d8 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py
@@ -54,11 +54,11 @@ def get_unary_elementwise_params(stmt, producers, consumers):
     input_pointer = None
     if isinstance(inner.value, tir.expr.Select):
         # ABS
-        input_pointer = inner.value.condition.b.buffer_var
+        input_pointer = inner.value.condition.b.buffer.data
     if isinstance(inner.value, tir.expr.Sub):
         # CLZ
-        input_pointer = inner.value.b.args[0].buffer_var
-    output_pointer = inner.buffer_var
+        input_pointer = inner.value.b.args[0].buffer.data
+    output_pointer = inner.buffer.data
     # Get feature map info
     serial_ifm, _ = get_ifm_params(input_pointer, producers)
     serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py
index de1c0ab..506f18b 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py
@@ -21,20 +21,20 @@ from tvm import arith
 
 
 # TODO(@mbaret): Formalise this with a specification
-def get_weights_pointer(tir_extern_call):
+def get_weights_buffer(tir_extern_call):
     """Get the weights pointer from a NPU extern call if it exists"""
     supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"]
     if tir_extern_call.args[0] in supported_ops:
-        return tir_extern_call.args[41].buffer_var
+        return tir_extern_call.args[41].buffer
     return None
 
 
 # TODO(@mbaret): Formalise this with a specification
-def get_scale_bias_pointer(tir_extern_call):
+def get_scale_bias_buffer(tir_extern_call):
     """Get the scale_bias pointer from a NPU extern call if it exists"""
     supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"]
     if tir_extern_call.args[0] in supported_ops:
-        return tir_extern_call.args[44].buffer_var
+        return tir_extern_call.args[44].buffer
     return None
 
 
@@ -177,23 +177,23 @@ def get_outer_loops(stmt, layout):
 
 
 def get_loads(stmt):
-    """Get the Load statements.
+    """Get the BufferLoad statements.
 
     Parameters
     ----------
     stmt : tvm.tir.Stmt
-        The statement to get the Loads from.
+        The statement to get the BufferLoads from.
 
     Returns
     -------
-    loads : list of tvm.tir.Load
-        The Loads found.
+    loads : list of tvm.tir.BufferLoad
+        The BufferLoads found.
 
     """
     loads = []
 
     def _visit(s):
-        if isinstance(s, tvm.tir.Load):
+        if isinstance(s, tvm.tir.BufferLoad):
             loads.append(s)
 
     tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
@@ -201,23 +201,23 @@ def get_loads(stmt):
 
 
 def get_stores(stmt):
-    """Get the Store statements.
+    """Get the BufferStore statements.
 
     Parameters
     ----------
     stmt : tvm.tir.Stmt
-        The statement to get the Stores from.
+        The statement to get the BufferStores from.
 
     Returns
     -------
-    stores : list of tvm.tir.Store
-        The Stores found.
+    stores : list of tvm.tir.BufferStore
+        The BufferStores found.
 
     """
     stores = []
 
     def _visit(s):
-        if isinstance(s, tvm.tir.Store):
+        if isinstance(s, tvm.tir.BufferStore):
             stores.append(s)
 
     tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
index f642f5f..33a22d1 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
@@ -122,9 +122,9 @@ def analyze_scratch_memory_acesses(mod: tvm.IRModule, candidate_regions_for_scra
         if isinstance(stmt, tvm.tir.stmt.LetStmt):
             call_address_of = stmt.value
             load = call_address_of.args[0]
-            pool_var = load.buffer_var
+            pool_var = load.buffer.data
             scratch_region_map[stmt.var] = RegionOffset(
-                region=pool_var_region_map[pool_var], offset=int(load.index)
+                region=pool_var_region_map[pool_var], offset=int(load.indices[0])
             )
 
     tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_pool_access)
@@ -334,6 +334,8 @@ def extract_buffer_info(
     primfunc = mod.functions.items()[0][1]
 
     for param, const_data in param_dict.items():
+        if isinstance(param, tvm.tir.Buffer):
+            param = param.data
         buffer_info[param] = BufferInfo(
             const_data, const_data.shape, const_data.dtype, BufferType.constant
         )
@@ -385,6 +387,7 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map):
         This is the dictionary obtained via calling extract_buffer_info.
         The key is the buffer name to BufferInfo
     npu_ops : list
+        A list of Vela NpuOps with tir.BufferLoads for addresses
         A list of Vela NpuOps with tir.Loads for addresses
     scratch_region_map : Dict[tvm.tir.Var, RegionOffset]
         A buffer_var to region and offset map.
@@ -397,14 +400,13 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map):
     """
 
     def replace_npu_fm_with_address(npu_fm):
-        assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.Load)
+        assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.BufferLoad)
         # We currently does not support tiles
         # Change this when tiles are needed
         # (i.e. when using rolling buffers)
         assert npu_fm.tiles.addresses[1:] == [0, 0, 0]
         npu_fm.tiles.addresses[1:] = [0, 0, 0]
-        buffer = npu_fm.tiles.addresses[0].buffer_var
-
+        buffer = npu_fm.tiles.addresses[0].buffer.data
         if buffer in scratch_region_map.keys():
             address = scratch_region_map[buffer].offset
             region = scratch_region_map[buffer].region
@@ -412,8 +414,10 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map):
             assert buffer in buffer_addresses.keys()
             address, buffer_type = buffer_addresses[buffer]
             region = _get_region(buffer_type)
-
-        index = npu_fm.tiles.addresses[0].index * (
+        assert (
+            len(npu_fm.tiles.addresses[0].indices) == 1
+        ), "Ethos-U translation expects flattened buffers"
+        index = npu_fm.tiles.addresses[0].indices[0] * (
             np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8
         )
         npu_fm.tiles.addresses[0] = address + int(index)
@@ -421,10 +425,11 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map):
         return npu_fm
 
     def replace_npu_address_range_with_address(npu_addr_range):
-        assert isinstance(npu_addr_range.address, tvm.tir.Load)
-        buffer = npu_addr_range.address.buffer_var
+        assert isinstance(npu_addr_range.address, tvm.tir.BufferLoad)
+        buffer = npu_addr_range.address.buffer.data
         index = int(
-            npu_addr_range.address.index * (np.iinfo(np.dtype(npu_addr_range.address)).bits // 8)
+            npu_addr_range.address.indices[0]
+            * (np.iinfo(np.dtype(npu_addr_range.address)).bits // 8)
         )
         if buffer in scratch_region_map.keys():
             return vapi.NpuAddressRange(
@@ -446,11 +451,11 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map):
     def classify_io(buffer):
         for _npu_op in npu_ops:
             if issubclass(type(_npu_op), vapi.NpuBlockOperation):
-                if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer_var == buffer:
+                if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer.data == buffer:
                     return BufferType.input
-                if _npu_op.ifm2 and _npu_op.ifm2.tiles.addresses[0].buffer_var == buffer:
+                if _npu_op.ifm2 and _npu_op.ifm2.tiles.addresses[0].buffer.data == buffer:
                     return BufferType.input
-                if _npu_op.ofm and _npu_op.ofm.tiles.addresses[0].buffer_var == buffer:
+                if _npu_op.ofm and _npu_op.ofm.tiles.addresses[0].buffer.data == buffer:
                     return BufferType.output
 
         raise ValueError(f"Unused IO : {buffer} in tir module.")
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index fbd5c3d..c90fd68 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -200,10 +200,10 @@ _reg.register_shape_func("invert_permutation", False, elemwise_shape_func)
 @script
 def _arange_shape_func(start, stop, step):
     out = output_tensor((1,), "int64")
-    if step[0] < 0:
-        out[0] = int64(ceil_div((int64(start[0]) - int64(stop[0])), int64(-step[0])))
+    if step[()] < 0:
+        out[0] = int64(ceil_div((int64(start[()]) - int64(stop[()])), int64(-step[()])))
     else:
-        out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0])))
+        out[0] = int64(ceil_div((int64(stop[()]) - int64(start[()])), int64(step[()])))
     return out
 
 
diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py
index c909764..d523d43 100644
--- a/python/tvm/relay/op/dyn/_transform.py
+++ b/python/tvm/relay/op/dyn/_transform.py
@@ -170,7 +170,7 @@ def _onehot_shape_func(dshape, k, axis):
     out = output_tensor((ndim,), "int64")
     for i in const_range(axis):
         out[i] = int64(dshape[i])
-    out[axis] = int64(k[0])
+    out[axis] = int64(k[(0)])
     for j in const_range(axis + 1, ndim):
         out[j] = int64(dshape[j - 1])
     return out
diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py
index 7277151..ec40665 100644
--- a/python/tvm/relay/op/dyn/nn/_nn.py
+++ b/python/tvm/relay/op/dyn/nn/_nn.py
@@ -78,8 +78,8 @@ def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis):
     out = output_tensor((4,), "int64")
     for i in const_range(4):
         out[i] = int64(dshape[i])
-    out[height_axis] = int64(round(dshape[height_axis] * scale_h[0]))
-    out[width_axis] = int64(round(dshape[width_axis] * scale_w[0]))
+    out[height_axis] = int64(round(dshape[height_axis] * scale_h[()]))
+    out[width_axis] = int64(round(dshape[width_axis] * scale_w[()]))
     return out
 
 
@@ -108,9 +108,9 @@ def _upsampling3d_shape_func(
     out = output_tensor((5,), "int64")
     for i in const_range(5):
         out[i] = int64(dshape[i])
-    out[depth_axis] = int64(round(dshape[depth_axis] * scale_d[0]))
-    out[height_axis] = int64(round(dshape[height_axis] * scale_h[0]))
-    out[width_axis] = int64(round(dshape[width_axis] * scale_w[0]))
+    out[depth_axis] = int64(round(dshape[depth_axis] * scale_d[()]))
+    out[height_axis] = int64(round(dshape[height_axis] * scale_h[()]))
+    out[width_axis] = int64(round(dshape[width_axis] * scale_w[()]))
     return out
 
 
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py
index 149e17b..972e584 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/context_maintainer.py
@@ -127,6 +127,8 @@ class ContextMaintainer:
     """List[Var]: The function parameters"""
     func_buffer_map: Mapping[Var, Buffer] = {}
     """Mapping[Var, Buffer]: The function buffer map"""
+    func_preflattened_buffer_map: Mapping[Var, Buffer] = {}
+    """Mapping[Var, Buffer]: The function buffer map, prior to any flattening."""
     func_dict_attr: Mapping[str, Object] = {}
     """Mapping[str, Object]: The function attrs"""
     func_var_env_dict: Mapping[Var, str] = {}
@@ -151,6 +153,7 @@ class ContextMaintainer:
         # function context
         self.func_params = []
         self.func_buffer_map = {}
+        self.func_preflattened_buffer_map = {}
         self.func_dict_attr = {}
         self.func_var_env_dict = {}
         # parser and analyzer
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 587fbe4..17beb81 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -484,6 +484,7 @@ class TVMScriptParser(Transformer):
             body,
             ret_type,
             buffer_map=self.context.func_buffer_map,
+            preflattened_buffer_map=self.context.func_preflattened_buffer_map,
             attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None,
             span=tvm_span_from_synr(node.span),
         )
@@ -552,7 +553,11 @@ class TVMScriptParser(Transformer):
 
         if isinstance(node.rhs, ast.Call):
             # Pattern 1 & Pattern 4
-            func = self.transform(node.rhs.func_name)
+            if isinstance(node.rhs.func_name, ast.Op):
+                func = None
+            else:
+                func = self.transform(node.rhs.func_name)
+
             if isinstance(func, WithScopeHandler):
                 if not func.concise_scope or not func.def_symbol:
                     self.report_error(
@@ -610,6 +615,12 @@ class TVMScriptParser(Transformer):
         rhs = self.transform(node.params[2])
         rhs_span = tvm_span_from_synr(node.params[2].span)
         if isinstance(symbol, tvm.tir.Buffer):
+            if len(indexes) != len(symbol.shape):
+                self.report_error(
+                    f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, "
+                    f"cannot be indexed by {len(indexes)}-dimensional indices.",
+                    node.params[1].span,
+                )
             # BufferStore
             return tvm.tir.BufferStore(
                 symbol,
@@ -629,15 +640,29 @@ class TVMScriptParser(Transformer):
                     f"Store is only allowed with one index, but {len(indexes)} were provided.",
                     node.params[1].span,
                 )
-            # Store
-            return tvm.tir.Store(
-                symbol,
-                tvm.runtime.convert(rhs, span=rhs_span),
-                indexes[0],
-                tvm.runtime.convert(True, span=tvm_span_from_synr(node.span)),
-                span=tvm_span_from_synr(node.span),
+            self.report_error(
+                "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span
+            )
+
+    def transform_AttrAssign(self, node):
+        """Visitor for statements of the form :code:`x.y = 2`."""
+        obj = self.transform(node.params[0])
+        field = node.params[1]
+        value = self.transform(node.params[2])
+
+        if not hasattr(obj, field.name):
+            self.error(f"Field {field.name} does not exist", field.span)
+
+        var = getattr(obj, field.name)
+
+        if not isinstance(var, tvm.tir.Var):
+            self.error(
+                f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span
             )
 
+        body = self.parse_body(node)
+        return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
+
     def transform_Assert(self, node):
         """Assert visitor
 
@@ -866,13 +891,16 @@ class TVMScriptParser(Transformer):
         """
         # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
         if isinstance(node.call.func_name, ast.Op):
-            if node.call.func_name.name != ast.BuiltinOp.SubscriptAssign:
-                self.report_error(
-                    "Binary and unary operators are not allowed as a statement", node.span
-                )
-            else:
+            if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign:
                 return self.transform_SubscriptAssign(node.call)
 
+            if node.call.func_name.name == ast.BuiltinOp.AttrAssign:
+                return self.transform_AttrAssign(node.call)
+
+            self.report_error(
+                "Binary and unary operators are not allowed as a statement", node.span
+            )
+
         # handle a regular function call
         func = self.transform(node.call.func_name)
         arg_list = self.parse_arg_list(func, node.call)
@@ -952,15 +980,8 @@ class TVMScriptParser(Transformer):
                     node.span,
                 )
 
-            return call_with_error_reporting(
-                self.report_error,
-                node.span,
-                tvm.tir.Load,
-                "float32",
-                symbol,
-                index,
-                True,
-                span=tvm_span_from_synr(node.span),
+            self.report_error(
+                "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span
             )
         elif isinstance(symbol, tvm.tir.Buffer):
             return BufferSlice(
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
index ac4ee30..0593236 100644
--- a/python/tvm/script/tir/__init__.pyi
+++ b/python/tvm/script/tir/__init__.pyi
@@ -311,7 +311,7 @@ def allocate(
     scope: str,
     condition: Union[PrimExpr, builtins.bool] = True,
     annotations: Optional[Mapping[str, Object]] = None,
-) -> Var: ...
+) -> Buffer: ...
 def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ...
 def realize(
     buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py
index 8564fc1..4dc78ba 100644
--- a/python/tvm/script/tir/node.py
+++ b/python/tvm/script/tir/node.py
@@ -96,7 +96,8 @@ class BufferSlice(ObjectGeneric):
                 if index < 0:
                     report_error("Negative index is not allowed during buffer access", span)
             elif isinstance(index, PrimExpr):
-                if index.dtype != "int32":
+                element_dtype = index.dtype.split("x", maxsplit=1)[0]
+                if element_dtype != "int32":
                     report_error(
                         "index expected an int32 type PrimExpr but got " + str(index.dtype),
                         index.span,
@@ -153,3 +154,6 @@ class BufferSlice(ObjectGeneric):
 
         indices = [s.start for s in self.slices]
         return BufferLoad(self.buffer, indices, span=self.span)
+
+    def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
+        return self.asobject().astype(dtype, span)
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
index 07ba204..2da7b78 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -112,10 +112,16 @@ class Allocate(WithScopeHandler):
             condition = tvm.runtime.convert(condition)
             scope = tvm.runtime.convert(scope)
 
+            # Currently, allocate nodes should only occur after buffer
+            # flattening has been applied.  This can be simplified in
+            # the future by having the AllocateNode hold a buffer
+            # object directly.
+            flattened = self.buffer.get_flattened_buffer()
+
             return tvm.tir.Allocate(
-                self.buffer_var,
-                dtype,
-                extents,
+                self.buffer.data,
+                flattened.dtype,
+                flattened.shape,
                 condition,
                 self.body,
                 annotations=annotations,
@@ -123,7 +129,7 @@ class Allocate(WithScopeHandler):
             )
 
         super().__init__(allocate, concise_scope=True, def_symbol=True)
-        self.buffer_var = None
+        self.buffer = None
 
     def enter_scope(
         self,
@@ -147,15 +153,20 @@ class Allocate(WithScopeHandler):
         else:
             raise Exception("Internal Bug")
 
-        def setup_buffer_var(
+        def setup_buffer(
             extents, dtype, scope, condition=True, annotations=None, span: Span = None
         ):
-            """Setup buffer var for a given type."""
-            buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
-            self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
+            """Setup buffer object for a given type."""
+            self.buffer = tvm.tir.decl_buffer(
+                shape=extents,
+                dtype=dtype,
+                name=name,
+                scope=scope,
+                span=span,
+            )
 
-        setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer_var, node)
+        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
+        context.update_symbol(name, self.buffer, node)
 
 
 @register
@@ -171,11 +182,11 @@ class AllocateConst(WithScopeHandler):
             for i in raw_data:
                 list_data.append(i.value)
             nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
-            n = tvm.tir.AllocateConst(self.buffer_var, dtype, shape, nd_data, self.body, span=span)
+            n = tvm.tir.AllocateConst(self.buffer.data, dtype, shape, nd_data, self.body, span=span)
             return n
 
         super().__init__(allocate_const, concise_scope=True, def_symbol=True)
-        self.buffer_var = None
+        self.buffer = None
 
     def enter_scope(
         self,
@@ -199,13 +210,17 @@ class AllocateConst(WithScopeHandler):
         else:
             raise Exception("Internal Bug")
 
-        def setup_buffer_var(data, dtype, shape, span: Span = None):
+        def setup_buffer(data, dtype, shape, span: Span = None):
             """Setup buffer var for a given type."""
-            buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
-            self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
+            self.buffer = tvm.tir.decl_buffer(
+                shape=shape,
+                dtype=dtype,
+                name=name,
+                span=span,
+            )
 
-        setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer_var, node)
+        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
+        context.update_symbol(name, self.buffer, node)
 
 
 @register
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
index 20161ad..d9c6dbd 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/tir/special_stmt.py
@@ -865,6 +865,60 @@ class FuncAttr(SpecialStmt):
 
 
 @register
+class PreflattenedBufferMap(SpecialStmt):
+    """Special Stmt for declaring the PrimFunc::preflattened_buffer_map
+
+    Example
+    -------
+    .. code-block:: python
+         T.preflattened_buffer_map({})
+    """
+
+    def __init__(self):
+        def preflattened_buffer(
+            postflattened,
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="global",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+
+            param = None
+            for key, value in self.context.func_buffer_map.items():
+                if value.same_as(postflattened):
+                    param = key
+
+            assert (
+                param is not None
+            ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."
+
+            buffer_name: str = f"{postflattened.name}_preflatten"
+            preflattened = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                buffer_name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+
+            self.context.func_preflattened_buffer_map[param] = preflattened
+
+        super().__init__(preflattened_buffer, def_symbol=False)
+
+
+@register
 class TargetAttrValue(SpecialStmt):
     """Special Stmt for target attr value.
     Example
diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py
index aaad6e1..4c4e223 100644
--- a/python/tvm/te/__init__.py
+++ b/python/tvm/te/__init__.py
@@ -27,7 +27,13 @@ from tvm.tir import isnan, isfinite, isinf
 from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
 from tvm.tir import comm_reducer, min, max, sum
 
-from .schedule import Schedule, Stage, create_schedule, SpecializedCondition
+from .schedule import (
+    Schedule,
+    Stage,
+    create_schedule,
+    SpecializedCondition,
+    AXIS_SEPARATOR,
+)
 from .tensor import TensorSlice, Tensor
 from .tensor_intrin import decl_tensor_intrin
 from .tag import tag_scope
diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py
index 55d07a5..fdd08f9 100644
--- a/python/tvm/te/schedule.py
+++ b/python/tvm/te/schedule.py
@@ -16,12 +16,16 @@
 # under the License.
 # pylint: disable=unused-import
 """The computation schedule api of TVM."""
+import collections
+import inspect
+from typing import Callable, List
+
 import tvm._ffi
 from tvm._ffi.base import string_types
 
 from tvm.runtime import Object, convert
 from tvm.ir import container as _container
-from tvm.tir import IterVar, Buffer
+from tvm.tir import IterVar, Buffer, Var
 
 from . import tensor as _tensor
 from . import _ffi_api
@@ -519,9 +523,149 @@ class Stage(Object):
         """
         _ffi_api.StageRollingBuffer(self)
 
+    def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr]]):
+        """Defines the layout transformation for the current stage's tensor.
+
+        The map from initial_indices to final_indices must be an
+        invertible affine transformation.  This method may be called
+        more than once for a given tensor, in which case each
+        transformation is applied sequentially.
+
+        If the stage is a ComputeOp, then the iteration order of the
+        compute stage is rewritten to be a row-major traversal of the
+        tensor, and the new loop iteration variables are returned.
+        For all other stages, the loop iteration order is unmodified,
+        and the return value is None.
+
+        Parameters
+        ----------
+        mapping_function : Callable[..., List[tvm.tir.PrimExpr]]
+
+            A callable that accepts N arguments of type tvm.tir.Var,
+            and outputs a list of PrimExpr.  The input arguments
+            represent the location of a value in the current stage's
+            tensor, using the pre-transformation layout.  The return
+            value of the function gives the location of that value in
+            the current stage's tensor, using the post-transformation
+            layout.
+
+        Returns
+        -------
+        new_iter_vars : Optional[List[tvm.tir.IterVar]]
+
+            If the stage is a ComputeOp, then the return will be the
+            updated loop iteration variables over the data array, in
+            the same order as the output values from the
+            `mapping_function`.
+
+            Otherwise, the return value is None.
+
+        Examples
+        --------
+        .. code-block:: python
+
+            # ``A`` is a tensor whose compute definition is in NHWC
+            # format, and should be transformed into NCHWc format.
+
+            s[A].transform_layout(
+                lambda n,h,w,c: [n, c//4, h, w, c%4]
+            )
+
+
+        .. code-block:: python
+
+            # ``A`` is a tensor whose compute definition is in an
+            # arbitrary format, and should be transformed such that
+            # the last index is split, with the slower-changing index
+            # of the split placed at the slowest changing dimension.
+
+            s[A].transform_layout(
+                lambda *indices, i: [i//4, *indices, i%4]
+            )
+
+        .. code-block:: python
+
+            # ``B`` is a tensor defined by te.compute to be a copy of
+            # ``A`, and should be transformed such that ``B``'s layout
+            # is a transpose of ``A``'s layout.  The loop iteration
+            # that computes ``B`` will correspond to ``B``'s memory
+            # layout.
+
+            A = te.placeholder([n,m])
+            B = te.compute(A.shape, lambda i,j: A[i,j])
+            s = te.create_schedule(B.op)
+
+            s[B].transform_layout(lambda i,j: [j,i])
+
+        """
+
+        args = []
+        var_arg_name = None
+        kwargs = collections.OrderedDict()
+        default_index_dtype = "int32"
+
+        # Make a dummy variable for each explicitly named input index.
+        # We may have some keyword-only arguments, if the function has
+        # *args before the last argument.
+        params = inspect.signature(mapping_function).parameters
+        for name, param in params.items():
+            if param.kind in [
+                inspect.Parameter.POSITIONAL_ONLY,
+                inspect.Parameter.POSITIONAL_OR_KEYWORD,
+            ]:
+                args.append(tvm.tir.Var(name, default_index_dtype))
+
+            elif param.kind == inspect.Parameter.VAR_POSITIONAL:
+                var_arg_name = name
+
+            elif param.kind == inspect.Parameter.KEYWORD_ONLY:
+                kwargs[name] = tvm.tir.Var(name, default_index_dtype)
+
+            elif param.kind in [inspect.Parameter.VAR_KEYWORD]:
+                raise ValueError("transform_layout mapping may not have **kwargs")
+
+        ndim = len(self.op.output(0).shape)
+
+        # Now that all the named arguments have been collected,
+        # everything that remains should go to the *args, if
+        # specified.
+        if var_arg_name is not None:
+            num_var_args = ndim - len(args) - len(kwargs)
+            for i in range(num_var_args):
+                args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", default_index_dtype))
+
+        initial_indices = args + list(kwargs.values())
+        if len(initial_indices) != ndim:
+            raise ValueError(
+                f"transform_layout mapping accepts {len(params)} initial indices, "
+                f"but {self.op.name} is {len(self.op.shape)}-dimensional"
+            )
+
+        mapping = mapping_function(*args, **kwargs)
+
+        final_indices = []
+        axis_separators = []
+        for val in mapping:
+            if isinstance(val, tvm.ir.PrimExpr):
+                final_indices.append(val)
+            elif val is AXIS_SEPARATOR:
+                axis_separators.append(len(final_indices))
+            else:
+                raise TypeError(
+                    "Expected mapping function to return list of "
+                    "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR.  "
+                    "Instead received {val} of type {type(val)}."
+                )
+
+        new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices)
+        _ffi_api.StageSetAxisSeparators(self, axis_separators)
+
+        return new_iter_vars or None
+
 
 @tvm._ffi.register_object
 class SpecializedCondition(Object):
+
     """Specialized condition to enable op specialization."""
 
     def __init__(self, conditions):
@@ -555,4 +699,10 @@ class SpecializedCondition(Object):
         _ffi_api.ExitSpecializationScope(self)
 
 
+# Sentinel value used to indicate which groups of pre-flattening axes
+# should be used to post-flattening axes axes.  See
+# Stage.transform_layout for more details.
+AXIS_SEPARATOR = "axis_separator"
+
+
 tvm._ffi._init_api("schedule", __name__)
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index 6dddd7b..e36a993 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -143,6 +143,33 @@ class Buffer(Object):
         """
         return _ffi_api.BufferStorageScope(self)  # type: ignore
 
+    def get_flattened_buffer(self):
+        """Generate a Buffer that is a flattened version of this buffer.
+
+        Returns
+        -------
+        flattened : Buffer
+            The corresponding flat buffer.
+        """
+        return _ffi_api.BufferGetFlattenedBuffer(self)  # type: ignore
+
+    def offset_of(self, indices):
+        """Determine the offset of the provided indices in the flattened buffer.
+
+        Parameters
+        ----------
+        indices : Union[PrimExpr, List[PrimExpr]]
+
+            The indices of the element in the original buffer.
+
+        Returns
+        -------
+        flattened_indices: List[PrimExpr]
+
+            The offset indices of the element in the flattened buffer.
+        """
+        return _ffi_api.BufferOffsetOf(self, indices)  # type: ignore
+
 
 def decl_buffer(
     shape,
@@ -155,6 +182,7 @@ def decl_buffer(
     data_alignment=-1,
     offset_factor=0,
     buffer_type="",
+    axis_separators=None,
     span=None,
 ):
     """Declare a new symbolic buffer.
@@ -204,6 +232,11 @@ def decl_buffer(
         without considering whether dimension size equals to one.
         TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
 
+    axis_separators : list of int, optional
+        If passed, a list of separators between groups of axes,
+        each of which is flattened to an output axis.  For flat
+        memory spaces, should either be None, or an empty list.
+
     span: Optional[Span]
         The location of the decl_buffer creation in the source.
 
@@ -254,6 +287,10 @@ def decl_buffer(
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
     dtype = "float32" if dtype is None else dtype
     strides = () if strides is None else strides
+
+    if axis_separators is None:
+        axis_separators = []
+
     if offset_factor != 0 and elem_offset is None:
         shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32"
         elem_offset = Var("%s_elem_offset" % name, shape_dtype)
@@ -272,6 +309,7 @@ def decl_buffer(
         data_alignment,
         offset_factor,
         buffer_type,
+        axis_separators,
         span,
     )
 
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index bcebab9..fdee18f 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -45,6 +45,9 @@ class PrimFunc(BaseFunc):
     buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer]
         The buffer binding map.
 
+    preflattened_buffer_map : Optional[Map[tvm.tir.Var, tvm.tir.Buffer]]
+        The buffer binding map, prior to any flattening.
+
     attrs: Optional[tvm.Attrs]
         Attributes of the function, can be None
 
@@ -52,9 +55,20 @@ class PrimFunc(BaseFunc):
         The location of this itervar in the source code.
     """
 
-    def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, span=None):
+    def __init__(
+        self,
+        params,
+        body,
+        ret_type=None,
+        buffer_map=None,
+        preflattened_buffer_map=None,
+        attrs=None,
+        span=None,
+    ):
+
         param_list = []
         buffer_map = {} if buffer_map is None else buffer_map
+        preflattened_buffer_map = {} if preflattened_buffer_map is None else preflattened_buffer_map
         for x in params:
             x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
             if isinstance(x, Buffer):
@@ -67,8 +81,15 @@ class PrimFunc(BaseFunc):
                 raise TypeError("params can only contain Var or Buffer")
 
         self.__init_handle_by_constructor__(
-            _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span  # type: ignore
-        )
+            _ffi_api.PrimFunc,
+            param_list,
+            body,
+            ret_type,
+            buffer_map,
+            preflattened_buffer_map,
+            attrs,
+            span,
+        )  # type: ignore
 
     def with_body(self, new_body, span=None):
         """Create a new PrimFunc with the same set signatures but a new body.
@@ -86,7 +107,15 @@ class PrimFunc(BaseFunc):
         new_func : PrimFunc
             The created new function.
         """
-        return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span)
+        return PrimFunc(
+            self.params,
+            new_body,
+            self.ret_type,
+            self.buffer_map,
+            self.preflattened_buffer_map,
+            self.attrs,
+            span,
+        )
 
     def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
         """Specialize parameters of PrimFunc
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index a71476b..334902b 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -15,12 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 """Developer API of IR node builder make function."""
+import tvm
 from tvm._ffi.base import string_types
-from tvm.runtime import ObjectGeneric, DataType, convert, const
-from tvm.ir import container as _container, PointerType, PrimType
+from tvm.runtime import ObjectGeneric, convert, const
+from tvm.ir import container as _container
 
 from . import stmt as _stmt
 from . import expr as _expr
+from . import buffer as _buffer
 from . import op
 
 
@@ -43,84 +45,77 @@ class BufferVar(ObjectGeneric):
 
     Do not create it directly, create use IRBuilder.
 
-    BufferVars support array access either via a linear index, or, if given a
-    shape, via a multidimensional index.
+    Array access through a BufferVar must use the same number of
+    indices as the underlying buffer was declared to have.
 
     Examples
     --------
     In the follow example, x is BufferVar.
-    :code:`x[0] = ...` directly emit a store to the IRBuilder,
-    :code:`x[10]` translates to Load.
+    :code:`x[0] = ...` directly emit a BufferStore to the IRBuilder,
+    :code:`x[10]` translates to BufferLoad.
 
     .. code-block:: python
 
-        # The following code generate IR for x[0] = x[
+        # The following code generate IR for x[0] = x[10] + 1
         ib = tvm.tir.ir_builder.create()
-        x = ib.pointer("float32")
+        x = ib.allocate("float32", 20)
         x[0] = x[10] + 1
 
+        # Array access using a multidimensional index
         y = ib.allocate("float32", (32, 32))
-        # Array access using a linear index
-        y[(2*32) + 31] = 0.
-        # The same array access using a multidimensional index
         y[2, 31] = 0.
 
     See Also
     --------
     IRBuilder.pointer
-    IRBuilder.buffer_ptr
     IRBuilder.allocate
+
     """
 
-    def __init__(self, builder, buffer_var, shape, content_type):
+    def __init__(self, builder, buffer, content_type):
         self._builder = builder
-        self._buffer_var = buffer_var
-        self._shape = shape
+        self._buffer = buffer
         self._content_type = content_type
 
     def asobject(self):
-        return self._buffer_var
+        return self._buffer
 
     @property
     def dtype(self):
         return self._content_type
 
-    def _linear_index(self, index):
-        if not isinstance(index, tuple) or self._shape is None:
-            return index
-        assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % (
-            len(index),
-            len(self._shape),
-        )
-        dim_size = 1
-        lidx = 0
-        for dim, idx in zip(reversed(self._shape), reversed(index)):
-            lidx += idx * dim_size
-            dim_size *= dim
-        return lidx
+    def _normalize_index(self, index):
+        try:
+            index = [*index]
+        except TypeError:
+            index = [index]
+
+        index = [x.var if isinstance(x, _expr.IterVar) else x for x in index]
+
+        # Workaround to support previous behavior of ir_builder
+        # indexing by a single index, treating the buffer as if were
+        # already flattened.
+        if len(index) == 1 and len(self._buffer.shape) != 1:
+            index = tvm.topi.utils.unravel_index(index[0], self._buffer.shape)
+
+        return index
 
     def __getitem__(self, index):
-        t = DataType(self._content_type)
-        index = self._linear_index(index)
-        if t.lanes > 1:
-            base = index * t.lanes
-            stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype)
-            index = _expr.Ramp(base, stride, t.lanes)
-        return _expr.Load(self._content_type, self._buffer_var, index)
+        index = self._normalize_index(index)
+        return _expr.BufferLoad(self._buffer, index)
 
     def __setitem__(self, index, value):
+        index = self._normalize_index(index)
+
         value = convert(value)
-        if value.dtype != self._content_type:
+        value_element = value.dtype.split("x", maxsplit=1)[0]
+        content_element = self._content_type.split("x", maxsplit=1)[0]
+        if value_element != content_element:
             raise ValueError(
                 "data type does not match content type %s vs %s" % (value.dtype, self._content_type)
             )
-        index = self._linear_index(index)
-        t = DataType(self._content_type)
-        if t.lanes > 1:
-            base = index * t.lanes
-            stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype)
-            index = _expr.Ramp(base, stride, t.lanes)
-        self._builder.emit(_stmt.Store(self._buffer_var, value, index))
+
+        self._builder.emit(_stmt.BufferStore(self._buffer, value, index))
 
 
 class IRBuilder(object):
@@ -394,7 +389,7 @@ class IRBuilder(object):
         self.emit(lambda x: _stmt.LetStmt(var, value, x))
         return var
 
-    def allocate(self, dtype, shape, name="buf", scope=""):
+    def allocate(self, dtype, shape, name="buf", axis_separators=None, scope=""):
         """Create a allocate statement.
 
         Parameters
@@ -408,6 +403,12 @@ class IRBuilder(object):
         name : str, optional
             The name of the buffer.
 
+        axis_separators : list of int, optional
+
+            If passed, a list of separators between groups of axes,
+            each of which is flattened to an output axis.  For flat
+            memory spaces, should either be None, or an empty list.
+
         scope : str, optional
             The scope of the buffer.
 
@@ -416,12 +417,18 @@ class IRBuilder(object):
         -------
         buffer : BufferVar
             The buffer var representing the buffer.
+
         """
-        buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope))
         if not isinstance(shape, (list, tuple, _container.Array)):
             shape = [shape]
+
+        buffer = _buffer.decl_buffer(
+            shape, dtype, name, scope=scope, axis_separators=axis_separators
+        )
+
+        buffer_var = buffer.data
         self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
-        return BufferVar(self, buffer_var, shape, dtype)
+        return BufferVar(self, buffer, dtype)
 
     def pointer(self, content_type, name="ptr", scope=""):
         """Create pointer variable with content type.
@@ -442,10 +449,10 @@ class IRBuilder(object):
         ptr : BufferVar
             The buffer var representing the buffer.
         """
-        buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope))
-        return BufferVar(self, buffer_var, None, content_type)
+        buffer = _buffer.decl_buffer(shape=[1], dtype=content_type, name=name, scope=scope)
+        return BufferVar(self, buffer, content_type)
 
-    def buffer_ptr(self, buf, shape=None):
+    def buffer_ptr(self, buf):
         """Create pointer variable corresponds to buffer ptr.
 
         Parameters
@@ -453,15 +460,12 @@ class IRBuilder(object):
         buf : Buffer
             The buffer to be extracted.
 
-        shape : Tuple
-            Optional shape of the buffer. Overrides existing buffer shape.
-
         Returns
         -------
         ptr : BufferVar
             The buffer var representing the buffer.
         """
-        return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype)
+        return BufferVar(self, buf, buf.dtype)
 
     def likely(self, expr):
         """Add likely tag for expression.
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 74e1f70..802fdc5 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -74,6 +74,19 @@ def InjectPrefetch():
     return _ffi_api.InjectPrefetch()  # type: ignore
 
 
+def ApplyLayoutTransforms():
+    """Reshape buffers that appear in the "layout_transform_map"
+    fucntion attribute.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+
+    """
+    return _ffi_api.ApplyLayoutTransforms()  # type: ignore
+
+
 def StorageFlatten(cache_line_size, create_bound_attribute: bool = False):
     """Flatten the multi-dimensional read/write to 1D.
 
@@ -784,7 +797,7 @@ def ExtractPrimFuncConstants():
     return _ffi_api.ExtractPrimFuncConstants()  # type: ignore
 
 
-def RenomalizeSplitPattern():
+def RenormalizeSplitPattern():
     """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
 
     Returns
diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index 32f20a1..8bfc803 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -149,7 +149,6 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr):
         warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
         m = data.shape[1]
         nb = w_indptr.shape[0] - 1
-        nnzb = w_data.shape[0]
         # treat csr like block size 1 bsr
         if len(w_data.shape) == 1:
             bs_n = 1
@@ -181,7 +180,7 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr):
 
         out_ptr = ib.buffer_ptr(out)
         data_ptr = ib.buffer_ptr(data)
-        w_data_ptr = ib.buffer_ptr(w_data, shape=(nnzb, bs_n, bs_k))
+        w_data_ptr = ib.buffer_ptr(w_data)
         w_indices_ptr = ib.buffer_ptr(w_indices)
         w_indptr_ptr = ib.buffer_ptr(w_indptr)
 
@@ -238,10 +237,11 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr):
             elem_idx = bb * rowlength_bi + tx
             with ib.for_range(0, bs_n, name="y", kind="unroll") as y:
                 with ib.for_range(0, bs_k, name="z", kind="unroll") as z:
-                    if use_warp_storage:
-                        w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z]
-                    else:
-                        w_data_cache[warp, tx, y, z] = w_data_ptr[row_start + elem_idx, y, z]
+                    data_indices = [row_start + elem_idx] + (
+                        [y, z] if len(w_data.shape) > 1 else []
+                    )
+                    cache_indices = [tx, y, z] if use_warp_storage else [warp, tx, y, z]
+                    w_data_cache[cache_indices] = w_data_ptr[data_indices]
             with ib.for_range(0, mi, name="i") as i:
                 # thread local block matmul
                 with ib.for_range(0, bs_m, name="x", kind="unroll") as x:
diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py
index 0e39a6c..af68ee9 100644
--- a/python/tvm/topi/utils.py
+++ b/python/tvm/topi/utils.py
@@ -311,9 +311,17 @@ def unravel_index(idx, shape):
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
     indices = []
-    for i in range(len(shape) - 1, -1, -1):
-        indices.append(idxm(idx, shape[i]))
-        idx = idxd(idx, shape[i])
+    for i, dim in enumerate(reversed(shape)):
+        if dim == 0:
+            indices.append(0)
+        elif i == len(shape) - 1:
+            # Assuming the index is in-bounds, the last coordinate is
+            # already less than dim, and doesn't need the be remainder
+            # mod dim.
+            indices.append(idx)
+        else:
+            indices.append(idxm(idx, dim))
+            idx = idxd(idx, dim)
     indices = indices[::-1]
     return indices
 
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index e11bd02..7320453 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -191,7 +191,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
     // truc div
     TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x);
     // floor div
-    TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x);
+    TVM_TRY_REWRITE(floordiv(x, y) * y + floormod(x, y), x);
+    TVM_TRY_REWRITE(y * floordiv(x, y) + floormod(x, y), x);
+    TVM_TRY_REWRITE(floormod(x, y) + floordiv(x, y) * y, x);
+    TVM_TRY_REWRITE(floormod(x, y) + y * floordiv(x, y), x);
+
     TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
                        c2.Eval()->value > 0);
 
diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc
index 59cac9c..17a05f0 100644
--- a/src/autotvm/feature_visitor.cc
+++ b/src/autotvm/feature_visitor.cc
@@ -97,14 +97,16 @@ void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
 }
 
 // memory access
-void FeatureVisitor::VisitExpr_(const LoadNode* op) {
-  EnterMem_(op->buffer_var, op->index);
+void FeatureVisitor::VisitExpr_(const BufferLoadNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers";
+  EnterMem_(op->buffer->data, op->indices[0]);
   StmtExprVisitor::VisitExpr_(op);
   ExitMem_();
 }
 
-void FeatureVisitor::VisitStmt_(const StoreNode* op) {
-  EnterMem_(op->buffer_var, op->index);
+void FeatureVisitor::VisitStmt_(const BufferStoreNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers";
+  EnterMem_(op->buffer->data, op->indices[0]);
   StmtExprVisitor::VisitStmt_(op);
   ExitMem_();
 }
diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h
index 8180839..3d34882 100644
--- a/src/autotvm/feature_visitor.h
+++ b/src/autotvm/feature_visitor.h
@@ -66,8 +66,8 @@ class FeatureVisitor : public StmtExprVisitor {
   void VisitStmt_(const AttrStmtNode* op) final;
 
   // memory access
-  void VisitExpr_(const LoadNode* op) final;
-  void VisitStmt_(const StoreNode* op) final;
+  void VisitExpr_(const BufferLoadNode* op) final;
+  void VisitStmt_(const BufferStoreNode* op) final;
 
   using StmtExprVisitor::VisitExpr_;
   using StmtExprVisitor::VisitStmt_;
diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc
index 5872a49..24c7ee7 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -274,6 +274,14 @@ void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLIN
 
 void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; }
 
+void CodeGenHybrid::VisitExpr_(const BufferLoadNode* op, std::ostream& os) {  // NOLINT(*)
+  LOG(FATAL) << "Phase 0 has no BufferLoad(s)!";
+}
+
+void CodeGenHybrid::VisitStmt_(const BufferStoreNode* op) {
+  LOG(FATAL) << "Phase 0 has no BufferStore(s)!";
+}
+
 void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) {  // NOLINT(*)
   LOG(FATAL) << "Phase 0 has no Let(s)!";
 }
diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h
index 47c13f7..da45ffb 100644
--- a/src/contrib/hybrid/codegen_hybrid.h
+++ b/src/contrib/hybrid/codegen_hybrid.h
@@ -89,6 +89,7 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   // expression
   void VisitExpr_(const VarNode* op, std::ostream& os) override;           // NOLINT(*)
   void VisitExpr_(const LoadNode* op, std::ostream& os) override;          // NOLINT(*)
+  void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override;    // NOLINT(*)
   void VisitExpr_(const LetNode* op, std::ostream& os) override;           // NOLINT(*)
   void VisitExpr_(const CallNode* op, std::ostream& os) override;          // NOLINT(*)
   void VisitExpr_(const ProducerLoadNode* op, std::ostream& os) override;  // NOLINT(*)
@@ -120,6 +121,7 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   // statment
   void VisitStmt_(const LetStmtNode* op) override;
   void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const BufferStoreNode* op) override;
   void VisitStmt_(const ProducerStoreNode* op) override;
   void VisitStmt_(const ForNode* op) override;
   void VisitStmt_(const IfThenElseNode* op) override;
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index e229da4..16d4772 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -223,6 +223,9 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
   if (!is_zero(buf->elem_offset)) {
     doc << ", elem_offset=" << Print(buf->elem_offset);
   }
+  if (buf->axis_separators.size()) {
+    doc << ", axis_separators=" << Print(buf->axis_separators);
+  }
   if (GetRef<Buffer>(buf).scope() != "global") {
     doc << ", scope=" << Doc::StrLiteral(GetRef<Buffer>(buf).scope());
   }
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index e1ccd2f..a6e5066 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -265,7 +265,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
   Doc PrintRange(const RangeNode* op);
   Doc PrintArray(const ArrayNode* op);
   Doc PrintBuffer(const BufferNode* op);
-  Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body);
+  Doc PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers);
   Doc AllocBufferDeclaration(const Buffer& buf);
   Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
   Doc PrintBlockVarRemaps();
@@ -912,16 +912,21 @@ Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_prece
 }
 
 Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
+  if (!buffer_var_usage_.count(op->var)) {
+    buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->var).value_or({});
+
   Doc doc;
   if (current_num_ != num_child_ - 1) {
     doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
-    doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
-                                         << PrintBody(op->body));
+    doc << Doc::Indent(
+        4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body));
   } else {
     if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
     doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
         << Doc::NewLine();
-    doc << PrintNonHeaderBufferDeclarations(op->var, op->body) << PrintBody(op->body);
+    doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body);
   }
   return doc;
 }
@@ -1008,8 +1013,59 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
   return Doc();
 }
 
+namespace {
+struct AllocUsage {
+  Buffer alloc_buffer;
+  Array<Buffer> aliasing_buffers;
+};
+
+template <typename AllocNode>
+AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
+  Map<Var, Array<Buffer>>& cache = *cache_ptr;
+  if (!cache.count(op->buffer_var)) {
+    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+  }
+  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
+
+  auto is_exact_match = [](Buffer a, Buffer b) {
+    if (a->dtype != b->dtype) return false;
+    if (a->shape.size() != b->shape.size()) return false;
+
+    arith::Analyzer analyzer;
+    for (size_t i = 0; i < a->shape.size(); i++) {
+      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
+        return false;
+      }
+    }
+    return true;
+  };
+
+  // If the buffer allocated via T.allocate is an exact match to the
+  // usage of the buffer later on, then that buffer is the return
+  // value of T.allocate, and no T.buffer_decl statement is needed.
+  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0,
+                      0, kDefault);
+  bool found_alloc_buf = false;
+  Array<Buffer> aliasing_buffers;
+  for (const auto& buf : buffer_usage) {
+    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
+      alloc_buffer = buf;
+      found_alloc_buf = true;
+    } else {
+      aliasing_buffers.push_back(buf);
+    }
+  }
+
+  return AllocUsage{alloc_buffer, aliasing_buffers};
+}
+}  // namespace
+
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  var_not_in_headers_.insert(op->buffer_var.get());
+  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
+  Buffer& alloc_buffer = usage.alloc_buffer;
+  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
+  buf_not_in_headers_.insert(alloc_buffer.get());
+  var_not_in_headers_.insert(alloc_buffer->data.get());
 
   auto storage_scope = GetPtrStorageScope(op->buffer_var);
   Doc func_call;
@@ -1027,13 +1083,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
 
   Doc doc;
   if (current_num_ != num_child_ - 1) {
-    doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
-    doc << Doc::Indent(4, Doc::NewLine()
-                              << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body)
-                              << PrintBody(op->body));
+    doc << "with " << func_call << " as " << Print(alloc_buffer) << ":";
+    doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers)
+                                         << PrintBody(op->body));
   } else {
-    doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
-    doc << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body);
+    doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine();
+    doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body);
   }
   TryDeallocVar(op->buffer_var);
   return doc;
@@ -1069,16 +1124,25 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
   }
   auto ndarray_str = ss.str();
 
+  auto usage = FindAllocateUsage(alloc, &buffer_var_usage_);
+  Buffer& alloc_buffer = usage.alloc_buffer;
+  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
+  buf_not_in_headers_.insert(alloc_buffer.get());
+  var_not_in_headers_.insert(alloc_buffer->data.get());
+
+  Doc func_call;
+  func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype)
+            << ", " << Print(alloc->extents) << ")";
+
   Doc doc;
   var_not_in_headers_.insert(alloc->buffer_var.get());
   if (current_num_ != num_child_ - 1) {
-    doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", "
-        << Print(alloc->extents) << ")";
-    doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body));
+    doc << "with " << func_call << " as " << Print(alloc_buffer) << ":";
+    doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers)
+                                         << PrintBody(alloc->body));
   } else {
-    doc << Print(alloc->buffer_var) << " = tir.allocate_const(" << ndarray_str << ", "
-        << PrintDType(alloc->dtype) << ", " << Print(alloc->extents);
-    doc << ")" << Doc::NewLine() << PrintBody(alloc->body);
+    doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine();
+    doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(alloc->body);
   }
   return doc;
 }
@@ -1465,9 +1529,30 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
     if (simple_buf.count(buf)) continue;
     buf_not_in_headers_.insert(buf.get());
     body << Print(buf) << " = " << tir_prefix_ << ".match_buffer(";
+    ICHECK(memo_buf_decl_.count(buf));
     body << Print((*it).first) << ", " << memo_buf_decl_[buf];
     body << ")" << Doc::NewLine();
   }
+  // print preflattened buffer map
+  for (const auto& param : op->params) {
+    auto pf_buf_it = op->preflattened_buffer_map.find(param);
+    if (pf_buf_it != op->preflattened_buffer_map.end()) {
+      const Buffer& preflattened = (*pf_buf_it).second;
+
+      auto buf_it = op->buffer_map.find(param);
+      ICHECK(buf_it != op->buffer_map.end()) << "Found pre-flattened buffer " << preflattened->name
+                                             << " with no corresponding post-flatten buffer.";
+      const Buffer& postflattened = (*buf_it).second;
+
+      // Call Print() without assigning in order to fill memo_buf_decl_.
+      Print(preflattened);
+      buf_not_in_headers_.insert(preflattened.get());
+      ICHECK(memo_buf_decl_.count(preflattened));
+
+      body << tir_prefix_ << ".preflattened_buffer(" << Print(postflattened) << ", "
+           << memo_buf_decl_.at(preflattened) << ")" << Doc::NewLine();
+    }
+  }
   // print body
   body << "# body" << Doc::NewLine();
   if (op->body->IsInstance<BlockRealizeNode>() &&
@@ -1586,13 +1671,9 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
   return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
 }
 
-Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) {
-  if (!buffer_var_usage_.count(buffer_var)) {
-    buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body);
-  }
-  Array<Buffer> buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({});
+Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers) {
   Doc decls;
-  for (const auto& buf_usage : buffer_usage) {
+  for (const auto& buf_usage : aliasing_buffers) {
     decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
           << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
     buf_not_in_headers_.insert(buf_usage.get());
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 3d2f0fc..0629ccd 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -423,15 +423,17 @@ class AOTExecutorCodegen : public MixedModeVisitor {
    */
   void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
     // Define intermediate DLTensor to load/store the data
-    auto tmp0 = te::Var("tmp0", DataType::Handle());
-    auto tmp1 = te::Var("tmp1", DataType::Handle());
+    tir::Buffer tmp_read =
+        tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read");
+    tir::Buffer tmp_write =
+        tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write");
     te::Var loop_idx("i", DataType::Int(32));
-    auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true());
+    auto retval_i = tir::BufferLoad(tmp_read, {loop_idx});
     // Copy the variable from the input to the output
     tir::Stmt copy =
         tir::For(loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial,
-                 tir::Store(tmp1, tir::Let(tmp0, in, retval_i), loop_idx, tir::const_true()));
-    stmts_.push_back(tir::LetStmt(tmp1, out, copy));
+                 tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx}));
+    stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy));
   }
 
   /*
@@ -689,7 +691,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});
 
     // Make the PrimFunc
-    return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_,
+    return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {},
                          DictAttrs(dict_attrs));
   }
 
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 530d649..46eacec 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -108,7 +108,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
     }
 
     tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map,
-                                   DictAttrs(dict_attrs));
+                                   Map<tir::Var, tir::Buffer>(), DictAttrs(dict_attrs));
     ir_module_->Add(global_var, replacement_func);
   }
 
diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
index 6794594..86f55ca 100644
--- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
+++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
@@ -52,8 +52,8 @@ class ConvertAddToSubtract : public MixedModeMutator {
   }
 
  private:
-  tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) {
-    return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true());
+  tir::BufferLoad LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) {
+    return tir::BufferLoad(buffer, {index});
   }
 
   void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) {
@@ -71,7 +71,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
 
     te::Var index("index", DataType::Int(32));
     tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
-    tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true());
+    tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index});
     tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);
 
     Map<tir::Var, tir::Buffer> buffer_map = {
@@ -81,7 +81,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
     };
 
     tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
-                                                   buffer_map, DictAttrs(dict_attrs));
+                                                   buffer_map, {}, DictAttrs(dict_attrs));
 
     // Switch to TIRToRuntime hook for testing
     Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index b5316f2..3f7da4e 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1571,11 +1571,15 @@ inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop,
                                 const te::Tensor& step, tvm::DataType dtype,
                                 std::string name = "T_arange_dynamic",
                                 std::string tag = topi::kInjective) {
+  ICHECK_EQ(start.ndim(), 0);
+  ICHECK_EQ(stop.ndim(), 0);
+  ICHECK_EQ(step.ndim(), 0);
   tvm::PrimExpr num_elem = tvm::tir::Var("num_elem");
   return te::compute(
       {num_elem},
       [&](const Array<tvm::tir::Var>& indices) {
-        return tvm::cast(dtype, start[0] + step[0] * indices[0]);
+        Array<PrimExpr> empty_indices;
+        return tvm::cast(dtype, start(empty_indices) + step(empty_indices) * indices[0]);
       },
       name, tag);
 }
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 3a8391e..a078cab 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -253,6 +253,7 @@ class ConstantFolder : public MixedModeMutator {
     // Use a fresh build context in case we are already in a build context.
     // needed for both execution and creation(due to JIT)
     With<transform::PassContext> fresh_build_ctx(transform::PassContext::Create());
+
     Map<String, ObjectRef> dict =
         (module_->attrs.defined()) ? module_->attrs->dict : Map<String, ObjectRef>();
     Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(),
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 6d9d980..ded346e 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -810,11 +810,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
   llvm::Value* arg_value = builder_->CreateInBoundsGEP(
       t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
       ConstInt32(begin));
-  TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
+  TypedPointer arg_tcode =
+      CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32));
   llvm::Value* ret_value = builder_->CreateInBoundsGEP(
       t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
       ConstInt32(end));
-  TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
+  TypedPointer ret_tcode =
+      CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32));
 
 #if TVM_LLVM_VERSION >= 90
   auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc
index 496c73a..3258703 100644
--- a/src/target/llvm/codegen_hexagon.cc
+++ b/src/target/llvm/codegen_hexagon.cc
@@ -319,11 +319,13 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const Array<Pri
   llvm::Value* arg_value = builder_->CreateInBoundsGEP(
       t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
       ConstInt32(begin));
-  TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
+  TypedPointer arg_tcode =
+      CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32));
   llvm::Value* ret_value = builder_->CreateInBoundsGEP(
       t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
       ConstInt32(end));
-  TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
+  TypedPointer ret_tcode =
+      CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32));
 
 #if TVM_LLVM_VERSION >= 90
   auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 0545d0b..cc2e495 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -437,6 +437,13 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
   if (auto* ptr = type.as<PrimTypeNode>()) {
     return DTypeToLLVMType(ptr->dtype);
   } else if (auto* ptr = type.as<PointerTypeNode>()) {
+    // LLVM IR doesn't allow void*, so we need to recognize this
+    // pattern explicitly.
+    if (auto* primtype = ptr->element_type.as<PrimTypeNode>()) {
+      if (primtype->dtype.is_void()) {
+        return t_void_p_;
+      }
+    }
     // TODO(tvm-team) consider put storage scope into the pointer type.
     return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace());
   } else if (IsVoidType(type)) {
@@ -781,17 +788,35 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) {
   return ptr;
 }
 
-CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer,
-                                                       llvm::Value* index) {
-  llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
-  ICHECK(btype != nullptr);
-  llvm::Type* llvm_type = DTypeToLLVMType(t);
-  llvm::PointerType* ttype = llvm_type->getPointerTo(btype->getAddressSpace());
-  if (btype != ttype) {
-    buffer = builder_->CreatePointerCast(buffer, ttype);
+CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr,
+                                                       DataType buffer_element_dtype,
+                                                       llvm::Value* index, DataType value_dtype) {
+  llvm::PointerType* buffer_ptr_type = llvm::dyn_cast<llvm::PointerType>(buffer_ptr->getType());
+  ICHECK(buffer_ptr_type != nullptr);
+  auto address_space = buffer_ptr_type->getAddressSpace();
+
+  llvm::Type* element_type = DTypeToLLVMType(buffer_element_dtype);
+  llvm::PointerType* element_ptr_type =
+      DTypeToLLVMType(buffer_element_dtype)->getPointerTo(address_space);
+  llvm::Type* value_type = DTypeToLLVMType(value_dtype);
+  llvm::PointerType* value_ptr_type = value_type->getPointerTo(address_space);
+
+  ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer";
+
+  if (buffer_ptr_type != element_ptr_type) {
+    buffer_ptr = builder_->CreatePointerCast(buffer_ptr, element_ptr_type);
   }
-  llvm::Value* ptr = builder_->CreateInBoundsGEP(llvm_type, buffer, index);
-  return TypedPointer(llvm_type, ptr);
+  ICHECK(!HasAlignmentPadding(buffer_element_dtype))
+      << "DType " << buffer_element_dtype
+      << " has padding for alignment.  TVM data arrays are expected to be densely packed, with no "
+         "padding for alignment.";
+  llvm::Value* value_ptr = builder_->CreateInBoundsGEP(element_type, buffer_ptr, index);
+
+  if (element_ptr_type != value_ptr_type) {
+    value_ptr = builder_->CreatePointerCast(value_ptr, value_ptr_type);
+  }
+
+  return TypedPointer(value_type, value_ptr);
 }
 
 llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const {
@@ -976,15 +1001,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
   } else if (op->op.same_as(builtin::tvm_storage_sync())) {
     return CreateStorageSync(op);
   } else if (op->op.same_as(builtin::address_of())) {
-    const LoadNode* l = op->args[0].as<LoadNode>();
-    ICHECK(op->args.size() == 1 && l);
-    TypedPointer buffer_ptr;
-    if (const RampNode* r = l->index.as<RampNode>()) {
-      PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes);
-      buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index));
-    } else {
-      buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index));
+    const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
+    ICHECK(op->args.size() == 1 && load);
+    ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations.";
+    PrimExpr index = load->indices[0];
+    if (const RampNode* r = index.as<RampNode>()) {
+      index = r->base;
     }
+    TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype,
+                                              MakeValue(index), load->dtype);
     unsigned addrspace =
         llvm::dyn_cast<llvm::PointerType>(buffer_ptr.addr->getType())->getAddressSpace();
     return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace));
@@ -1236,15 +1261,40 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
+  LOG(FATAL) << "Unexpected deprecated LoadNode.  Use BufferLoadNode instead.";
+  return NULL;
+}
+
+bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) {
+  const llvm::DataLayout& data_layout = module_->getDataLayout();
+  int bytes = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype));
+  int bytes_scalar = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype.element_of()));
+  return bytes != bytes_scalar * dtype.lanes();
+}
+
+llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";
+
   DataType t = op->dtype;
-  bool is_volatile = volatile_buf_.count(op->buffer_var.get());
-  llvm::Value* buffer = MakeValue(op->buffer_var);
-  llvm::Value* index = MakeValue(op->index);
+  DataType buffer_element_dtype = op->buffer->dtype;
+  Var buffer_var = op->buffer->data;
+  PrimExpr buffer_index = op->indices[0];
 
-  if (t.lanes() == 1) {
+  bool is_volatile = volatile_buf_.count(buffer_var.get());
+
+  if (t.lanes() == buffer_element_dtype.lanes()) {
     int alignment, native_bits;
-    GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
-    TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index);
+    GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits);
+
+    TypedPointer buffer_ptr;
+    if (HasAlignmentPadding(buffer_element_dtype)) {
+      buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(),
+                                   MakeValue(buffer_element_dtype.lanes() * buffer_index), t);
+    } else {
+      buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype,
+                                   MakeValue(buffer_index), t);
+    }
+
 #if TVM_LLVM_VERSION >= 110
     llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr,
                                                        llvm::Align(alignment), is_volatile);
@@ -1254,22 +1304,18 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
 #else
     llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile);
 #endif
-    AddAliasInfo(load, op->buffer_var.get(), op->index);
+    AddAliasInfo(load, buffer_var.get(), buffer_index);
     return load;
   } else {
     // vector load
-    if (const RampNode* ramp = op->index.as<RampNode>()) {
+    if (const RampNode* ramp = buffer_index.as<RampNode>()) {
       if (is_one(ramp->stride)) {
         int alignment, native_bits;
-        GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
-        ICHECK_EQ(ramp->lanes, t.lanes());
+        GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits);
+        ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), t.lanes());
         // The index argument is element-based, to create buffer pointer for t's element type.
-        TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base));
-        unsigned addrspace =
-            llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
-        buffer_ptr.type = DTypeToLLVMType(t);
-        buffer_ptr.addr =
-            builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace));
+        TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype,
+                                                  MakeValue(ramp->base), t);
 #if TVM_LLVM_VERSION >= 110
         llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr,
                                                            llvm::Align(alignment), is_volatile);
@@ -1279,7 +1325,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
 #else
         llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile);
 #endif
-        AddAliasInfo(load, op->buffer_var.get(), op->index);
+        AddAliasInfo(load, buffer_var.get(), buffer_index);
         return load;
       }
     }
@@ -1288,7 +1334,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
   int basic_align = t.bits() / 8;
   llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t));
   auto f = [&](int i, llvm::Value* index) {
-    TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index);
+    TypedPointer buffer_ptr =
+        CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, index, t.element_of());
 #if TVM_LLVM_VERSION >= 110
     llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr,
                                                        llvm::Align(basic_align), is_volatile);
@@ -1299,9 +1346,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
     llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile);
 #endif
     ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
-    AddAliasInfo(load, op->buffer_var.get(), PrimExpr());
+    AddAliasInfo(load, buffer_var.get(), PrimExpr());
   };
-  this->Scalarize(op->index, f);
+  this->Scalarize(buffer_index, f);
   return ret;
 }
 
@@ -1366,17 +1413,34 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
 }
 
 void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
-  ICHECK(is_one(op->predicate)) << op->predicate;
-  DataType t = op->value.dtype();
-  bool is_volatile = volatile_buf_.count(op->buffer_var.get());
-  llvm::Value* buffer = MakeValue(op->buffer_var);
-  llvm::Value* index = MakeValue(op->index);
+  LOG(FATAL) << "Unexpected deprecated StoreNode.  Use BufferStoreNode instead.";
+}
+
+void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";
+
+  DataType value_dtype = op->value.dtype();
+  DataType buffer_element_dtype = op->buffer->dtype;
+  Var buffer_var = op->buffer->data;
+  PrimExpr buffer_index = op->indices[0];
+
+  bool is_volatile = volatile_buf_.count(buffer_var.get());
+  llvm::Value* buffer = MakeValue(buffer_var);
   llvm::Value* value = MakeValue(op->value);
 
-  if (t.lanes() == 1) {
+  if (value_dtype.lanes() == buffer_element_dtype.lanes()) {
     int alignment, native_bits;
-    GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
-    TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index);
+    GetAlignment(value_dtype, buffer_var.get(), buffer_index, &alignment, &native_bits);
+
+    TypedPointer buffer_ptr;
+    if (HasAlignmentPadding(buffer_element_dtype)) {
+      buffer_ptr =
+          CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(),
+                          MakeValue(buffer_element_dtype.lanes() * buffer_index), value_dtype);
+    } else {
+      buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype,
+                                   MakeValue(buffer_index), value_dtype);
+    }
 #if TVM_LLVM_VERSION >= 110
     llvm::StoreInst* store =
         builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile);
@@ -1384,20 +1448,21 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
     llvm::StoreInst* store =
         builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile);
 #endif
-    AddAliasInfo(store, op->buffer_var.get(), op->index);
+    AddAliasInfo(store, buffer_var.get(), buffer_index);
     return;
   } else {
     // vector store
-    if (const RampNode* ramp = op->index.as<RampNode>()) {
+    if (const RampNode* ramp = buffer_index.as<RampNode>()) {
       if (is_one(ramp->stride)) {
         int alignment, native_bits;
-        GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
-        ICHECK_EQ(ramp->lanes, t.lanes());
+        GetAlignment(value_dtype, buffer_var.get(), ramp->base, &alignment, &native_bits);
+        ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), value_dtype.lanes());
         // The index argument is element-based, to create buffer pointer for t's element type.
-        TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base));
+        TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype,
+                                                  MakeValue(ramp->base), value_dtype);
         unsigned addrspace =
             llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
-        buffer_ptr.type = DTypeToLLVMType(t);
+        buffer_ptr.type = DTypeToLLVMType(value_dtype);
         buffer_ptr.addr =
             builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace));
 #if TVM_LLVM_VERSION >= 110
@@ -1407,16 +1472,17 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
         llvm::StoreInst* store =
             builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile);
 #endif
-        AddAliasInfo(store, op->buffer_var.get(), op->index);
+        AddAliasInfo(store, buffer_var.get(), buffer_index);
         return;
       }
     }
   }
-  ICHECK_GE(t.bits(), 8);
+  ICHECK_GE(value_dtype.bits(), 8);
   // scalarized store.
-  int basic_align = t.bits() / 8;
+  int basic_align = value_dtype.bits() / 8;
   auto f = [&](int i, llvm::Value* index) {
-    TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index);
+    TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype,
+                                              index, value_dtype.element_of());
 #if TVM_LLVM_VERSION >= 110
     llvm::StoreInst* store =
         builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr,
@@ -1425,9 +1491,9 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
     llvm::StoreInst* store = builder_->CreateAlignedStore(
         builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile);
 #endif
-    AddAliasInfo(store, op->buffer_var.get(), PrimExpr());
+    AddAliasInfo(store, buffer_var.get(), PrimExpr());
   };
-  this->Scalarize(op->index, f);
+  this->Scalarize(buffer_index, f);
 }
 
 void CodeGenLLVM::VisitStmt_(const ForNode* op) {
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 5431e92..e8cbe7a 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -171,12 +171,14 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
   llvm::Value* VisitExpr_(const SelectNode* op) override;
   llvm::Value* VisitExpr_(const LetNode* op) override;
   llvm::Value* VisitExpr_(const LoadNode* op) override;
+  llvm::Value* VisitExpr_(const BufferLoadNode* op) override;
   llvm::Value* VisitExpr_(const CallNode* op) override;
   llvm::Value* VisitExpr_(const RampNode* op) override;
   llvm::Value* VisitExpr_(const ShuffleNode* op) override;
   llvm::Value* VisitExpr_(const BroadcastNode* op) override;
   // stmt
   void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const BufferStoreNode* op) override;
   void VisitStmt_(const ForNode* op) override;
   void VisitStmt_(const WhileNode* op) override;
   void VisitStmt_(const IfThenElseNode* op) override;
@@ -319,6 +321,8 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
   // Get alignment given index.
   void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment,
                     int* p_native_bits);
+  // Returns whether the LLVM type has padding for alignment
+  bool HasAlignmentPadding(DataType dtype);
   // Get constant string
   llvm::Constant* GetConstString(const std::string& str);
   // do a scalarize call with f
@@ -338,7 +342,8 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
   llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
   llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
   llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
-  TypedPointer CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index);
+  TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
+                               llvm::Value* index, DataType value_dtype);
   // Vector concatenation.
   llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
   llvm::Value* CreateVecFlip(llvm::Value* vec);
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 01c1c91..1752c2a 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -159,78 +159,58 @@ void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src,
 }
 
 // Print a reference expression to a buffer.
-std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) {
+std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) {
+  const VarNode* buffer_var = buffer->data.get();
   std::ostringstream os;
-  std::string vid = GetVarID(buffer);
+  std::string vid = GetVarID(buffer_var);
   std::string scope;
-  if (alloc_storage_scope_.count(buffer)) {
-    scope = alloc_storage_scope_.at(buffer);
+  if (alloc_storage_scope_.count(buffer_var)) {
+    scope = alloc_storage_scope_.at(buffer_var);
   }
-  bool is_vol = IsVolatile(buffer);
-  if (t.lanes() == 1) {
-    if (!HandleTypeMatch(buffer, t) || is_vol) {
-      os << "((";
-      if (is_vol) {
-        os << "volatile ";
-      }
-      // Scope may not be part of type.
-      if (!scope.empty() && IsScopePartOfType()) {
-        PrintStorageScope(scope, os);
-      }
-      PrintType(t, os);
-      os << "*)" << vid << ')';
-    } else {
-      os << vid;
-    }
-    os << "[(";
-    PrintExpr(index, os);
-    os << ")";
-    if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
-      os << " / " << (32 / t.bits());
-    }
-    os << ']';
-  } else {
-    // Buffer declared as vector type.
-    // optimize for case where it is in register,
-    if (HandleTypeMatch(buffer, t) && !is_vol) {
-      // optimize for constant access
-      if (auto* ptr = index.as<tir::IntImmNode>()) {
-        int64_t offset = ptr->value;
-        ICHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type";
-        os << vid << '[' << (offset / t.lanes()) << ']';
-        return os.str();
-      }
-    }
-    os << "((";
+  bool is_vol = IsVolatile(buffer_var);
+
+  auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
+    std::ostringstream ptr_os;
+    ptr_os << "(";
     if (is_vol) {
-      os << "volatile ";
+      ptr_os << "volatile ";
     }
     if (!scope.empty() && IsScopePartOfType()) {
-      PrintStorageScope(scope, os);
-    }
-    PrintType(t, os);
-    os << "*)(";
-    if (!HandleTypeMatch(buffer, t.element_of())) {
-      os << '(';
-      if (!scope.empty() && IsScopePartOfType()) {
-        PrintStorageScope(scope, os);
-      }
-      PrintType(t.element_of(), os);
-      os << "*)";
-    }
-    if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
-      os << vid << ") + (";
-      PrintExpr(index, os);
-      os << ")";
-      os << " / " << t.lanes();
-      os << ")[0]";
-    } else {
-      os << vid << " + (";
-      PrintExpr(index, os);
-      os << ")";
-      os << "))[0]";
+      PrintStorageScope(scope, ptr_os);
     }
+    PrintType(pointed_to, ptr_os);
+    ptr_os << "*)";
+    return ptr_os.str();
+  };
+
+  DataType buffer_element_dtype = buffer->dtype;
+
+  std::string buffer_str = vid;
+  if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
+    std::stringstream temp;
+    temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
+    buffer_str = temp.str();
+  }
+
+  std::string index_str = PrintExpr(index);
+  if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
+    // This is a special case, because CodegenCUDA::PrintType()
+    // returns "int" for bool and for 4-bit integers. In most cases,
+    // we divide by the number of lanes to determine the index.
+    // However, the backing type for scalar int4 and scalar bool is
+    // int32.  Therefore, we need to divide by the ratio of their
+    // sizes in that case.
+    int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();
+
+    os << "*("
+       << "(" << ptr_cast(t) << vid << ")"
+       << " + " << index_str << " / " << div_factor << ")";
+  } else if (t == buffer_element_dtype) {
+    os << buffer_str << "[" << index_str << "]";
+  } else {
+    os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
   }
+
   return os.str();
 }
 
@@ -334,11 +314,11 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i,
   stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec;
 }
 
-std::string CodeGenC::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) {
+std::string CodeGenC::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) {
   return GetBufferRef(t, buffer, base);
 }
 
-void CodeGenC::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base,
+void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
                              const std::string& value) {
   std::string ref = GetBufferRef(t, buffer, base);
   this->PrintIndent();
@@ -586,17 +566,10 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
       PrintExpr(op->args[2], os);
       os << ")";
     } else if (op->op.same_as(builtin::address_of())) {
-      const LoadNode* l = op->args[0].as<LoadNode>();
-      ICHECK(op->args.size() == 1 && l);
-      os << "((";
-      this->PrintType(l->dtype.element_of(), os);
-      os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "
-         << "(";
-      this->PrintExpr(l->index, os);
-      if (l->dtype.bits() == 4 || (l->dtype.bits() == 1 && l->dtype.is_int())) {
-        os << " / " << (32 / l->dtype.bits());
-      }
-      os << "))";
+      const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
+      ICHECK(op->args.size() == 1 && load);
+      ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations.";
+      os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))";
     } else if (op->op.same_as(builtin::tvm_struct_get())) {
       ICHECK_EQ(op->args.size(), 3U);
       os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as<IntImmNode>()->value);
@@ -681,18 +654,27 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
 }
 
 void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
+  LOG(FATAL) << "Unexpected deprecated LoadNode.  Use BufferLoadNode instead.";
+}
+
+void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) {  // NOLINT(*)
+  ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported.";
+
+  DataType value_dtype = op->dtype;
+  PrimExpr index = op->indices[0];
+  Var buffer_var = op->buffer->data;
+  DataType element_dtype = op->buffer->dtype;
+
   int lanes = op->dtype.lanes();
   // delcare type.
-  if (op->dtype.lanes() == 1) {
-    std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index);
+  if (value_dtype.lanes() == element_dtype.lanes()) {
+    std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
     HandleVolatileLoads(ref, op, os);
   } else {
-    ICHECK(is_one(op->predicate)) << "predicated load is not supported";
-
     bool can_vector_load = false;
     arith::PVar<PrimExpr> base;
-    if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) {
-      const RampNode* ramp = op->index.as<RampNode>();
+    if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
+      const RampNode* ramp = index.as<RampNode>();
       ICHECK(ramp);
       arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
       // The condition: {k * coeff + base} divisible by the alignment for any k
@@ -702,19 +684,19 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
     }
 
     if (can_vector_load) {
-      std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval());
+      std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
       HandleVolatileLoads(ref, op, os);
     } else {
       std::ostringstream svalue_expr;
-      std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype());
-      std::string vid = GetVarID(op->buffer_var.get());
+      std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
+      std::string vid = GetVarID(buffer_var.get());
       DataType elem_type = op->dtype.element_of();
       for (int i = 0; i < lanes; ++i) {
         std::ostringstream value_temp;
-        if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
+        if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
           value_temp << "((";
-          if (op->buffer_var.get()->dtype.is_handle()) {
-            auto it = alloc_storage_scope_.find(op->buffer_var.get());
+          if (buffer_var.get()->dtype.is_handle()) {
+            auto it = alloc_storage_scope_.find(buffer_var.get());
             if (it != alloc_storage_scope_.end()) {
               PrintStorageScope(it->second, value_temp);
             }
@@ -725,7 +707,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
           value_temp << vid;
         }
         value_temp << '[';
-        PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp);
+        PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
         value_temp << ']';
         PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
       }
@@ -735,35 +717,44 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
 }
 
 void CodeGenC::VisitStmt_(const StoreNode* op) {
-  DataType t = op->value.dtype();
-  if (t.lanes() == 1) {
+  LOG(FATAL) << "Unexpected deprecated StoreNode.  Use BufferStoreNode instead.";
+}
+
+void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
+
+  DataType value_dtype = op->value.dtype();
+  DataType element_dtype = op->buffer->dtype;
+  PrimExpr index_expr = op->indices[0];
+  Var buffer_var = op->buffer->data;
+
+  if (value_dtype.lanes() == element_dtype.lanes()) {
     std::string value = this->PrintExpr(op->value);
-    std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index);
+    std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr);
     this->PrintIndent();
     stream << ref << " = " << value << ";\n";
   } else {
-    ICHECK(is_one(op->predicate)) << "Predicated store is not supported";
     arith::PVar<PrimExpr> base;
 
-    if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
+    if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) {
       std::string value = this->PrintExpr(op->value);
-      this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
+      this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
     } else {
       // The assignment below introduces side-effect, and the resulting value cannot
       // be reused across multiple expression, thus a new scope is needed
       int vec_scope = BeginScope();
 
       // store elements seperately
-      std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype());
+      std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype());
       std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
-      std::string vid = GetVarID(op->buffer_var.get());
-      for (int i = 0; i < t.lanes(); ++i) {
+      std::string vid = GetVarID(buffer_var.get());
+      for (int i = 0; i < value_dtype.lanes(); ++i) {
         this->PrintIndent();
-        DataType elem_type = t.element_of();
-        if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
+        DataType elem_type = value_dtype.element_of();
+        if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
           stream << "((";
-          if (op->buffer_var.get()->dtype.is_handle()) {
-            auto it = alloc_storage_scope_.find(op->buffer_var.get());
+          if (buffer_var.get()->dtype.is_handle()) {
+            auto it = alloc_storage_scope_.find(buffer_var.get());
             if (it != alloc_storage_scope_.end()) {
               PrintStorageScope(it->second, stream);
             }
@@ -774,7 +765,7 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
           stream << vid;
         }
         stream << '[';
-        PrintVecElemLoad(index, op->index.dtype(), i, stream);
+        PrintVecElemLoad(index, index_expr.dtype(), i, stream);
         stream << "] = ";
         PrintVecElemLoad(value, op->value.dtype(), i, stream);
         stream << ";\n";
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 2af77bb..4f67195 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -126,6 +126,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   // expression
   void VisitExpr_(const VarNode* op, std::ostream& os) override;        // NOLINT(*)
   void VisitExpr_(const LoadNode* op, std::ostream& os) override;       // NOLINT(*)
+  void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const LetNode* op, std::ostream& os) override;        // NOLINT(*)
   void VisitExpr_(const CallNode* op, std::ostream& os) override;       // NOLINT(*)
   void VisitExpr_(const AddNode* op, std::ostream& os) override;        // NOLINT(*)
@@ -155,6 +156,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   // statment
   void VisitStmt_(const LetStmtNode* op) override;
   void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const BufferStoreNode* op) override;
   void VisitStmt_(const ForNode* op) override;
   void VisitStmt_(const WhileNode* op) override;
   void VisitStmt_(const IfThenElseNode* op) override;
@@ -176,9 +178,9 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs,
                                 std::ostream& os);  // NOLINT(*)
   // print vector load
-  virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base);
+  virtual std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base);
   // print vector store
-  virtual void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base,
+  virtual void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
                              const std::string& value);  // NOLINT(*)
   // print load of single element
   virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i,
@@ -201,7 +203,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   // Print reference to struct location
   std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
   // Print reference to a buffer as type t in index.
-  virtual std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index);
+  virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index);
 
   /*!
    * \brief Handle volatile loads.
@@ -211,7 +213,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
    * does not implement volatile member functions. CUDA codegen will cast
    * away volatile qualifier from CUDA __half types.
    */
-  virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) {
+  virtual void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
+                                   std::ostream& os) {
     // By default, do nothing but print the loaded value.
     os << value;
   }
diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc
index 7ddea46..db23c01 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -97,6 +97,10 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
     os << "void*";
     return;
   }
+  if (t.is_void()) {
+    os << "void";
+    return;
+  }
   if (t == DataType::Bool()) {
     os << "bool";
     return;
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index 984f8a1..0dda079 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -171,6 +171,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
     os << "void*";
     return;
   }
+
+  if (t.is_void()) {
+    os << "void";
+    return;
+  }
+
   bool fail = false;
   if (t.is_float()) {
     switch (t.bits()) {
@@ -1115,12 +1121,12 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode
   return 0;
 }
 
-void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op,
+void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
                                       std::ostream& os) {
   // Cast away volatile qualifier for fp16 types. That is, only loads and
   // stores are volatile. The loaded objects are not marked as volatile.
   //
-  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer_var.get())) {
+  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) {
     os << "(";
     PrintType(op->dtype, os);
     os << ")(" << value << ")";
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 385b734..673753c 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -76,7 +76,8 @@ class CodeGenCUDA final : public CodeGenC {
 
  private:
   // Handle volatile loads
-  void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final;
+  void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
+                           std::ostream& os) final;
 
   // Whether scope such as "__shared__" or "__constant__"  is part of type.
   bool IsScopePartOfType() const final { return false; }
diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index b44afec..a76da36 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -177,6 +177,11 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
     os << "void*";
     return;
   }
+
+  if (t.is_void()) {
+    os << "void";
+    return;
+  }
   if (t == DataType::Bool()) {
     os << "bool";
     return;
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index a9cd9d8..a0e19ca 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -174,6 +174,10 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
     os << "void*";
     return;
   }
+  if (t.is_void()) {
+    os << "void";
+    return;
+  }
   if (t == DataType::Bool()) {
     os << "bool";
     return;
@@ -256,21 +260,22 @@ void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) {  // NOLINT(*
   }
 }
 
-void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base,
+void CodeGenOpenCL::PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base,
                                  std::ostream& os) {  // NOLINT(*)
-  if (!HandleTypeMatch(buffer, t.element_of())) {
+  const VarNode* buffer_var = buffer->data.get();
+  if (!HandleTypeMatch(buffer_var, t.element_of())) {
     os << '(';
-    auto it = alloc_storage_scope_.find(buffer);
+    auto it = alloc_storage_scope_.find(buffer_var);
     if (it != alloc_storage_scope_.end()) {
       PrintStorageScope(it->second, os);
     }
     PrintType(t.element_of(), os);
     os << "*)";
   }
-  os << GetVarID(buffer) << " + ";
+  os << GetVarID(buffer_var) << " + ";
   PrintExpr(base, os);
 }
-std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) {
+std::string CodeGenOpenCL::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) {
   std::ostringstream os;
   os << "vload" << t.lanes() << "(0, ";
   PrintVecAddr(buffer, t, base, os);
@@ -278,7 +283,7 @@ std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExp
   return os.str();
 }
 
-void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base,
+void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
                                   const std::string& value) {
   this->PrintIndent();
   stream << "vstore" << t.lanes() << "(" << value << ", 0, ";
@@ -337,13 +342,17 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType
 }
 
 void CodeGenOpenCL::VisitStmt_(const StoreNode* op) {
+  LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+}
+
+void CodeGenOpenCL::VisitStmt_(const BufferStoreNode* op) {
   if (auto call = op->value.as<CallNode>()) {
     if (call->op.same_as(builtin::texture2d_load())) {
       need_texture_ssa_ = false;
       // If storing a texture load into a buffer, don't use an
       // intermediate local unless the buffer allocation is a
       // single element selected from the texture read.
-      auto it = allocation_size_.find(op->buffer_var.get());
+      auto it = allocation_size_.find(op->buffer->data.get());
       if (it != allocation_size_.end() && it->second == 1) {
         need_texture_ssa_ = true;
       }
@@ -371,16 +380,17 @@ void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) {
 void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) {
   if (op->op.same_as(builtin::address_of())) {
     // Overload tvm_address_of to add storage scope (e.g. __global).
-    const LoadNode* load = op->args[0].as<LoadNode>();
+    const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
     ICHECK(op->args.size() == 1 && load);
+    ICHECK_EQ(load->indices.size(), 0) << "CodeGenOpenCL only supports flat memory allocations.";
     os << "((";
-    auto it = alloc_storage_scope_.find(load->buffer_var.get());
+    auto it = alloc_storage_scope_.find(load->buffer->data.get());
     if (it != alloc_storage_scope_.end()) {
       PrintStorageScope(it->second, os);
     }
     this->PrintType(load->dtype.element_of(), os);
-    os << " *)" << this->GetVarID(load->buffer_var.get()) << " + ";
-    this->PrintExpr(load->index, os);
+    os << " *)" << this->GetVarID(load->buffer->data.get()) << " + ";
+    this->PrintExpr(load->indices[0], os);
     os << ')';
   } else if (op->op.same_as(builtin::texture2d_store())) {
     auto* ptr_type = op->args[0].as<VarNode>()->type_annotation.as<PointerTypeNode>();
diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h
index 2670c60..3508eef 100644
--- a/src/target/source/codegen_opencl.h
+++ b/src/target/source/codegen_opencl.h
@@ -47,11 +47,11 @@ class CodeGenOpenCL final : public CodeGenC {
   void PrintStorageSync(const CallNode* op) final;                           // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final;                        // NOLINT(*)
   void PrintType(const Type& type, std::ostream& os) final;                  // NOLINT(*)
-  std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final;
-  void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base,
+  std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final;
+  void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
                      const std::string& value) final;  // NOLINT(*)
   // the address of load/store
-  void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base,
+  void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base,
                     std::ostream& os);                                           // NOLINT(*)
   void PrintRestrict(const Var& v, std::ostream& os) final;                      // NOLINT(*)
   std::string CastFromTo(std::string value, DataType from, DataType target);     // NOLINT(*)
@@ -64,6 +64,7 @@ class CodeGenOpenCL final : public CodeGenC {
   void VisitExpr_(const CastNode* op, std::ostream& os) final;       // NOLINT(*)
   void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;   // NOLINT(*)
   void VisitStmt_(const StoreNode* op) final;                        // NOLINT(*)
+  void VisitStmt_(const BufferStoreNode* op) final;                  // NOLINT(*)
 
   // overload min and max to avoid ambiguous call errors
   void VisitExpr_(const MinNode* op, std::ostream& os) final;
diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc
index 5dcf158..5acb420 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -119,6 +119,10 @@ void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) {  // NOLINT(
     os << "void*";
     return;
   }
+  if (type.is_void()) {
+    os << "void";
+    return;
+  }
   if (type.is_float()) {
     if (type.bits() == 32) {
       os << "float";
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index 1d30b9b..0427d8c 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -412,22 +412,23 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
   return builder_->Concat(values);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
-  ICHECK(is_one(op->predicate));
+spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers";
+  Var buffer_var = op->buffer->data;
+  PrimExpr prim_index = op->indices[0];
 
   DataType desired_read_type = op->dtype;
   if (desired_read_type == DataType::Bool()) {
     desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes());
   }
 
-  const VarNode* buffer_var = op->buffer_var.get();
-  auto it = storage_info_.find(buffer_var);
+  auto it = storage_info_.find(buffer_var.get());
   ICHECK(it != storage_info_.end());
   StorageInfo& info = it->second;
-  info.CheckContentType(desired_read_type, op->index.dtype().lanes());
+  info.CheckContentType(desired_read_type, prim_index.dtype().lanes());
 
   spirv::SType content_type = builder_->GetSType(info.element_type);
-  spirv::Value buffer = MakeValue(op->buffer_var);
+  spirv::Value buffer = MakeValue(buffer_var);
   spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
 
   uint32_t mask = spv::MemoryAccessMaskNone;
@@ -438,7 +439,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
   if (desired_read_type == info.element_type) {
     // Requested a single value from an array.  This may be a scalar load
     // or a vectorized load, based on the array element type.
-    spirv::Value index = MakeValue(op->index);
+    spirv::Value index = MakeValue(prim_index);
     spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
     spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
     // OpTypeBool have no physical address/storage.  Here, cast from
@@ -457,13 +458,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
       spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
       values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
     };
-    this->Scalarize(op->index, f);
+    this->Scalarize(prim_index, f);
     return builder_->Concat(values);
 
   } else {
     LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint
                << "' with element type " << info.element_type << " using index of type "
-               << op->index->dtype << " to produce output of type " << op->dtype;
+               << prim_index->dtype << " to produce output of type " << op->dtype;
     return spirv::Value();
   }
 }
@@ -483,15 +484,18 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv:
   }
 }
 
-void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
-  ICHECK(is_one(op->predicate));
-  auto it = storage_info_.find(op->buffer_var.get());
+void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers";
+  Var buffer_var = op->buffer->data;
+  PrimExpr prim_index = op->indices[0];
+
+  auto it = storage_info_.find(buffer_var.get());
   ICHECK(it != storage_info_.end());
   StorageInfo& info = it->second;
-  info.CheckContentType(op->value.dtype(), op->index.dtype().lanes());
+  info.CheckContentType(op->value.dtype(), prim_index.dtype().lanes());
 
   spirv::SType content_type = builder_->GetSType(info.element_type);
-  spirv::Value buffer = MakeValue(op->buffer_var);
+  spirv::Value buffer = MakeValue(buffer_var);
   spirv::Value value = MakeValue(op->value);
   spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
 
@@ -505,7 +509,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
     // or a vectorized store, based on the array element type.
     ICHECK_EQ(info.element_type, op->value.dtype())
         << "Vulkan only allow one type access to the same buffer";
-    spirv::Value index = MakeValue(op->index);
+    spirv::Value index = MakeValue(prim_index);
     spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
     builder_->MakeInst(spv::OpStore, ptr, value, mask);
 
@@ -517,12 +521,12 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
       spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
       builder_->MakeInst(spv::OpStore, ptr, elem, mask);
     };
-    this->Scalarize(op->index, f);
+    this->Scalarize(prim_index, f);
 
   } else {
     LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '"
-               << op->buffer_var->name_hint << "' with element type " << info.element_type
-               << " using index of type " << op->index->dtype;
+               << buffer_var->name_hint << "' with element type " << info.element_type
+               << " using index of type " << prim_index->dtype;
   }
 }
 
diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h
index 74b62e7..08b9db0 100644
--- a/src/target/spirv/codegen_spirv.h
+++ b/src/target/spirv/codegen_spirv.h
@@ -100,9 +100,9 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
   spirv::Value VisitExpr_(const CallNode* op) override;
   spirv::Value VisitExpr_(const RampNode* op) override;
   spirv::Value VisitExpr_(const BroadcastNode* op) override;
-  spirv::Value VisitExpr_(const LoadNode* op) override;
+  spirv::Value VisitExpr_(const BufferLoadNode* op) override;
   // stmt
-  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const BufferStoreNode* op) override;
   void VisitStmt_(const ForNode* op) override;
   void VisitStmt_(const WhileNode* op) override;
   void VisitStmt_(const IfThenElseNode* op) override;
diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc
index 402e329..e704054 100644
--- a/src/target/stackvm/codegen_stackvm.cc
+++ b/src/target/stackvm/codegen_stackvm.cc
@@ -140,12 +140,21 @@ int CodeGenStackVM::GetVarID(const VarNode* v) const {
 }
 
 void CodeGenStackVM::VisitExpr_(const LoadNode* op) {
-  this->Push(op->buffer_var);
+  LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+}
+
+void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers.  "
+                                   << "Has StorageFlatten (TE-based schedules) or "
+                                   << "FlattenBuffer (TIR-based schedules) been run?";
+  auto index = op->indices[0];
+
+  this->Push(op->buffer->data);
   StackVM::OpCode code = StackVM::GetLoad(op->dtype);
-  if (const IntImmNode* index = op->index.as<IntImmNode>()) {
-    this->PushOp(code, index->value);
+  if (const IntImmNode* int_index = index.as<IntImmNode>()) {
+    this->PushOp(code, int_index->value);
   } else {
-    this->Push(op->index);
+    this->Push(index);
     this->PushOp(StackVM::PUSH_I64, op->dtype.element_of().bytes());
     this->PushOp(StackVM::MUL_I64);
     this->PushOp(StackVM::ADDR_ADD);
@@ -154,13 +163,22 @@ void CodeGenStackVM::VisitExpr_(const LoadNode* op) {
 }
 
 void CodeGenStackVM::VisitStmt_(const StoreNode* op) {
-  this->Push(op->buffer_var);
+  LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+}
+
+void CodeGenStackVM::VisitStmt_(const BufferStoreNode* op) {
+  ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers.  "
+                                   << "Has StorageFlatten (TE-based schedules) or "
+                                   << "FlattenBuffer (TIR-based schedules) been run?";
+  auto index = op->indices[0];
+
+  this->Push(op->buffer->data);
   StackVM::OpCode code = StackVM::GetStore(op->value.dtype());
-  if (const IntImmNode* index = op->index.as<IntImmNode>()) {
+  if (const IntImmNode* int_index = index.as<IntImmNode>()) {
     this->Push(op->value);
-    this->PushOp(code, index->value);
+    this->PushOp(code, int_index->value);
   } else {
-    this->Push(op->index);
+    this->Push(index);
     this->PushOp(StackVM::PUSH_I64, op->value.dtype().element_of().bytes());
     this->PushOp(StackVM::MUL_I64);
     this->PushOp(StackVM::ADDR_ADD);
@@ -175,11 +193,13 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) {
 
 void CodeGenStackVM::VisitExpr_(const CallNode* op) {
   if (op->op.same_as(builtin::address_of())) {
-    const LoadNode* l = op->args[0].as<LoadNode>();
-    ICHECK(op->args.size() == 1 && l);
-    this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
-    this->Push(l->index);
-    this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes());
+    const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
+    ICHECK(op->args.size() == 1 && load);
+    ICHECK_EQ(load->indices.size(), 0) << "CodeGenStackVM only supports flat memory allocations.";
+
+    this->PushOp(StackVM::LOAD_HEAP, GetVarID(load->buffer->data.get()));
+    this->Push(load->indices[0]);
+    this->PushOp(StackVM::PUSH_I64, load->dtype.element_of().bytes());
     this->PushOp(StackVM::MUL_I64);
     this->PushOp(StackVM::ADDR_ADD);
   } else if (op->op.same_as(builtin::reinterpret())) {
diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h
index 480ffc7..ae6f316 100644
--- a/src/target/stackvm/codegen_stackvm.h
+++ b/src/target/stackvm/codegen_stackvm.h
@@ -108,6 +108,7 @@ class CodeGenStackVM : public ExprFunctor<void(const PrimExpr&)>,
   // expression
   void VisitExpr_(const VarNode* op) final;
   void VisitExpr_(const LoadNode* op) final;
+  void VisitExpr_(const BufferLoadNode* op) final;
   void VisitExpr_(const LetNode* op) final;
   void VisitExpr_(const CallNode* op) final;
   void VisitExpr_(const AddNode* op) final;
@@ -136,6 +137,7 @@ class CodeGenStackVM : public ExprFunctor<void(const PrimExpr&)>,
   // statment
   void VisitStmt_(const LetStmtNode* op) final;
   void VisitStmt_(const StoreNode* op) final;
+  void VisitStmt_(const BufferStoreNode* op) final;
   void VisitStmt_(const ForNode* op) final;
   void VisitStmt_(const IfThenElseNode* op) final;
   void VisitStmt_(const AllocateNode* op) final;
diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc
index 2ed5fd4..e419377 100644
--- a/src/te/operation/cross_thread_reduction.cc
+++ b/src/te/operation/cross_thread_reduction.cc
@@ -134,29 +134,25 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
 
   // If we load from and then store into the same res_handles in the thread_allreduce intrinsic,
   // something goes wrong, so we use an extra variable here for normal reduction.
-  std::vector<Var> normal_res_handles;
+  std::vector<Buffer> normal_res_buffers;
   std::vector<Stmt> normal_init, normal_update;
   if (!normal_red.empty()) {
-    normal_res_handles.reserve(size);
+    normal_res_buffers.reserve(size);
     normal_init.reserve(size);
     normal_update.resize(size);
     const CommReducerNode* combiner = reduces[0]->combiner.as<CommReducerNode>();
     ICHECK(combiner);
     Array<PrimExpr> lhs;
     for (size_t i = 0; i < size; ++i) {
-      DataType t = reduces[i]->dtype;
-      normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i),
-                                      PointerType(PrimType(t), "local"));
-      lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes())));
+      normal_res_buffers.push_back(
+          decl_buffer({1}, reduces[i]->dtype, "normal_reduce_temp" + std::to_string(i), "local"));
+      lhs.push_back(BufferLoad(normal_res_buffers[i], {0}));
     }
     Array<PrimExpr> init_value = combiner->identity_element;
     Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source);
     for (size_t i = 0; i < size; ++i) {
-      DataType t = reduces[i]->dtype;
-      normal_init.emplace_back(
-          Store(normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
-      normal_update.emplace_back(
-          Store(normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
+      normal_init.emplace_back(BufferStore(normal_res_buffers[i], init_value[i], {0}));
+      normal_update.emplace_back(BufferStore(normal_res_buffers[i], update_value[i], {0}));
     }
   }
 
@@ -164,8 +160,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
   for (size_t i = 0; i < size; ++i) {
     if (!normal_red.empty()) {
-      DataType t = reduces[i]->dtype;
-      freduce_args.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes())));
+      freduce_args.push_back(BufferLoad(normal_res_buffers[i], {0}));
     } else {
       freduce_args.push_back(reduces[0]->source[i]);
     }
@@ -174,12 +169,15 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   // No constraints on the thread reduction step. It may have redundent
   // computation for rare cases. TODO(tvm-team): revisit this.
   freduce_args.push_back(const_true(1));
-  std::vector<Var> res_handles(size);
+  std::vector<Buffer> res_buffers(size);
   for (size_t idx = 0; idx < size; ++idx) {
-    DataType dtype = reduces[idx]->dtype;
-    res_handles[idx] =
-        Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local"));
-    freduce_args.push_back(res_handles[idx]);
+    res_buffers[idx] =
+        decl_buffer({1}, reduces[idx]->dtype, "reduce_temp" + std::to_string(idx), "local");
+    // Make a BufferLoad object so that we can pass the entire Buffer
+    // object through to LowerThreadAllreduce.  The index here is
+    // unused.
+    PrimExpr dummy_load = BufferLoad(res_buffers[idx], {0});
+    freduce_args.push_back(dummy_load);
   }
 
   for (IterVar iv : stage->leaf_iter_vars) {
@@ -216,18 +214,18 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
 
   std::vector<Stmt> assigns(size);
   for (size_t idx = 0; idx < size; ++idx) {
-    DataType t = reduces[idx]->dtype;
-    assigns[idx] = ProducerStore(stage->op.output(idx),
-                                 Load(t, res_handles[idx], 0, const_true(t.lanes())), args);
+    assigns[idx] = ProducerStore(stage->op.output(idx), BufferLoad(res_buffers[idx], {0}), args);
   }
   Stmt assign_body = SeqStmt::Flatten(assigns);
   assign_body = MergeNest(MakeIfNest(output_preds), assign_body);
   Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
   for (size_t idx = size; idx != 0; --idx) {
-    body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+    const auto& res_buffer = res_buffers[idx - 1];
+    body = Allocate(res_buffer->data, res_buffer->dtype, res_buffer->shape, const_true(), body);
     if (!normal_red.empty()) {
-      body =
-          Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+      const auto& normal_res_buffer = normal_res_buffers[idx - 1];
+      body = Allocate(normal_res_buffer->data, normal_res_buffer->dtype, normal_res_buffer->shape,
+                      const_true(), body);
     }
   }
   body = Substitute(body, value_map);
diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc
index d45f29e..b1056ac 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -79,6 +79,22 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>*
     } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
       state[s->parent] = state[s->rebased];
     } else if (rel.as<SingletonNode>()) {
+    } else if (const TransformNode* s = rel.as<TransformNode>()) {
+      // Currently, this marks all original iter vars as deriving from
+      // a thread bind if any of the transformed variables are bound,
+      // even if the inverse expression for that iter var doesn't
+      // depend on the bound variable.
+
+      // TODO(Lunderberg): For each of original variable, check
+      // whether any variable in the inverse expression for it has a
+      // thread binding.
+      bool is_thread_binding = false;
+      for (const auto& iter_var : s->transformed_variables) {
+        is_thread_binding = is_thread_binding || state[iter_var];
+      }
+      for (const auto& iter_var : s->original_variables) {
+        state[iter_var] = is_thread_binding;
+      }
     } else {
       LOG(FATAL) << "unknown relation type";
     }
@@ -157,6 +173,29 @@ void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_st
       Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx);
     } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
       Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx);
+    } else if (const TransformNode* s = rel.as<TransformNode>()) {
+      bool missing_originals = false;
+      for (const auto& iter_var : s->original_variables) {
+        if (!state.count(iter_var)) {
+          ICHECK(allow_missing);
+          missing_originals = true;
+        }
+      }
+      if (missing_originals) {
+        continue;
+      }
+
+      Array<Range> original_ranges;
+      for (const auto& iter_var : s->original_variables) {
+        original_ranges.push_back(state[iter_var]);
+      }
+      Array<Range> updated_ranges = s->forward_transformation->MapRanges(original_ranges);
+
+      ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size());
+      for (size_t i = 0; i < updated_ranges.size(); i++) {
+        Update(p_state, s->transformed_variables[i], updated_ranges[i], actx);
+      }
+
     } else {
       LOG(FATAL) << "unknown relation type";
     }
@@ -225,6 +264,29 @@ void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
         state[s->parent] = value;
       }
     } else if (rel.as<SingletonNode>()) {
+    } else if (const TransformNode* s = rel.as<TransformNode>()) {
+      bool missing_transformed = false;
+      for (const auto& iter_var : s->transformed_variables) {
+        if (!state.count(iter_var)) {
+          ICHECK(allow_missing);
+          missing_transformed = true;
+        }
+      }
+      if (missing_transformed) {
+        continue;
+      }
+
+      Array<PrimExpr> transformed_indices;
+      for (const auto& iter_var : s->transformed_variables) {
+        transformed_indices.push_back(state[iter_var]);
+      }
+      Array<PrimExpr> original_indices = s->inverse_transformation->MapIndices(transformed_indices);
+
+      ICHECK_EQ(original_indices.size(), s->original_variables.size());
+      for (size_t i = 0; i < original_indices.size(); i++) {
+        state[s->original_variables[i]] = original_indices[i];
+      }
+
     } else {
       LOG(FATAL) << "unknown relation type";
     }
@@ -270,6 +332,28 @@ void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
       state[s->rebased] = value;
     } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
       state[s->iter] = make_zero(s->iter->var.dtype());
+    } else if (const TransformNode* s = rel.as<TransformNode>()) {
+      bool missing_originals = false;
+      for (const auto& iter_var : s->original_variables) {
+        if (!state.count(iter_var)) {
+          ICHECK(allow_missing);
+          missing_originals = true;
+        }
+      }
+      if (missing_originals) {
+        continue;
+      }
+
+      Array<PrimExpr> original_indices;
+      for (const auto& iter_var : s->original_variables) {
+        original_indices.push_back(state[iter_var]);
+      }
+      Array<PrimExpr> transformed_indices = s->forward_transformation->MapIndices(original_indices);
+
+      ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size());
+      for (size_t i = 0; i < transformed_indices.size(); i++) {
+        state[s->transformed_variables[i]] = transformed_indices[i];
+      }
     } else {
       LOG(FATAL) << "unknown relation type";
     }
@@ -351,6 +435,26 @@ void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>&
   *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}});
 }
 
+Array<IntSet> PassUpDomain(const TransformNode* s,
+                           const std::unordered_map<IterVar, Range>& dom_map,
+                           const Map<IterVar, IntSet>& transformed_domains) {
+  Array<IntSet> output;
+
+  Array<PrimExpr> transformed_indices;
+  for (const auto& iter_var : s->transformed_variables) {
+    transformed_indices.push_back(iter_var->var);
+  }
+
+  Array<PrimExpr> transformed_exprs = s->inverse_transformation->MapIndices(transformed_indices);
+
+  ICHECK_EQ(transformed_exprs.size(), s->original_variables.size());
+  for (size_t i = 0; i < transformed_exprs.size(); i++) {
+    output.push_back(arith::EvalSet(transformed_exprs[i], transformed_domains));
+  }
+
+  return output;
+}
+
 void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                   std::unordered_map<IterVar, IntSet>* p_state) {
   auto& state = *p_state;
@@ -370,6 +474,16 @@ void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>&
       PassUpDomain(r, dom_map, state.at(r->rebased), &parent);
       state[r->parent] = parent;
     } else if (rel.as<SingletonNode>()) {
+    } else if (const TransformNode* r = rel.as<TransformNode>()) {
+      Map<IterVar, IntSet> transformed_domains;
+      for (const auto& var : r->transformed_variables) {
+        transformed_domains.Set(var, state.at(var));
+      }
+      auto original_ranges = PassUpDomain(r, dom_map, transformed_domains);
+      ICHECK_EQ(original_ranges.size(), r->original_variables.size());
+      for (size_t i = 0; i < original_ranges.size(); i++) {
+        state[r->original_variables[i]] = original_ranges[i];
+      }
     } else {
       LOG(FATAL) << "unknown relation type";
     }
@@ -509,6 +623,22 @@ void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
       state[s->parent] = state.at(s->rebased);
     } else if (rel.as<SingletonNode>()) {
       // nop
+    } else if (const TransformNode* s = rel.as<TransformNode>()) {
+      // Currently, this marks all original iter vars as requiring
+      // bounds checks if any of the transformed variables require
+      // bounds checks, even if the inverse expression for that iter
+      // var doesn't depend on the bound variable.
+
+      // TODO(Lunderberg): For each of original variable, check
+      // whether any variable in the inverse expression for it
+      // requires bounds checking.
+      bool needs_bounds_check = false;
+      for (const auto& iter_var : s->transformed_variables) {
+        needs_bounds_check = needs_bounds_check || state[iter_var];
+      }
+      for (const auto& iter_var : s->original_variables) {
+        state[iter_var] = needs_bounds_check;
+      }
     } else {
       LOG(FATAL) << "unknown relation type";
     }
diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc
index 2f74d29..0fcd613 100644
--- a/src/te/schedule/schedule_lang.cc
+++ b/src/te/schedule/schedule_lang.cc
@@ -25,8 +25,10 @@
 #include <tvm/te/operation.h>
 #include <tvm/te/schedule.h>
 
+#include <algorithm>
 #include <stack>
 #include <unordered_set>
+#include <vector>
 
 #include "graph.h"
 
@@ -429,6 +431,80 @@ Stage& Stage::rolling_buffer() {
   self->rolling_buffer = true;
   return *this;
 }
+Stage& Stage::transform_layout(const Array<Var>& initial_indices,
+                               const Array<PrimExpr>& final_indices,
+                               Array<IterVar>* out_iter_vars) {
+  StageNode* self = operator->();
+  IndexMap map(initial_indices, final_indices);
+  self->layout_transforms.push_back(map);
+
+  auto* compute = self->op.as<ComputeOpNode>();
+
+  // Can only rewrite the indices of compute op nodes.
+  if (!compute) {
+    return *this;
+  }
+
+  CHECK_EQ(initial_indices.size(), compute->axis.size())
+      << "Expected number of initial indices in transformation to match the dimension of "
+      << self->op->name;
+
+  // Locate the IterVar objects for the data axes.
+  auto leaf_iter_range = [&]() -> std::pair<size_t, size_t> {
+    std::vector<size_t> leaf_var_indices;
+    for (const auto& axis : compute->axis) {
+      leaf_var_indices.push_back(
+          FindLeafVar(self->all_iter_vars.CopyOnWrite(), self->leaf_iter_vars.CopyOnWrite(), axis));
+    }
+    auto minmax_element = std::minmax_element(leaf_var_indices.begin(), leaf_var_indices.end());
+    return {*minmax_element.first, *minmax_element.second + 1};
+  }();
+  CHECK_EQ(leaf_iter_range.first + compute->axis.size(), leaf_iter_range.second)
+      << "Cannot transform indices if they have already been reordered";
+
+  // Determine the updated ranges of iteration.
+  Array<Range> initial_ranges;
+  for (const auto& iter_var : compute->axis) {
+    initial_ranges.push_back(iter_var->dom);
+  }
+  Array<Range> final_ranges = map->MapRanges(initial_ranges);
+
+  // Make IterVar objects to represent the new iterations.
+  auto inverse = map.Inverse(initial_ranges);
+  Array<IterVar> final_indices_iter;
+  ICHECK_EQ(inverse->initial_indices.size(), final_ranges.size());
+  for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
+    final_indices_iter.push_back(IterVar(final_ranges[i], inverse->initial_indices[i], kDataPar));
+  }
+
+  // Append the new IterVar objects to all_iter_vars
+  for (const auto& iter_var : final_indices_iter) {
+    self->all_iter_vars.push_back(iter_var);
+  }
+
+  // Replace the existing IterVar objects in leaf_iter_vars with the
+  // new IterVar objects.
+  self->leaf_iter_vars.erase(self->leaf_iter_vars.begin() + leaf_iter_range.first,
+                             self->leaf_iter_vars.begin() + leaf_iter_range.second);
+  self->leaf_iter_vars.insert(self->leaf_iter_vars.begin() + leaf_iter_range.first,
+                              final_indices_iter.begin(), final_indices_iter.end());
+
+  // Define a relationship for each new axis
+  self->relations.push_back(Transform(compute->axis, final_indices_iter, map, inverse));
+
+  // Return the iteration variables as an output.
+  if (out_iter_vars) {
+    *out_iter_vars = final_indices_iter;
+  }
+
+  return *this;
+}
+
+Stage& Stage::set_axis_separators(const Array<IntImm>& axis_separators) {
+  StageNode* self = operator->();
+  self->axis_separators = axis_separators;
+  return *this;
+}
 
 Stage CopyStage(const Stage& s) {
   ObjectPtr<StageNode> n = make_object<StageNode>(*s.operator->());
@@ -711,6 +787,16 @@ Singleton::Singleton(IterVar iter) {
   data_ = std::move(n);
 }
 
+Transform::Transform(Array<IterVar> original_variables, Array<IterVar> transformed_variables,
+                     IndexMap forward_transformation, IndexMap inverse_transformation) {
+  auto n = make_object<TransformNode>();
+  n->original_variables = original_variables;
+  n->transformed_variables = transformed_variables;
+  n->forward_transformation = forward_transformation;
+  n->inverse_transformation = inverse_transformation;
+  data_ = std::move(n);
+}
+
 SpecializedCondition::SpecializedCondition(Array<PrimExpr> conditions) {
   ObjectPtr<SpecializedConditionNode> n = make_object<SpecializedConditionNode>();
   n->clauses = std::move(conditions);
@@ -895,6 +981,16 @@ TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffe
 
 TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buffer);
 
+TVM_REGISTER_GLOBAL("te.StageTransformLayout")
+    .set_body_typed([](Stage stage, const Array<Var>& initial_indices,
+                       const Array<PrimExpr>& final_indices) {
+      Array<IterVar> new_iter_vars;
+      stage.transform_layout(initial_indices, final_indices, &new_iter_vars);
+      return new_iter_vars;
+    });
+
+TVM_REGISTER_GLOBAL("te.StageSetAxisSeparators").set_body_method(&Stage::set_axis_separators);
+
 TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize);
 
 TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group);
diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc
index 99e02cc..47ef4af 100644
--- a/src/te/schedule/schedule_ops.cc
+++ b/src/te/schedule/schedule_ops.cc
@@ -40,12 +40,36 @@ namespace te {
 
 using namespace tir;
 
+// Annotate the statement with the layout transforms and axis
+// separators of the stage.  These annotations are removed during
+// SchedulePostProcToPrimFunc.  Afterwards, layout transforms are
+// specified in the PrimFunc attrs, and the axis_separators are
+// specified in the BufferNode.
+Stmt WrapLayoutTransformationAttrs(const Stage& stage, Stmt body) {
+  if (stage->layout_transforms.size()) {
+    for (int i = 0; i < stage->op->num_outputs(); i++) {
+      body = AttrStmt(Array<ObjectRef>{stage->op.output(i), stage->layout_transforms},
+                      tir::attr::layout_transforms, 1, body);
+    }
+  }
+
+  if (stage->axis_separators.size()) {
+    for (int i = 0; i < stage->op->num_outputs(); i++) {
+      body = AttrStmt(Array<ObjectRef>{stage->op.output(i), stage->axis_separators},
+                      tir::attr::axis_separators, 1, body);
+    }
+  }
+
+  return body;
+}
+
 Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_map, Stmt consumer,
                   bool debug_keep_trivial_loop) {
   Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
   if (s->double_buffer) {
     producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer);
   }
+  producer = WrapLayoutTransformationAttrs(s, producer);
   Stmt pipeline = producer;
 
   if (consumer.defined() && !is_no_op(consumer)) {
@@ -209,6 +233,23 @@ class SchedulePostProc : public StmtExprMutator {
           return this->VisitStmt(op->body);
         }
       }
+    } else if (op->attr_key == tir::attr::layout_transforms ||
+               op->attr_key == tir::attr::axis_separators) {
+      auto arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2);
+
+      Stmt body = op->body;
+
+      Tensor tensor = Downcast<Tensor>(arr[0]);
+      auto it = replace_op_.find(tensor->op.get());
+      if (it != replace_op_.end()) {
+        if (it->second.defined()) {
+          return AttrStmt(Array<ObjectRef>{it->second.output(tensor->value_index), arr[1]},
+                          op->attr_key, op->value, this->VisitStmt(op->body));
+        } else {
+          return this->VisitStmt(op->body);
+        }
+      }
     }
     return StmtExprMutator::VisitStmt_(op);
   }
@@ -349,12 +390,16 @@ Stmt ScheduleOps(Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_tri
     Stage s = sch->stages[i - 1];
     ICHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops";
     ICHECK(s->op.defined());
-    // no need to specify place holder op.
-    if (s->op.as<PlaceholderOpNode>()) continue;
     // Remove grouping sugar, get the real attach spec.
     Stage attach_spec = s.GetAttachSpec();
 
-    if (scan_init.count(s->op)) {
+    if (s->op.as<PlaceholderOpNode>()) {
+      // Placeholders don't need any realize/provide statements, but
+      // may be annotated with set_physical_layout to indicate the
+      // physical layout of an input, and must still have the
+      // attribute given.
+      body = WrapLayoutTransformationAttrs(s, std::move(body));
+    } else if (scan_init.count(s->op)) {
       ICHECK(body.defined());
       InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
       body = mu(std::move(body));
@@ -381,6 +426,7 @@ Stmt ScheduleOps(Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_tri
           << body;
     }
   }
+
   SchedulePostProc post_proc;
   post_proc.Init(sch);
   return post_proc(std::move(body));
diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc
index 7e8b12b..0cf6e54 100644
--- a/src/te/schedule/schedule_postproc_to_primfunc.cc
+++ b/src/te/schedule/schedule_postproc_to_primfunc.cc
@@ -42,6 +42,7 @@
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt_functor.h>
 
+#include <functional>
 #include <unordered_map>
 #include <utility>
 
@@ -55,6 +56,7 @@ Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") {
     name += ".v" + std::to_string(tensor->value_index);
   }
   Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope);
+
   return buffer;
 }
 
@@ -86,6 +88,17 @@ class TensorToBufferMapper : public StmtExprMutator {
       Tensor tensor = Downcast<Tensor>(op->node);
       Buffer buffer = GetOrAllocBuffer(tensor);
       return AttrStmt(buffer, op->attr_key, op->value, op->body);
+    } else if (op->attr_key == tir::attr::layout_transforms ||
+               op->attr_key == tir::attr::axis_separators) {
+      auto arr = Downcast<Array<ObjectRef>>(op->node);
+      ICHECK_EQ(arr.size(), 2);
+
+      Stmt body = op->body;
+
+      Tensor tensor = Downcast<Tensor>(arr[0]);
+      Buffer buffer = GetBuffer(tensor);
+
+      return AttrStmt(Array<ObjectRef>{buffer, arr[1]}, op->attr_key, 1, body);
     } else {
       return ret;
     }
@@ -108,7 +121,7 @@ class TensorToBufferMapper : public StmtExprMutator {
     auto ret = StmtExprMutator::VisitStmt_(op);
     op = ret.as<ProducerStoreNode>();
 
-    return BufferStore(buffer, op->value, op->indices);
+    return BufferStore(buffer, op->value, GetIndices(op->indices, buffer->shape));
   }
 
   PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
@@ -116,7 +129,7 @@ class TensorToBufferMapper : public StmtExprMutator {
     op = ret.as<ProducerLoadNode>();
     Tensor tensor = Downcast<Tensor>(op->producer);
     Buffer buffer = GetBuffer(tensor);
-    return tir::BufferLoad(buffer, op->indices);
+    return tir::BufferLoad(buffer, GetIndices(op->indices, buffer->shape));
   }
 
  private:
@@ -134,46 +147,279 @@ class TensorToBufferMapper : public StmtExprMutator {
     return buffer;
   }
 
-  // maps tensor to buffer.
+  Array<PrimExpr> GetIndices(const Array<PrimExpr>& tensor_indices,
+                             const Array<PrimExpr>& buffer_shape) {
+    if (tensor_indices.size() == buffer_shape.size()) {
+      return tensor_indices;
+    } else if (tensor_indices.size() == 1) {
+      // Workaround to support previous behavior of tensor indexing by
+      // a single index, treating the tensor as if were already
+      // flattened by a row-major traversal.
+      PrimExpr unravel = tensor_indices[0];
+      Array<PrimExpr> rev_indices;
+      for (size_t i = buffer_shape.size(); i > 0; i--) {
+        PrimExpr dim = buffer_shape[i - 1];
+        rev_indices.push_back(indexmod(unravel, dim));
+        unravel = indexdiv(unravel, dim);
+      }
+      return Array<PrimExpr>(rev_indices.rbegin(), rev_indices.rend());
+    } else {
+      LOG(FATAL) << "Cannot produce indices for " << buffer_shape.size()
+                 << "-dimensional TIR buffer using " << tensor_indices.size()
+                 << "-dimensional tensor indices.";
+      return {};
+    }
+  }
+
+  // Maps tensor to buffer.
   std::unordered_map<Tensor, Buffer> buffer_map_;
 };
 
+/*! Collect the physical layout map of all tensors in the statement. */
+class LayoutTransformAttrUnwrapper : StmtExprMutator {
+ public:
+  static tir::PrimFunc Apply(tir::PrimFunc func) {
+    // Collect the physical layout annotations in the body, which may
+    // refer to input arguments.
+    auto layout_map = Collector::Collect(func->body);
+
+    if (layout_map.size()) {
+      func = WithAttr(std::move(func), "layout_transform_map", layout_map);
+
+      auto write_ptr = func.CopyOnWrite();
+      write_ptr->body = LayoutTransformAttrUnwrapper()(func->body);
+    }
+
+    return func;
+  }
+
+  LayoutTransformAttrUnwrapper() {}
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    auto ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<AttrStmtNode>();
+
+    if (op->attr_key == tir::attr::layout_transforms) {
+      return op->body;
+    } else {
+      return ret;
+    }
+  }
+
+ private:
+  /*! Collect the physical layout information of all tensors in the statement.
+   *
+   * Must be done before constructing the buffers, since the
+   * attributes could either apply to the external buffers or to
+   * internal allocations.
+   */
+  class Collector : StmtExprVisitor {
+   public:
+    static Map<Buffer, Array<IndexMap>> Collect(Stmt stmt) {
+      Collector collector;
+      collector(std::move(stmt));
+      return std::move(collector.layout_map_);
+    }
+
+    Collector() {}
+
+    void VisitStmt_(const AttrStmtNode* op) final {
+      if (op->attr_key == tir::attr::layout_transforms) {
+        auto arr = Downcast<Array<ObjectRef>>(op->node);
+        ICHECK_EQ(arr.size(), 2);
+
+        auto buffer = Downcast<Buffer>(arr[0]);
+        auto layout_transforms = Downcast<Array<IndexMap>>(arr[1]);
+        layout_map_.Set(buffer, layout_transforms);
+      }
+      StmtExprVisitor::VisitStmt_(op);
+    }
+
+    Map<Buffer, Array<IndexMap>> layout_map_;
+  };
+
+  std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
+
+  Map<Buffer, Array<IndexMap>> layout_map_;
+};
+
+/*! Move axis_separators from an attribute to a buffer property. */
+class AxisSeparatorsAttrUnwrapper : StmtExprMutator {
+ public:
+  static tir::PrimFunc Apply(tir::PrimFunc func) {
+    // Collect the physical layout annotations in the body, which may
+    // refer to input arguments.
+    auto axis_separators_map = Collector::Collect(func->body);
+
+    if (axis_separators_map.size()) {
+      auto write_ptr = func.CopyOnWrite();
+      auto pass = AxisSeparatorsAttrUnwrapper(axis_separators_map);
+      write_ptr->buffer_map = pass.UpdateExternBufferMap(func->buffer_map);
+      write_ptr->body = pass(func->body);
+    }
+
+    return func;
+  }
+
+  explicit AxisSeparatorsAttrUnwrapper(Map<Buffer, Array<IntImm>> axis_separators_map)
+      : axis_separators_map_(axis_separators_map) {}
+
+  Map<Var, Buffer> UpdateExternBufferMap(const Map<Var, Buffer>& orig) {
+    Map<Var, Buffer> output;
+    for (const auto& kv : orig) {
+      output.Set(kv.first, GetRemappedBuffer(kv.second));
+    }
+    return output;
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    auto ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<AttrStmtNode>();
+
+    if (op->attr_key == tir::attr::axis_separators) {
+      return op->body;
+    } else {
+      return ret;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    auto node = Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+ private:
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    Buffer new_buf = GetRemappedBuffer(node->buffer);
+    if (!node->buffer.same_as(new_buf)) {
+      auto writer = node.CopyOnWrite();
+      writer->buffer = new_buf;
+    }
+    return node;
+  }
+
+  Buffer GetRemappedBuffer(Buffer buf) {
+    // If this buffer has already been remapped, then return the
+    // previous value.
+    auto key = buf.get();
+    {
+      auto it = buffer_remap_.find(key);
+      if (it != buffer_remap_.end()) {
+        return it->second;
+      }
+    }
+
+    // Otherwise, check if we need to add axis_separators to this
+    // buffer.
+    auto lookup = axis_separators_map_.Get(buf);
+    if (lookup) {
+      Array<IntImm> axis_separators = lookup.value();
+      if (axis_separators.size()) {
+        auto write_ptr = buf.CopyOnWrite();
+        write_ptr->axis_separators = axis_separators;
+      }
+    }
+
+    // And cache the result for next time.
+    buffer_remap_[key] = buf;
+
+    return buf;
+  }
+
+  /*! Collect the axis separator information of all tensors in the statement.
+   *
+   * Must be done before constructing the buffers, since the
+   * attributes could either apply to the external buffers or to
+   * internal allocations.
+   */
+  class Collector : StmtExprVisitor {
+   public:
+    static Map<Buffer, Array<IntImm>> Collect(Stmt stmt) {
+      Collector collector;
+      collector(std::move(stmt));
+      return std::move(collector.axis_separators_map_);
+    }
+
+    Collector() {}
+
+    void VisitStmt_(const AttrStmtNode* op) final {
+      if (op->attr_key == tir::attr::axis_separators) {
+        auto arr = Downcast<Array<ObjectRef>>(op->node);
+        ICHECK_EQ(arr.size(), 2);
+
+        auto buffer = Downcast<Buffer>(arr[0]);
+        auto axis_separators = Downcast<Array<IntImm>>(arr[1]);
+        axis_separators_map_.Set(buffer, axis_separators);
+      }
+      StmtExprVisitor::VisitStmt_(op);
+    }
+
+    Map<Buffer, Array<IntImm>> axis_separators_map_;
+  };
+
+  std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
+
+  Map<Buffer, Array<IntImm>> axis_separators_map_;
+};
+
 PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
                                     Optional<Map<Tensor, Buffer>> extern_buffer_opt) {
-  std::unordered_map<Tensor, Buffer> extern_buffer;
+  std::unordered_map<Tensor, Buffer> extern_tensor_map;
 
   if (extern_buffer_opt.defined()) {
     auto v = extern_buffer_opt.value();
-    extern_buffer = std::unordered_map<Tensor, Buffer>(v.begin(), v.end());
+    extern_tensor_map = std::unordered_map<Tensor, Buffer>(v.begin(), v.end());
   }
 
   Array<tir::Var> params;
   Map<tir::Var, tir::Buffer> buffer_map;
 
-  for (auto var : arg_list) {
-    if (auto* n = var.as<tir::VarNode>()) {
+  for (auto arg : arg_list) {
+    if (auto* n = arg.as<tir::VarNode>()) {
+      tir::Var var = GetRef<tir::Var>(n);
       params.push_back(GetRef<tir::Var>(n));
-    } else if (auto* n = var.as<te::TensorNode>()) {
+    } else if (auto* n = arg.as<te::TensorNode>()) {
       te::Tensor tensor = GetRef<te::Tensor>(n);
-      ICHECK(!extern_buffer.count(tensor));
+      ICHECK(!extern_tensor_map.count(tensor));
 
       tir::Buffer buffer = CreateBufferFor(tensor);
       tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
       params.push_back(bptr);
       buffer_map.Set(bptr, buffer);
-      extern_buffer[tensor] = buffer;
-    } else {
-      tir::Buffer buffer = Downcast<tir::Buffer>(var);
+      extern_tensor_map[tensor] = buffer;
+    } else if (auto* n = arg.as<tir::BufferNode>()) {
+      tir::Buffer buffer = GetRef<tir::Buffer>(n);
       tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
       params.push_back(bptr);
       buffer_map.Set(bptr, buffer);
+    } else {
+      LOG(FATAL) << "Expected argument to be Var, Tensor, or Buffer, but received "
+                 << arg->GetTypeKey();
     }
   }
 
-  body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
+  body = TensorToBufferMapper(std::move(extern_tensor_map))(std::move(body));
+
+  PrimFunc func = tir::PrimFunc(params, body, VoidType(), buffer_map);
+
+  func = LayoutTransformAttrUnwrapper::Apply(std::move(func));
+  func = AxisSeparatorsAttrUnwrapper::Apply(std::move(func));
+
   // We mark this PrimFunc as coming from a TE schedule
-  return WithAttr(tir::PrimFunc(params, body, VoidType(), buffer_map), "from_legacy_te_schedule",
-                  Bool(true));
+  func = WithAttr(func, "from_legacy_te_schedule", Bool(true));
+
+  return func;
 }
 
 TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc
index 3038eca..974f6ec 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -141,14 +141,13 @@ Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
 void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef<Var>(op)); }
 
 void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
-  UpdateOpaque(op->buffer_var);
-  ExprVisitor::VisitExpr_(op);
+  LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
 }
 
 void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
   std::vector<arith::IntSet> relaxed_region;
   for (const PrimExpr& index : op->indices) {
-    relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+    relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
   }
   Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
   ExprVisitor::VisitExpr_(op);
@@ -194,14 +193,13 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
 }
 
 void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
-  UpdateOpaque(op->buffer_var);
-  StmtVisitor::VisitStmt_(op);
+  LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
 }
 
 void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
   std::vector<arith::IntSet> relaxed_region;
   for (const PrimExpr& index : op->indices) {
-    relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+    relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
   }
   Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
   StmtVisitor::VisitStmt_(op);
diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc
index e680d68..b71e6b2 100644
--- a/src/tir/analysis/buffer_access_lca_detector.cc
+++ b/src/tir/analysis/buffer_access_lca_detector.cc
@@ -43,6 +43,13 @@ class LCADetector : public StmtExprVisitor {
       detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get());
     }
 
+    // The root node must be explicitly present in the list of
+    // ancestor_scopes_.  We cannot use nullptr to represent the root
+    // node, as that is also used to represent a scope that hasn't
+    // been observed before.
+    ScopeInfo root(nullptr, nullptr, 0);
+    detector.ancestor_scopes_.push_back(&root);
+
     detector(func->body);
     // Prepare the return
     Map<Buffer, Optional<Stmt>> buffer_lca;
@@ -120,13 +127,11 @@ class LCADetector : public StmtExprVisitor {
 
   // Explict to visit buffer data in Load and Store node.
   void VisitExpr_(const LoadNode* op) final {
-    ExprVisitor::VisitExpr_(op);
-    VisitBufferVar(op->buffer_var.get());
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
   }
 
   void VisitStmt_(const StoreNode* op) final {
-    StmtVisitor::VisitStmt_(op);
-    VisitBufferVar(op->buffer_var.get());
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
   }
 
   void VisitBufferVar(const VarNode* op) {
@@ -137,6 +142,7 @@ class LCADetector : public StmtExprVisitor {
   }
 
   void UpdateBufferLCA(const BufferNode* buffer) {
+    buffer_var_map_.emplace(buffer->data.get(), buffer);
     if (match_buffers_.find(buffer) == match_buffers_.end()) {
       // Ingore buffer created by block match_buffer
       const ScopeInfo*& lca = buffer_lca_[buffer];
@@ -169,8 +175,11 @@ class LCADetector : public StmtExprVisitor {
     return lhs;
   }
 
-  /*! \brief The ancestor scope stacks info (Block and For), initialized with Null. */
-  std::vector<const ScopeInfo*> ancestor_scopes_ = {nullptr};
+  /*! \brief The ancestor scope stacks info (Block and For).  The
+   *  first element is initialized in LCADetector::Detect to represent
+   *  the root scope.
+   */
+  std::vector<const ScopeInfo*> ancestor_scopes_ = {};
   /*! \brief The map from Buffer to its LCA ForNode/BlockNode. */
   std::unordered_map<const BufferNode*, const ScopeInfo*> buffer_lca_ = {};
   /*! \brief The map from Buffer data to the Buffer. */
diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc
index 26cf66c..1309681 100644
--- a/src/tir/analysis/device_constraint_utils.cc
+++ b/src/tir/analysis/device_constraint_utils.cc
@@ -210,6 +210,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
 
     // Start with a copy of the current prim_func buffer map.
     Map<Var, Buffer> new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end());
+    Map<Var, Buffer> new_preflattened_buffer_map(prim_func->preflattened_buffer_map.begin(),
+                                                 prim_func->preflattened_buffer_map.end());
     bool any_change = false;
 
     // For each constrained parameter...
@@ -223,6 +225,23 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
         any_change = true;
       }
       new_buffer_map.Set(param, new_buffer);
+
+      // Rewrite the pre-flattened buffers to account for constraint.
+      // This only has an impact if the IRModule being analyzed has
+      // already been run through the StorageFlatten or FlattenBuffer
+      // passes.
+      if (auto opt = prim_func->preflattened_buffer_map.Get(param)) {
+        Buffer pf_buffer = opt.value();
+        if (pf_buffer.same_as(buffer)) {
+          new_preflattened_buffer_map.Set(param, new_buffer);
+        } else {
+          const Buffer new_buffer = RewriteBuffer(pf_buffer, virtual_device);
+          if (!new_buffer.same_as(pf_buffer)) {
+            any_change = true;
+          }
+          new_preflattened_buffer_map.Set(param, new_buffer);
+        }
+      }
     }
     // Make sure we have accounted for all prim_func parameters.
     CheckNoRemainingPointerParams(prim_func, &current_primfunc_param_index);
@@ -240,7 +259,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
 
     if (any_change) {
       return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type,
-                      std::move(new_buffer_map), prim_func->attrs, prim_func->span);
+                      std::move(new_buffer_map), std::move(new_preflattened_buffer_map),
+                      prim_func->attrs, prim_func->span);
     } else {
       return prim_func;
     }
@@ -425,9 +445,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
     PointerType new_pointer_type(pointer_type_node->element_type, virtual_device->memory_scope);
     Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span);
     var_subst_.emplace(buffer->data.get(), new_data);
-    Buffer new_buffer(new_data, buffer->dtype, buffer->shape, buffer->strides, buffer->elem_offset,
-                      buffer->name, buffer->data_alignment, buffer->offset_factor,
-                      buffer->buffer_type, buffer->span);
+    Buffer new_buffer = buffer;
+    new_buffer.CopyOnWrite()->data = new_data;
     buffer_subst_.emplace(buffer.get(), new_buffer);
     return new_buffer;
   }
diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc
index c4acd2b..f92afc4 100644
--- a/src/tir/analysis/var_touch.cc
+++ b/src/tir/analysis/var_touch.cc
@@ -44,13 +44,21 @@ class VarTouchVisitor : public StmtExprVisitor {
 
   void VisitExpr_(const VarNode* op) final { Handle(op); }
 
+  void VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+  }
+
   void VisitStmt_(const StoreNode* op) final {
-    Handle(op->buffer_var.get());
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    Handle(op->buffer->data.get());
     StmtVisitor::VisitStmt_(op);
   }
 
-  void VisitExpr_(const LoadNode* op) final {
-    Handle(op->buffer_var.get());
+  void VisitExpr_(const BufferLoadNode* op) final {
+    Handle(op->buffer->data.get());
     ExprVisitor::VisitExpr_(op);
   }
 
diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc
index c1579c2..b082581 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -184,7 +184,15 @@ class GPUCodeVerifier : public StmtExprVisitor {
     StmtVisitor::VisitStmt_(op);
   }
 
-  void VisitExpr_(const LoadNode* op) {
+  void VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+  }
+
+  void VisitStmt_(const StoreNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) {
     if (op->dtype.lanes() > 1) {
       if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
         std::stringstream s;
@@ -197,7 +205,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
     ExprVisitor::VisitExpr_(op);
   }
 
-  void VisitStmt_(const StoreNode* op) {
+  void VisitStmt_(const BufferStoreNode* op) {
     if (op->value->dtype.lanes() > 1) {
       if (static_cast<size_t>(op->value->dtype.lanes() * op->value->dtype.bytes()) >
           max_vector_bytes_) {
diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc
index b6c41b9..6ee30e0 100644
--- a/src/tir/analysis/verify_memory.cc
+++ b/src/tir/analysis/verify_memory.cc
@@ -89,12 +89,20 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   }
 
   void VisitExpr_(const LoadNode* op) final {
-    HandleLoadStoreToVariable(op->buffer_var);
-    return StmtExprVisitor::VisitExpr_(op);
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
   }
 
   void VisitStmt_(const StoreNode* op) final {
-    HandleLoadStoreToVariable(op->buffer_var);
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    HandleLoadStoreToVariable(op->buffer->data);
+    return StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    HandleLoadStoreToVariable(op->buffer->data);
     return StmtExprVisitor::VisitStmt_(op);
   }
   //@}
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 24aacc3..4fe9b16 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -48,10 +48,10 @@ Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) {
 }
 
 Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, String storage_scope,
-                   Span span) {
+                   Array<IntImm> axis_separators, Span span) {
   DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
   return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape,
-                Array<PrimExpr>(), PrimExpr(), name, 0, 0, kDefault, span);
+                Array<PrimExpr>(), PrimExpr(), name, 0, 0, kDefault, axis_separators, span);
 }
 
 // Split the given expression w.r.t the add operator
@@ -243,82 +243,187 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) {
   return no_opt_sum;
 }
 
+Array<PrimExpr> Buffer::OffsetOf(Array<PrimExpr> input_indices) const {
+  return (*this)->ElemOffset(std::move(input_indices));
+}
+
 // The buffer offset in convention of number of elements of
 // original data ignoring number of lanes.
 // We also perform optimization to simplify the indexing expression.
-PrimExpr BufferNode::ElemOffset(Array<PrimExpr> index) const {
-  PrimExpr base = this->elem_offset;
+Array<PrimExpr> BufferNode::ElemOffset(Array<PrimExpr> input_indices) const {
+  ICHECK_EQ(shape.size(), input_indices.size())
+      << "Buffer " << this->name << " is " << shape.size()
+      << "-dimensional, cannot be indexed with the " << input_indices.size()
+      << "-dimensional indices provided.";
+
+  if (strides.size()) {
+    ICHECK_EQ(this->strides.size(), input_indices.size())
+        << "If strides are defined, "
+        << "the index's dimensionality must match the dimensionality of the index given.";
+  }
+
+  // TODO(Lunderberg): Better handling for cases where there is more
+  // than one output index.  Currently, this only allows elem_offset
+  // to be non-zero for flat memory allocations.
+  Array<PrimExpr> elem_offsets = {};
+  if (elem_offset.defined() && !is_zero(elem_offset)) {
+    elem_offsets = {elem_offset};
+  }
+
+  if (elem_offsets.size()) {
+    ICHECK_EQ(elem_offsets.size(), axis_separators.size() + 1)
+        << "If element offsets are defined, "
+        << "there must be one element offset for each output index.";
+  }
+
+  Array<PrimExpr> output_indices(axis_separators.size() + 1, 0);
+
+  size_t current_output_axis = 0;
+
   arith::Analyzer ana;
-  if (this->strides.size() == 0) {
-    // Scalar case
-    if (this->shape.size() == 0 && index.size() == 1) {
-      auto is_int = index[0].as<IntImmNode>();
-      ICHECK(is_int && is_int->value == 0);
-      base = base + index[0];
-    } else {
-      ICHECK_EQ(this->shape.size(), index.size());
-      if (index.size() > 0) {
-        PrimExpr offset = index[0];
-        for (size_t i = 1; i < index.size(); ++i) {
-          offset = MergeMulMod(&ana, offset * this->shape[i] + index[i]);
-        }
-        base = base + offset;
-      }
+
+  for (size_t i = 0; i < input_indices.size(); i++) {
+    if ((current_output_axis < axis_separators.size()) &&
+        (i == size_t(axis_separators[current_output_axis]->value))) {
+      current_output_axis++;
     }
-  } else {
-    ICHECK_EQ(this->strides.size(), index.size());
-    if (is_zero(base)) {
-      base = MergeMulMod(&ana, index[0] * this->strides[0]);
+
+    PrimExpr output_index = output_indices[current_output_axis];
+    if (strides.size()) {
+      output_index = output_index + input_indices[i] * strides[i];
     } else {
-      base = MergeMulMod(&ana, base + index[0] * this->strides[0]);
+      output_index = output_index * this->shape[i] + input_indices[i];
+    }
+
+    if (i > 0) {
+      output_index = MergeMulMod(&ana, output_index);
     }
-    for (size_t i = 1; i < index.size(); ++i) {
-      base = MergeMulMod(&ana, base + index[i] * this->strides[i]);
+
+    output_indices.Set(current_output_axis, output_index);
+  }
+
+  if (elem_offsets.size()) {
+    for (size_t i = 0; i < output_indices.size(); i++) {
+      output_indices.Set(i, output_indices[i] + elem_offsets[i]);
     }
   }
-  return base;
+
+  return SimplifyArray(&ana, output_indices);
 }
 
-inline PrimExpr BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) {
-  PrimExpr offset = n->ElemOffset(index);
+inline Array<PrimExpr> BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) {
+  Array<PrimExpr> offsets = n->ElemOffset(index);
+  // If the Buffer has element type with more than one lane, scale to
+  // get the offset in number of scalars.
   if (n->dtype.lanes() != 1) {
-    offset = offset * make_const(offset.dtype(), dtype.lanes());
+    PrimExpr last_offset = offsets[offsets.size() - 1];
+    offsets.Set(offsets.size() - 1, last_offset * make_const(last_offset.dtype(), dtype.lanes()));
   }
+
+  // If the requested type has more than one lane, make a RampNode at
+  // that offset.
   if (dtype.lanes() != 1) {
-    return tir::Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes());
+    PrimExpr last_offset = offsets[offsets.size() - 1];
+    PrimExpr stride = make_const(last_offset.dtype(), 1);
+    offsets.Set(offsets.size() - 1, tir::Ramp(last_offset, stride, dtype.lanes()));
+  }
+
+  return offsets;
+}
+
+Buffer Buffer::GetFlattenedBuffer() const {
+  auto self = operator->();
+
+  // These checks ensure that all output axes contain at least one
+  // input axis.
+  for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) {
+    auto sep = self->axis_separators[i]->value;
+    auto next_sep = self->axis_separators[i]->value;
+    ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order.";
+  }
+  if (self->axis_separators.size()) {
+    auto first_sep = self->axis_separators[0]->value;
+    ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, "
+                            << "so that first output axis contains at least one input axis";
+    auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value;
+    ICHECK_LT(last_sep, self->shape.size())
+        << "Last output axis must contain at least one input axis.";
+  }
+
+  Array<PrimExpr> output_shape;
+  if (self->strides.size()) {
+    // If strides are defined, then the extent of each flattened
+    // buffer is the stride*size for the first input axis used for
+    // each output axis.
+    ICHECK_EQ(self->shape.size(), self->strides.size());
+    output_shape.push_back(self->strides[0] * self->shape[0]);
+    for (const auto& sep : self->axis_separators) {
+      output_shape.push_back(self->strides[sep->value] * self->shape[sep->value]);
+    }
+
   } else {
-    return offset;
+    // Otherwise, the extent of each flattened buffer is the product
+    // of the extents of each input axis used to generate that output
+    // axis.  This also "flattens" rank-0 tensors to a rank-1 buffer
+    // of shape [1].
+    output_shape = Array<PrimExpr>(self->axis_separators.size() + 1, 1);
+    size_t current_output_index = 0;
+    for (size_t i = 0; i < self->shape.size(); i++) {
+      if ((current_output_index < self->axis_separators.size()) &&
+          (i == size_t(self->axis_separators[current_output_index]->value))) {
+        current_output_index += 1;
+      }
+      output_shape.Set(current_output_index, output_shape[current_output_index] * self->shape[i]);
+    }
   }
+
+  // The axis_separators for the output buffer.
+  Array<IntImm> output_axis_separators;
+  for (size_t i = 0; i < self->axis_separators.size(); i++) {
+    auto dtype = self->axis_separators[i]->dtype;
+    output_axis_separators.push_back(IntImm(dtype, i + 1));
+  }
+
+  Buffer output = *this;
+  auto writer = output.CopyOnWrite();
+  writer->shape = output_shape;
+  writer->axis_separators = output_axis_separators;
+  writer->strides = {};
+
+  return output;
 }
 
-PrimExpr Buffer::vload(Array<PrimExpr> begin, DataType dtype) const {
+PrimExpr Buffer::vload(Array<PrimExpr> begin, DataType value_dtype) const {
   // specially handle bool, stored as DataType::Int(8)
   const BufferNode* n = operator->();
   ICHECK(n != nullptr);
-  ICHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0)
-      << "Cannot load " << dtype << " from buffer of " << n->dtype;
-  if (dtype == DataType::Bool()) {
-    return tir::Cast(DataType::Bool(),
-                     tir::Load(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)),
-                               const_true()));
-  } else {
-    return tir::Load(dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes()));
+  ICHECK(value_dtype.element_of() == n->dtype.element_of() &&
+         value_dtype.lanes() % n->dtype.lanes() == 0)
+      << "Cannot load " << value_dtype << " from buffer of " << n->dtype;
+
+  Array<PrimExpr> indices = begin;
+  int factor = value_dtype.lanes() / n->dtype.lanes();
+  if (factor > 1) {
+    indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor));
   }
+  return BufferLoad(*this, indices);
 }
 
 Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const {
   // specially handle bool, stored as DataType::Int(8)
   const BufferNode* n = operator->();
   ICHECK(n != nullptr);
-  DataType dtype = value.dtype();
-  ICHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0)
-      << "Cannot store " << dtype << " to buffer of " << n->dtype;
-  if (value.dtype() == DataType::Bool()) {
-    return tir::Store(n->data, tir::Cast(DataType::Int(8), value),
-                      BufferOffset(n, begin, DataType::Int(8)), const_true());
-  } else {
-    return tir::Store(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes()));
+  DataType value_dtype = value.dtype();
+  ICHECK(value_dtype.element_of() == n->dtype.element_of() &&
+         value_dtype.lanes() % n->dtype.lanes() == 0)
+      << "Cannot store " << value_dtype << " to buffer of " << n->dtype;
+
+  Array<PrimExpr> indices = begin;
+  int factor = value_dtype.lanes() / n->dtype.lanes();
+  if (factor > 1) {
+    indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor));
   }
+  return BufferStore(*this, value, indices);
 }
 
 String Buffer::scope() const {
@@ -353,7 +458,10 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
   ICHECK(n != nullptr);
   arith::Analyzer ana;
   begins = SimplifyArray(&ana, begins);
-  PrimExpr elem_offset = ana.Simplify(n->ElemOffset(begins));
+  Array<PrimExpr> elem_offset = n->ElemOffset(begins);
+  elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); });
+  ICHECK_EQ(elem_offset.size(), 1) << "MakeSlice currently supports only flat 1-d memory.";
+
   Array<PrimExpr> strides = n->strides;
   if (strides.size() == 0) {
     bool can_relax = true;
@@ -372,7 +480,7 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
       return MakeStrideView().MakeSlice(begins, extents);
     }
   }
-  return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice",
+  return Buffer(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
                 n->data_alignment, 0, n->buffer_type);
 }
 
@@ -407,15 +515,27 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
 
 Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
                PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
-               BufferType buffer_type, Span span) {
+               BufferType buffer_type, Array<IntImm> axis_separators, Span span) {
   DataType storage_dtype = dtype;
   // specially handle bool
   if (storage_dtype == DataType::Bool()) {
     storage_dtype = DataType::Int(8);
   }
-  ICHECK(IsPointerType(data->type_annotation, storage_dtype))
-      << "Buffer data field expect to have the right pointer type annotation"
-      << " annotation=" << data->type_annotation << ", storage_dtype=" << storage_dtype;
+  // The buffer dtype may differ from the dtype of the underlying
+  // allocation, such as a single allocation that backs multiple
+  // tensors without a common datatype.  Therefore, we check that the
+  // data pointer is a pointer, but not the exact type of the
+  // pointed-to values.
+
+  // TODO(Lunderberg): Use an explicit pointer cast for the data
+  // pointer.  Should be done alongside extensions to StmtExprMutator
+  // to more easily handle buffer/buffer_var updates.
+  ICHECK(data->type_annotation.defined())
+      << "Variable " << data->name_hint << " is missing a type annotation.";
+  ICHECK(data->type_annotation.as<PointerTypeNode>())
+      << "Variable " << data->name_hint << " is not a pointer.";
+  ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>())
+      << "Variable " << data->name_hint << " does not point to a primitive.";
 
   auto n = make_object<BufferNode>();
   n->data = std::move(data);
@@ -423,6 +543,7 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr>
 
   n->shape = std::move(shape);
   n->strides = std::move(strides);
+  n->axis_separators = std::move(axis_separators);
   n->name = std::move(name);
   if (!elem_offset.defined()) {
     elem_offset = make_const(n->DefaultIndexType(), 0);
@@ -455,15 +576,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 TVM_REGISTER_NODE_TYPE(BufferNode);
 
 TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
-  ICHECK_EQ(args.size(), 10);
+  ICHECK_EQ(args.size(), 11);
   auto buffer_type = args[8].operator String();
   BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
-  *ret =
-      Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, args[9]);
+  *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type,
+                args[9], args[10]);
 });
 
 TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr);
 
+TVM_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer").set_body_method(&Buffer::GetFlattenedBuffer);
+
+TVM_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf);
+
 TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload);
 
 TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore);
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index fbbd4a9..ef533ef 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -626,6 +626,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 // Load
 Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, Span span) {
+  LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint
+             << ".  Use BufferStore instead.";
   ICHECK(buffer_var.defined());
   ICHECK(predicate.defined());
   ICHECK(index.defined());
@@ -1056,12 +1058,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<AnyNode>([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; });
 
 // BufferLoad
+void BufferLoadNode::LegalizeDType() {
+  int index_lanes = 1;
+  for (const auto& index : indices) {
+    index_lanes *= index.dtype().lanes();
+  }
+
+  int buffer_lanes = buffer->dtype.lanes();
+
+  this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes);
+}
+
 BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span) {
+  ICHECK_EQ(buffer->shape.size(), indices.size())
+      << "Buffer " << buffer->name << " is " << buffer->shape.size()
+      << "-dimensional, cannot be indexed with the " << indices.size()
+      << "-dimensional indices provided.";
+
   ObjectPtr<BufferLoadNode> node = make_object<BufferLoadNode>();
-  node->dtype = buffer->dtype;
   node->buffer = std::move(buffer);
   node->indices = std::move(indices);
   node->span = std::move(span);
+  node->LegalizeDType();
   data_ = std::move(node);
 }
 
diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc
index 4c5ea5b..c8dc846 100644
--- a/src/tir/ir/expr_functor.cc
+++ b/src/tir/ir/expr_functor.cc
@@ -35,8 +35,7 @@ void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
 void ExprVisitor::VisitExpr_(const AnyNode* op) {}
 
 void ExprVisitor::VisitExpr_(const LoadNode* op) {
-  this->VisitExpr(op->index);
-  this->VisitExpr(op->predicate);
+  LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
 }
 
 void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
@@ -127,13 +126,8 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
 PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef<PrimExpr>(op); }
 
 PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
-  PrimExpr index = this->VisitExpr(op->index);
-  PrimExpr predicate = this->VisitExpr(op->predicate);
-  if (index.same_as(op->index) && predicate.same_as(op->predicate)) {
-    return GetRef<PrimExpr>(op);
-  } else {
-    return Load(op->dtype, op->buffer_var, index, predicate);
-  }
+  LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+  return PrimExpr();
 }
 
 PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index f58dd8a..b9c3029 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -29,7 +29,9 @@ namespace tvm {
 namespace tir {
 // Get the function type of a PrimFunc
 PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
-                   Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
+                   Map<tir::Var, Buffer> buffer_map,
+                   Optional<Map<tir::Var, Buffer>> preflattened_buffer_map, DictAttrs attrs,
+                   Span span) {
   // Assume void-return type for now
   // TODO(tvm-team) consider type deduction from body.
   if (!ret_type.defined()) {
@@ -40,6 +42,7 @@ PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
   n->body = std::move(body);
   n->ret_type = std::move(ret_type);
   n->buffer_map = std::move(buffer_map);
+  n->preflattened_buffer_map = preflattened_buffer_map.value_or(Map<tir::Var, Buffer>());
   n->attrs = std::move(attrs);
   n->checked_type_ = n->func_type_annotation();
   n->span = std::move(span);
@@ -118,8 +121,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 TVM_REGISTER_GLOBAL("tir.PrimFunc")
     .set_body_typed([](Array<tir::Var> params, Stmt body, Type ret_type,
-                       Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
-      return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
+                       Map<tir::Var, Buffer> buffer_map,
+                       Map<tir::Var, Buffer> preflattened_buffer_map, DictAttrs attrs, Span span) {
+      return PrimFunc(params, body, ret_type, buffer_map, preflattened_buffer_map, attrs, span);
     });
 
 TVM_REGISTER_GLOBAL("tir.TensorIntrin")
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
new file mode 100644
index 0000000..ba0998e
--- /dev/null
+++ b/src/tir/ir/index_map.cc
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file index_map.cc
+ */
+
+#include "tvm/tir/index_map.h"
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_set.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/op.h>
+
+#include <sstream>
+
+namespace tvm {
+namespace tir {
+
+IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices) {
+  auto n = make_object<IndexMapNode>();
+  n->initial_indices = std::move(initial_indices);
+  n->final_indices = std::move(final_indices);
+  data_ = std::move(n);
+}
+
+IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
+  // Dummy variables to represent the inverse's inputs.
+  Array<Var> output_vars;
+  for (size_t i = 0; i < (*this)->final_indices.size(); i++) {
+    PrimExpr index = (*this)->final_indices[i];
+    // TODO(Lunderberg): Better names for these variables.  A variable
+    // that is passed through unmodified (`index` is an element of
+    // `initial_indices`) should use that input index's name.  A pair
+    // of output indices variables split from a single input index
+    // should be named (X.outer,X.inner).
+    std::stringstream ss;
+    ss << "axis" << i;
+    Var var_index(ss.str(), index.dtype());
+    output_vars.push_back(var_index);
+  }
+
+  // Dummy ranges for the extent of each input.
+  Map<Var, Range> input_iters;
+  ICHECK_EQ((*this)->initial_indices.size(), initial_ranges.size());
+  for (size_t i = 0; i < initial_ranges.size(); i++) {
+    input_iters.Set((*this)->initial_indices[i], initial_ranges[i]);
+  }
+
+  // Unpack the output indices into linear combinations of the initial
+  // indices.
+  arith::Analyzer analyzer;
+  auto diagnostics = DiagnosticContext::Default(IRModule());
+  auto iter_map =
+      DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer, diagnostics);
+  CHECK(iter_map.size()) << "Index transformation was not bijective.";
+
+  // Determine expressions for the input variables, in terms of the
+  // output variables.
+  Map<Var, PrimExpr> inverse_exprs_map =
+      InverseAffineIterMap(iter_map, Array<PrimExpr>(output_vars.begin(), output_vars.end()));
+
+  // Unpack the map to an array, maintaining the same parameter order.
+  Array<PrimExpr> inverse_exprs;
+  for (const auto& index : (*this)->initial_indices) {
+    inverse_exprs.push_back(inverse_exprs_map.at(index));
+  }
+
+  return IndexMap(output_vars, inverse_exprs);
+}
+
+Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices) const {
+  ICHECK_EQ(indices.size(), initial_indices.size());
+
+  arith::Analyzer analyzer;
+
+  for (size_t i = 0; i < initial_indices.size(); i++) {
+    analyzer.Bind(initial_indices[i], indices[i]);
+  }
+
+  Array<PrimExpr> output;
+  for (const auto& output_dim : final_indices) {
+    output.push_back(analyzer.Simplify(output_dim));
+  }
+
+  return output;
+}
+
+Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges) const {
+  ICHECK_EQ(ranges.size(), initial_indices.size());
+
+  Map<Var, Range> input_iters;
+  for (size_t i = 0; i < initial_indices.size(); i++) {
+    input_iters.Set(initial_indices[i], ranges[i]);
+  }
+
+  std::unordered_map<const VarNode*, arith::IntSet> dom_map;
+  for (size_t i = 0; i < initial_indices.size(); i++) {
+    dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]);
+  }
+
+  Array<Range> output;
+  for (const auto& final_index : final_indices) {
+    auto int_set = arith::EvalSet(final_index, dom_map);
+    output.push_back(Range::FromMinExtent(int_set.min(), int_set.max() - int_set.min() + 1));
+  }
+
+  return output;
+}
+
+Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape) const {
+  ICHECK_EQ(shape.size(), initial_indices.size());
+
+  Array<Range> ranges;
+  for (auto& dim : shape) {
+    ranges.push_back(Range(0, dim));
+  }
+  Array<Range> mapped = MapRanges(std::move(ranges));
+
+  Array<PrimExpr> output;
+  for (auto& range : mapped) {
+    ICHECK(is_zero(range->min));
+    output.push_back(range->extent);
+  }
+
+  return output;
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IndexMapNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IndexMapNode*>(node.get());
+      p->stream << "index_map(" << op->initial_indices << ", " << op->final_indices << ")";
+    });
+
+TVM_REGISTER_NODE_TYPE(IndexMapNode);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 1269607..3914f41 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -241,6 +241,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 // Store
 Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) {
+  LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint
+             << ".  Use BufferStore instead.";
   ICHECK(value.defined());
   ICHECK(index.defined());
   ICHECK(predicate.defined());
@@ -341,7 +343,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 // Allocate
 Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
                    Stmt body, Map<String, ObjectRef> annotations, Span span) {
-  CHECK(IsPointerType(buffer_var->type_annotation, dtype))
+  CHECK(IsPointerType(buffer_var->type_annotation, dtype) ||
+        (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8))))
       << "The allocated data type (" << dtype
       << ") does not match the type annotation of the buffer " << buffer_var << " ("
       << buffer_var->type_annotation
@@ -668,6 +671,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 // BufferStore
 BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) {
+  ICHECK_EQ(buffer->shape.size(), indices.size())
+      << "Buffer " << buffer->name << " is " << buffer->shape.size()
+      << "-dimensional, cannot be indexed with the " << indices.size()
+      << "-dimensional indices provided.";
+
   ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
   node->buffer = std::move(buffer);
   node->value = std::move(value);
@@ -760,7 +768,12 @@ BufferRegion BufferRegion::FullRegion(Buffer buffer) {
 BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) {
   Array<Range> region;
   for (const PrimExpr& index : indices) {
-    region.push_back(Range::FromMinExtent(index, 1));
+    if (const RampNode* ramp_index = index.as<RampNode>()) {
+      region.push_back(
+          Range::FromMinExtent(ramp_index->base, ramp_index->stride * ramp_index->lanes));
+    } else {
+      region.push_back(Range::FromMinExtent(index, 1));
+    }
   }
   return BufferRegion(buffer, region);
 }
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index 949e8a1..c4d7ad0 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -64,9 +64,7 @@ void StmtVisitor::VisitStmt_(const AllocateConstNode* op) {
 }
 
 void StmtVisitor::VisitStmt_(const StoreNode* op) {
-  this->VisitExpr(op->value);
-  this->VisitExpr(op->index);
-  this->VisitExpr(op->predicate);
+  LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
 }
 
 void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
@@ -358,18 +356,8 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
-  PrimExpr value = this->VisitExpr(op->value);
-  PrimExpr index = this->VisitExpr(op->index);
-  PrimExpr predicate = this->VisitExpr(op->predicate);
-  if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) {
-    return GetRef<Stmt>(op);
-  } else {
-    auto n = CopyOnWrite(op);
-    n->value = std::move(value);
-    n->index = std::move(index);
-    n->predicate = std::move(predicate);
-    return Stmt(n);
-  }
+  LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  return Stmt();
 }
 
 Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
@@ -664,23 +652,51 @@ class IRSubstitute : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    PrimExpr ret = StmtExprMutator::VisitExpr_(op);
-    op = ret.as<LoadNode>();
-    if (auto mapped_var = vmap_(op->buffer_var)) {
-      return Load(op->dtype, Downcast<Var>(mapped_var.value()), op->index, op->predicate);
-    } else {
-      return ret;
-    }
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    Stmt ret = StmtExprMutator::VisitStmt_(op);
-    op = ret.as<StoreNode>();
-    if (auto mapped_var = vmap_(op->buffer_var)) {
-      return Store(Downcast<Var>(mapped_var.value()), op->value, op->index, op->predicate);
-    } else {
-      return ret;
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    Buffer new_buf = GetRemappedBuffer(node->buffer);
+
+    if (!new_buf.same_as(node->buffer)) {
+      auto writer = node.CopyOnWrite();
+      writer->buffer = new_buf;
     }
+
+    return node;
+  }
+
+  Buffer GetRemappedBuffer(Buffer buf) {
+    auto key = buf.get();
+    auto it = buf_remap_.find(key);
+    if (it != buf_remap_.end()) {
+      return it->second;
+    }
+
+    if (auto mapped_var = vmap_(buf->data)) {
+      auto writer = buf.CopyOnWrite();
+      writer->data = Downcast<Var>(mapped_var);
+    }
+
+    buf_remap_[key] = buf;
+    return buf;
   }
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
@@ -696,7 +712,17 @@ class IRSubstitute : public StmtExprMutator {
   }
 
  private:
+  // Caller provided function that defines the variables to be remapped.
   std::function<Optional<PrimExpr>(const Var&)> vmap_;
+
+  /* \brief Generated map to track buffers being remapped.
+   *
+   * If a `Var BufferNode::data` is remapped, then all buffers
+   * containing that data pointer should also be remapped.  This map
+   * is used to track buffer modifications, and ensure all instances
+   * of a buffer are replaced by the same modified buffer object.
+   */
+  std::unordered_map<const BufferNode*, Buffer> buf_remap_;
 };
 
 Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var&)> vmap) {
diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc
index 6231bb2..ed3ecec 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -456,13 +456,9 @@ class CacheReadRewriter : public StmtExprMutator {
     return ExprMutator::VisitExpr_(load);
   }
 
-  PrimExpr VisitExpr_(const LoadNode* load) final {
-    if (load->buffer_var.same_as(info_->read_buffer->data)) {
-      ObjectPtr<LoadNode> n = make_object<LoadNode>(*load);
-      n->buffer_var = info_->write_buffer->data;
-      return PrimExpr(n);
-    }
-    return ExprMutator::VisitExpr_(load);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
@@ -575,22 +571,14 @@ class CacheWriteRewriter : public StmtExprMutator {
     return ExprMutator::VisitExpr_(load);
   }
 
-  PrimExpr VisitExpr_(const LoadNode* load) final {
-    if (load->buffer_var.same_as(info_->write_buffer->data)) {
-      ObjectPtr<LoadNode> n = make_object<LoadNode>(*load);
-      n->buffer_var = info_->read_buffer->data;
-      return PrimExpr(n);
-    }
-    return ExprMutator::VisitExpr_(load);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
-  Stmt VisitStmt_(const StoreNode* store) final {
-    if (store->buffer_var.same_as(info_->write_buffer->data)) {
-      ObjectPtr<StoreNode> n = make_object<StoreNode>(*store);
-      n->buffer_var = info_->read_buffer->data;
-      return Stmt(n);
-    }
-    return StmtMutator::VisitStmt_(store);
+  Stmt VisitStmt_(const StoreNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
   }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc
index 9a9860b..d7556ed 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -200,14 +200,14 @@ class BaseInliner : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(var);
   }
 
-  PrimExpr VisitExpr_(const LoadNode* load) final {
-    CheckOpaqueAccess(load->buffer_var.get());
-    return StmtExprMutator::VisitExpr_(load);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
-  Stmt VisitStmt_(const StoreNode* store) final {
-    CheckOpaqueAccess(store->buffer_var.get());
-    return StmtExprMutator::VisitStmt_(store);
+  Stmt VisitStmt_(const StoreNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
   }
 
   Stmt VisitStmt_(const ForNode* loop) final {
@@ -284,6 +284,31 @@ class BaseInliner : public StmtExprMutator {
     }
   }
 
+  /*!
+   * \brief Count the number of undefined variables that are not used
+   * as buffer objects.
+   *
+   * This is used to determine whether inlining or reverse inlining is
+   * possible.  The only undefined variables present should be the
+   * load/store indices, or buffer access based on those indices.
+   *
+   * \param stmt The statement in which to count undefined variables
+   */
+  static int GetNumUndefinedNonpointerVars(const Stmt& stmt) {
+    auto undefined_vars = UndefinedVars(stmt, {});
+    // Buffer pointers and the inlined indices are allowed, but no
+    // other variables may appear in the inlined block.
+    int num_nonpointer_vars = 0;
+    for (const auto& var : undefined_vars) {
+      bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() &&
+                        var->type_annotation.as<PointerTypeNode>();
+      if (!is_pointer) {
+        num_nonpointer_vars++;
+      }
+    }
+    return num_nonpointer_vars;
+  }
+
  private:
   /*!
    * \brief Add the buffers in the block signature to the `buffer_var_map_`,
@@ -417,7 +442,8 @@ class ComputeInliner : public BaseInliner {
     if (inlined_store_ == nullptr) {
       return false;
     }
-    int n_vars = UndefinedVars(GetRef<Stmt>(inlined_store_), {}).size();
+
+    int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));
     if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) {
       return false;
     }
@@ -484,7 +510,7 @@ class ReverseComputeInliner : public BaseInliner {
       // Failure: no BufferLoad from the `inlined_buffer_`
       return false;
     }
-    int n_vars = UndefinedVars(GetRef<BufferStore>(inlined_store_), {}).size();
+    int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));
     for (const BufferLoadNode* load : loads) {
       if (!UpdateAndCheckIndexVars(load->indices, n_vars)) {
         // Failure: incorrect of inconsistent index vars
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 1e566a9..d7cd731 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -154,23 +154,34 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
   const Stmt nop = Evaluate(0);
   // dimension checks
   PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
+
+  // Helper functions for shape/stride name formatting
+  auto shape_handle_name = [&]() { return arg_name + ".shape"; };
+  auto stride_handle_name = [&]() { return arg_name + ".strides"; };
+  auto array_element_name = [&](const std::string& arr_name, size_t k) {
+    std::stringstream ss;
+    ss << arr_name << '[' << k << ']';
+    return ss.str();
+  };
+  auto shape_element_name = [&](size_t k) { return array_element_name(shape_handle_name(), k); };
+  auto stride_element_name = [&](size_t k) { return array_element_name(stride_handle_name(), k); };
+
   PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
   std::ostringstream ndim_err_msg;
   ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size();
   auto msg = tvm::tir::StringImm(ndim_err_msg.str());
   asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
   // type checks
-  DataType dtype = buffer->dtype;
   std::ostringstream type_err_msg;
-  type_err_msg << arg_name << ".dtype is expected to be " << dtype;
+  type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype;
   PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) ==
-                       IntImm(DataType::UInt(8), dtype.code()) &&
+                       IntImm(DataType::UInt(8), buffer->dtype.code()) &&
                    TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) ==
-                       IntImm(DataType::UInt(8), dtype.bits()) &&
+                       IntImm(DataType::UInt(8), buffer->dtype.bits()) &&
                    TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) ==
-                       IntImm(DataType::UInt(16), dtype.lanes()));
-  if (!(dtype == DataType::Int(1) || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
-        dtype == DataType::UInt(16))) {
+                       IntImm(DataType::UInt(16), buffer->dtype.lanes()));
+  if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) ||
+        buffer->dtype == DataType::UInt(4) || buffer->dtype == DataType::UInt(16))) {
     auto type_msg = tvm::tir::StringImm(type_err_msg.str());
     asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
     asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
@@ -185,27 +196,29 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
                                      IntImm(DataType::Int(32), buffer->data_alignment), nop));
   }
 
-  Var v_shape(arg_name + ".shape", DataType::Handle());
+  // shape field
+  Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type,
+                                 shape_handle_name());
+  Var v_shape(shape_handle_name(), DataType::Handle());
   def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
   init_nest_.emplace_back(
-      LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop));
+      LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop));
   for (size_t k = 0; k < buffer->shape.size(); ++k) {
-    if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) {
+    if (buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4) ||
+        buffer->dtype == DataType::Int(1)) {
       break;
     }
-    std::ostringstream field_name;
-    field_name << v_shape->name_hint << '[' << k << ']';
     Bind_(buffer->shape[k],
-          cast(buffer->shape[k].dtype(),
-               Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))),
-          field_name.str(), true);
+          cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})),
+          shape_element_name(k), true);
   }
   // strides field
-  Var v_strides(arg_name + ".strides", DataType::Handle());
-  def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
-  init_nest_.emplace_back(
-      LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
-  PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides});
+  Buffer buf_strides = decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())},
+                                   tvm_shape_type, arg_name + ".strides");
+  def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type));
+  init_nest_.emplace_back(LetStmt(
+      buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
+  PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});
   if (buffer->strides.size() == 0) {
     // Assert the buffer is compact
     DataType stype = buffer->DefaultIndexType();
@@ -213,14 +226,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
     Array<PrimExpr> conds;
     for (size_t i = buffer->shape.size(); i != 0; --i) {
       size_t k = i - 1;
-      PrimExpr svalue =
-          cast(stype, Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1)));
+      PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
       conds.push_back(expect_stride == svalue);
       expect_stride = expect_stride * buffer->shape[k];
     }
     std::ostringstream stride_err_msg;
-    stride_err_msg << arg_name << ".strides:"
-                   << " expected to be compact array";
+    stride_err_msg << stride_handle_name() << ": expected to be compact array";
     if (conds.size() != 0) {
       auto stride_msg = tvm::tir::StringImm(stride_err_msg.str());
       Stmt check = AssertStmt(
@@ -235,34 +246,26 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
     PrimExpr stride = make_const(stype, 1);
     for (size_t i = buffer->shape.size(); i != 0; --i) {
       size_t k = i - 1;
-      std::ostringstream field_name;
-      field_name << v_strides->name_hint << '[' << k << ']';
       PrimExpr value =
-          cast(buffer->shape[k].dtype(),
-               Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1)));
+          cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
       value = tvm::if_then_else(v_strides_is_null, stride, value);
       value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
-      Bind_(buffer->strides[k], value, field_name.str(), true);
+      Bind_(buffer->strides[k], value, stride_element_name(k), true);
       stride = analyzer_.Simplify(stride * buffer->shape[k]);
     }
   } else {
     PrimExpr stride_from_shape = 1;
 
     for (int k = buffer->strides.size() - 1; k >= 0; k--) {
-      std::ostringstream field_name;
-      field_name << v_strides->name_hint << '[' << k << ']';
-
       PrimExpr explicit_stride =
-          cast(buffer->shape[k].dtype(),
-               Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1)));
+          cast(buffer->shape[k].dtype(), BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
 
       Bind_(buffer->strides[k],
             tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride),
-            field_name.str(), true);
+            stride_element_name(k), true);
 
       stride_from_shape *=
-          cast(buffer->shape[k].dtype(),
-               Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1)));
+          cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
     }
   }
   // Byte_offset field.
diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc
index 79c4068..193584f 100644
--- a/src/tir/transforms/bf16_legalize.cc
+++ b/src/tir/transforms/bf16_legalize.cc
@@ -199,11 +199,11 @@ class BF16LowerRewriter : public StmtExprMutator {
     Stmt ret = StmtExprMutator::VisitStmt_(op);
     op = ret.as<BufferStoreNode>();
 
-    auto it = buffer_remap_.find(op->buffer);
-    if (it != buffer_remap_.end()) {
-      return BufferStore(it->second, op->value, op->indices);
-    } else {
+    Buffer new_buf = GetRemappedBuffer(op->buffer);
+    if (new_buf.same_as(op->buffer)) {
       return ret;
+    } else {
+      return BufferStore(new_buf, op->value, op->indices);
     }
   }
 
@@ -229,50 +229,34 @@ class BF16LowerRewriter : public StmtExprMutator {
     Stmt ret = StmtExprMutator::VisitStmt_(op);
     op = ret.as<BufferRealizeNode>();
 
-    auto it = buffer_remap_.find(op->buffer);
-    if (it != buffer_remap_.end()) {
-      return BufferRealize(it->second, op->bounds, op->condition, op->body);
-    } else {
+    Buffer new_buf = GetRemappedBuffer(op->buffer);
+    if (new_buf.same_as(op->buffer)) {
       return ret;
+    } else {
+      return BufferRealize(new_buf, op->bounds, op->condition, op->body);
     }
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    // NOTE: we do not explicit recursivly mutate op->buffer_var
-    Stmt ret = StmtExprMutator::VisitStmt_(op);
-    op = ret.as<StoreNode>();
-
-    auto it = var_remap_.find(op->buffer_var);
-    if (it != var_remap_.end()) {
-      return Store(it->second, op->value, op->index, op->predicate);
-    } else {
-      return ret;
-    }
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
   }
 
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
     PrimExpr ret = StmtExprMutator::VisitExpr_(op);
     op = ret.as<BufferLoadNode>();
 
-    auto it = buffer_remap_.find(op->buffer);
-    if (it != buffer_remap_.end()) {
-      return BufferLoad(it->second, op->indices);
-    } else {
+    Buffer new_buf = GetRemappedBuffer(op->buffer);
+    if (new_buf.same_as(op->buffer)) {
       return ret;
+    } else {
+      return BufferLoad(new_buf, op->indices);
     }
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    PrimExpr ret = StmtExprMutator::VisitExpr_(op);
-    op = ret.as<LoadNode>();
-
-    if (op->dtype.is_bfloat16()) {
-      auto it = var_remap_.find(op->buffer_var);
-      ICHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped";
-      return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate);
-    } else {
-      return ret;
-    }
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
   PrimExpr VisitExpr_(const FloatImmNode* op) final {
@@ -284,9 +268,10 @@ class BF16LowerRewriter : public StmtExprMutator {
   }
 
   void AlterBuffers(PrimFuncNode* op) {
-    std::vector<std::pair<Var, Buffer>> changes;
+    Map<Var, Buffer> new_buffer_map;
 
     for (auto& itr : op->buffer_map) {
+      auto param_var = itr.first;
       auto oldbuf = itr.second;
       if (oldbuf->dtype.is_bfloat16()) {
         DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
@@ -296,18 +281,69 @@ class BF16LowerRewriter : public StmtExprMutator {
                              oldbuf->buffer_type);
         buffer_remap_[oldbuf] = newbuf;
         var_remap_[oldbuf->data] = buffer_var;
-        changes.emplace_back(itr.first, newbuf);
+        new_buffer_map.Set(param_var, newbuf);
       } else {
-        changes.emplace_back(itr);
+        new_buffer_map.Set(param_var, oldbuf);
+      }
+    }
+
+    // Most passes do not change the preflattened buffer map, nor
+    // should they change it.  This is an exception, because the Var
+    // associated with the `BufferNode::data` in
+    // `PrimFunc::buffer_map` may be replaced, and the corresponding
+    // Var in the `PrimFunc::preflattened_buffer_map` must also be
+    // replaced.
+    Map<Var, Buffer> new_preflattened_buffer_map;
+    for (auto& itr : op->preflattened_buffer_map) {
+      auto param_var = itr.first;
+      auto oldbuf = itr.second;
+      if (oldbuf->dtype.is_bfloat16()) {
+        auto it = new_buffer_map.find(param_var);
+        ICHECK(it != new_buffer_map.end())
+            << "PrimFunc parameter " << param_var->name_hint
+            << " is associated with the pre-flattened buffer " << oldbuf->name
+            << ", but isn't associated with any post-flatten buffer.";
+        const Buffer& flatbuf = (*it).second;
+        DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
+        auto newbuf = Buffer(flatbuf->data, dtype, oldbuf->shape, oldbuf->strides,
+                             oldbuf->elem_offset, oldbuf->name, oldbuf->data_alignment,
+                             oldbuf->offset_factor, oldbuf->buffer_type);
+        buffer_remap_[oldbuf] = newbuf;
+        new_preflattened_buffer_map.Set(param_var, newbuf);
+      } else {
+        new_preflattened_buffer_map.Set(param_var, oldbuf);
       }
     }
 
     if (buffer_remap_.size() != 0) {
-      op->buffer_map = Map<Var, Buffer>(changes.begin(), changes.end());
+      op->buffer_map = new_buffer_map;
+      op->preflattened_buffer_map = new_preflattened_buffer_map;
     }
   }
 
  private:
+  Buffer GetRemappedBuffer(Buffer buf) {
+    auto buf_it = buffer_remap_.find(buf);
+    if (buf_it != buffer_remap_.end()) {
+      return buf_it->second;
+    }
+
+    Buffer new_buf = buf;
+
+    auto var_it = var_remap_.find(buf->data);
+    if (var_it != var_remap_.end()) {
+      DataType dtype =
+          buf->dtype.is_bfloat16() ? DataType::UInt(16, buf->dtype.lanes()) : buf->dtype;
+      new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, buf->elem_offset, buf->name,
+                       buf->data_alignment, buf->offset_factor, buf->buffer_type,
+                       buf->axis_separators, buf->span);
+    }
+
+    buffer_remap_[buf] = new_buf;
+
+    return new_buf;
+  }
+
   std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
   std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
 };
diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc
index 944a67a..1d2b2db 100644
--- a/src/tir/transforms/bind_params.cc
+++ b/src/tir/transforms/bind_params.cc
@@ -53,12 +53,11 @@ class ParamsCollector : public StmtExprVisitor {
     return constant_list_;
   }
 
-  void VisitExpr_(const LoadNode* ln) {
-    if (constant_map_.find(ln->buffer_var) != constant_map_.end()) {
-      auto it =
-          std::find(constant_list_.begin(), constant_list_.end(), ln->buffer_var.operator->());
+  void VisitExpr_(const BufferLoadNode* ln) {
+    if (constant_map_.find(ln->buffer->data) != constant_map_.end()) {
+      auto it = std::find(constant_list_.begin(), constant_list_.end(), ln->buffer->data.get());
       if (it == constant_list_.end()) {
-        constant_list_.push_back(ln->buffer_var.operator->());
+        constant_list_.push_back(ln->buffer->data.get());
       }
     }
     StmtExprVisitor::VisitExpr_(ln);
diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc
index 3b6af06..85aac3c 100644
--- a/src/tir/transforms/bound_checker.cc
+++ b/src/tir/transforms/bound_checker.cc
@@ -37,25 +37,30 @@
 namespace tvm {
 namespace tir {
 
+// TODO(Lunderberg): Move this pass to be before
+// StorageFlatten/FlattenBuffer.  That will simplify this pass,
+// because it can check directly against the buffer limits.
 class BoundCollector : public StmtVisitor {
  public:
   BoundCollector() {}
 
   void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == tir::attr::buffer_bound) {
-      if (const VarNode* key = op->node.as<VarNode>()) {
-        mem_to_shape[key] = op->value;
+      const VarNode* key = op->node.as<VarNode>();
+      const CallNode* container = op->value.as<CallNode>();
+      if (key && container) {
+        mem_to_shape[key] = container->args;
       }
     }
     StmtVisitor::VisitStmt_(op);
   }
   // Hashtable which maps buffer_var to shape.
-  std::unordered_map<const VarNode*, PrimExpr> mem_to_shape;
+  std::unordered_map<const VarNode*, Array<PrimExpr>> mem_to_shape;
 };
 
 class BoundChecker : public StmtExprMutator {
  public:
-  explicit BoundChecker(const std::unordered_map<const VarNode*, PrimExpr>& mem_to_shape)
+  explicit BoundChecker(const std::unordered_map<const VarNode*, Array<PrimExpr>>& mem_to_shape)
       : mem_to_shape_(mem_to_shape) {}
 
   Stmt VisitStmt_(const AllocateNode* op) final {
@@ -73,21 +78,31 @@ class BoundChecker : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(op);
   }
 
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
+  }
+
   Stmt VisitStmt_(const StoreNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
     store_scope_bound_collector_.clear();
     process_store_ = true;
     unsafe_rewritten_ = false;
     StmtExprMutator::VisitStmt_(op);
     process_store_ = false;
-    if (CanInstrument(op->index, op->buffer_var)) {
-      Collect(op->index, op->buffer_var);
+    if (CanInstrument(op->indices, op->buffer->data)) {
+      Collect(op->indices, op->buffer->data);
     }
     // The collector should has at least one item.
     if (store_scope_bound_collector_.size()) {
       PrimExpr condition = MakeCondition();
       if (!condition.as<StringImmNode>()) {
         Stmt nop = Evaluate(1);
-        Stmt then_case = Store(op->buffer_var, op->value, op->index, op->predicate);
+        Stmt then_case = GetRef<Stmt>(op);
         Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop);
         Stmt body = IfThenElse(condition, then_case, else_case);
         return body;
@@ -96,9 +111,9 @@ class BoundChecker : public StmtExprMutator {
     return GetRef<Stmt>(op);
   }
 
-  PrimExpr VisitExpr_(const LoadNode* op) final {
-    if (CanInstrument(op->index, op->buffer_var)) {
-      Collect(op->index, op->buffer_var);
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    if (CanInstrument(op->indices, op->buffer->data)) {
+      Collect(op->indices, op->buffer->data);
     }
     return StmtExprMutator::VisitExpr_(op);
   }
@@ -108,79 +123,106 @@ class BoundChecker : public StmtExprMutator {
     return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
   }
 
-  void Update(const Var& buffer_var, const Array<PrimExpr>& new_shape, const DataType& type) {
+  void Update(const Var& buffer_var, Array<PrimExpr> new_shape, const DataType& type) {
     // Sanity check at first.
-    if (!new_shape.size()) {
+    if (!ShapeIsValid(new_shape)) {
       return;
     }
 
-    for (size_t i = 0; i < new_shape.size(); ++i) {
-      if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() ||
-          is_negative_const(new_shape[i])) {
-        return;
+    new_shape.MutateByApply([&](const PrimExpr& dim) {
+      // Cast to uint64 to avoid potential overflow.
+      return make_const(DataType::UInt(64), type.lanes()) * dim;
+    });
+    mem_to_shape_[buffer_var.get()] = new_shape;
+  }
+
+  bool ShapeIsValid(const Array<PrimExpr>& shape) const {
+    if (!shape.defined()) {
+      return false;
+    }
+    for (const auto& dim : shape) {
+      if (!IsValidScalar(dim) || is_negative_const(dim)) {
+        return false;
       }
     }
 
-    // Scalarize the shape.
-    PrimExpr shape =
-        Mul(make_const(DataType::UInt(64), type.lanes()), Cast(DataType::UInt(64), new_shape[0]));
-    for (size_t i = 1; i < new_shape.size(); ++i) {
-      // Cast to unsigned to avoid integer overlow at frist.
-      shape = Mul(shape, Mul(make_const(DataType::UInt(64), type.lanes()),
-                             Cast(DataType::UInt(64), new_shape[i])));
-    }
-    mem_to_shape_[buffer_var.get()] = shape;
+    return true;
   }
 
-  bool IndexIsValid(const PrimExpr& index) const {
-    if (!index.defined()) {
+  bool IndicesAreValid(const Array<PrimExpr>& indices) const {
+    if (!indices.defined()) {
       return false;
     }
 
-    if (const RampNode* ramp_index = index.as<RampNode>()) {
-      return ramp_index->base.defined() && ramp_index->base.dtype().is_scalar() &&
-             ramp_index->stride.defined() && ramp_index->stride.dtype().is_scalar() &&
-             (ramp_index->lanes > 0);
+    for (const auto& index : indices) {
+      if (!index.defined()) {
+        return false;
+      }
+
+      if (const RampNode* ramp_index = index.as<RampNode>()) {
+        if (!IsValidScalar(ramp_index->base)) {
+          return false;
+        }
+        if (!IsValidScalar(ramp_index->stride)) {
+          return false;
+        }
+        if (ramp_index->lanes <= 0) {
+          return false;
+        }
+      }
     }
     return true;
   }
 
-  bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const {
-    return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) &&
-           !unsafe_rewritten_;
+  bool IsValidScalar(const PrimExpr& expr) const {
+    return expr.defined() && expr.dtype().is_scalar();
+  }
+
+  bool CanInstrument(const Array<PrimExpr>& indices, const Var& buffer_var) const {
+    return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
+           IndicesAreValid(indices) && !unsafe_rewritten_;
   }
 
-  void Collect(PrimExpr index, Var buffer_var) {
-    store_scope_bound_collector_.push_back(std::make_pair(index, mem_to_shape_[buffer_var.get()]));
+  void Collect(Array<PrimExpr> indices, Var buffer_var) {
+    store_scope_bound_collector_.push_back(
+        std::make_pair(indices, mem_to_shape_[buffer_var.get()]));
   }
 
   PrimExpr MakeCondition() {
     PrimExpr condition;
-    for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) {
-      std::pair<PrimExpr, PrimExpr> buffer_to_mem = store_scope_bound_collector_[i];
-      PrimExpr index = buffer_to_mem.first;
-      PrimExpr upper_bound = buffer_to_mem.second;
-
-      if (const RampNode* ramp_index = index.as<RampNode>()) {
-        // In case index is base + stride * i.
-        // Non inclusive range.
-        index = Add(ramp_index->base, Mul(ramp_index->stride, make_const(ramp_index->stride.dtype(),
-                                                                         ramp_index->lanes - 1)));
+    for (const auto& pair : store_scope_bound_collector_) {
+      Array<PrimExpr> indices = pair.first;
+      Array<PrimExpr> shape = pair.second;
+
+      ICHECK_EQ(indices.size(), shape.size())
+          << "Mismatch between dimension of physical shape and physical indices";
+
+      for (size_t i = 0; i < indices.size(); i++) {
+        PrimExpr index = indices[i];
+        PrimExpr upper_bound = shape[i];
+
+        if (const RampNode* ramp_index = index.as<RampNode>()) {
+          // In case index is base + stride * i.
+          // Non inclusive range.
+          index = Add(ramp_index->base,
+                      Mul(ramp_index->stride,
+                          make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1)));
+        }
+
+        // Try to simplify index and bound.
+        index = analyzer_.Simplify(index);
+        upper_bound = analyzer_.Simplify(upper_bound);
+
+        // Cast to the same type - signed, to be able to check lower bound.
+        index = Cast(DataType::Int(64), index);
+        upper_bound = Cast(DataType::Int(64), upper_bound);
+
+        // Looks like a lower bound should always be zero after normalization.
+        PrimExpr lower_bound = make_zero(DataType::Int(64));
+
+        PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound));
+        condition = condition.defined() ? And(condition, current_condition) : current_condition;
       }
-
-      // Try to simplify index and bound.
-      index = analyzer_.Simplify(index);
-      upper_bound = analyzer_.Simplify(upper_bound);
-
-      // Cast to the same type - signed, to be able to check lower bound.
-      index = Cast(DataType::Int(64), index);
-      upper_bound = Cast(DataType::Int(64), upper_bound);
-
-      // Looks like a lower bound should always be zero after normalization.
-      PrimExpr lower_bound = make_zero(DataType::Int(64));
-
-      PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound));
-      condition = !i ? current_condition : And(condition, current_condition);
     }
     return condition;
   }
@@ -190,11 +232,11 @@ class BoundChecker : public StmtExprMutator {
   // Whether we face tvm_if_then_else intrinsic.
   bool unsafe_rewritten_{false};
   // Pool which collects the pair of index and shape for specific store/load.
-  std::vector<std::pair<PrimExpr, PrimExpr>> store_scope_bound_collector_;
+  std::vector<std::pair<Array<PrimExpr>, Array<PrimExpr>>> store_scope_bound_collector_;
   // Error message.
   const char* const error_message_ = "OUT OF THE BOUNDS";
   // Hashtable which maps buffer_var to shape.
-  std::unordered_map<const VarNode*, PrimExpr> mem_to_shape_;
+  std::unordered_map<const VarNode*, Array<PrimExpr>> mem_to_shape_;
   // internal analyzer
   arith::Analyzer analyzer_;
 };
diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc
index 20ddd7f..6a31739 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -99,13 +99,11 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
   void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef<Var>(op)); }
 
   void VisitExpr_(const LoadNode* op) final {
-    StmtExprVisitor::VisitExpr_(op);
-    VisitBufferVar(op->buffer_var);
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
   }
 
   void VisitStmt_(const StoreNode* op) final {
-    StmtExprVisitor::VisitStmt_(op);
-    VisitBufferVar(op->buffer_var);
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
   }
 
   void VisitStmt_(const ForNode* op) final {
@@ -217,7 +215,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
           continue;
         }
         auto dom_it = dom_map_.find(v);
-        ICHECK(dom_it != dom_map_.end());
+        ICHECK(dom_it != dom_map_.end())
+            << "Could not find domain for loop variable " << v->name_hint;
         non_relaxed[i] = dom_it->second;
         dom_map_.erase(dom_it);
       }
diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc
index 7a6d2d3..f3a9f99 100644
--- a/src/tir/transforms/coproc_sync.cc
+++ b/src/tir/transforms/coproc_sync.cc
@@ -39,18 +39,24 @@ namespace tir {
 class CoProcTouchedBuffer : public StmtExprVisitor {
  public:
   void VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  }
+  void VisitExpr_(const BufferLoadNode* op) final {
     if (in_scope_) {
-      touched_[op->buffer_var.get()].coproc = true;
+      touched_[op->buffer->data.get()].coproc = true;
     } else {
-      touched_[op->buffer_var.get()].normal = true;
+      touched_[op->buffer->data.get()].normal = true;
     }
     StmtExprVisitor::VisitExpr_(op);
   }
-  void VisitStmt_(const StoreNode* op) final {
+  void VisitStmt_(const BufferStoreNode* op) final {
     if (in_scope_) {
-      touched_[op->buffer_var.get()].coproc = true;
+      touched_[op->buffer->data.get()].coproc = true;
     } else {
-      touched_[op->buffer_var.get()].normal = true;
+      touched_[op->buffer->data.get()].normal = true;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
@@ -325,7 +331,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
     Array<arith::IntSet> wset;
     for (const AccessEntry& acc : wvec) {
       ICHECK(acc.dtype == wvec[0].dtype);
-      wset.push_back(acc.touched);
+      ICHECK_EQ(acc.touched.size(), 1) << "CoProcBarrierDetector expects flat memory";
+      wset.push_back(acc.touched[0]);
     }
     Range none;
     Range r = arith::Union(wset).CoverRange(none);
diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc
index e9d99cd..c7cc51d 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -46,13 +46,30 @@ PrimExpr BufferArea(const Buffer& buffer) {
 }
 
 /*!
- * \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store
+ * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension
  */
 class BufferFlattener : public StmtExprMutator {
  public:
-  static Stmt Flatten(const PrimFunc& f) { return BufferFlattener().VisitStmt(f->body); }
+  static PrimFunc Flatten(PrimFunc func) {
+    Map<Var, Buffer> preflattened_buffer_map =
+        Merge(func->buffer_map, func->preflattened_buffer_map);
+
+    auto pass = BufferFlattener(func->buffer_map);
+
+    auto writer = func.CopyOnWrite();
+    writer->body = pass.VisitStmt(func->body);
+    writer->preflattened_buffer_map = preflattened_buffer_map;
+    writer->buffer_map = pass.updated_extern_buffer_map_;
+    return func;
+  }
 
  private:
+  explicit BufferFlattener(const Map<Var, Buffer>& extern_buffer_map) {
+    for (const auto& kv : extern_buffer_map) {
+      updated_extern_buffer_map_.Set(kv.first, GetFlattenedBuffer(kv.second));
+    }
+  }
+
   Stmt VisitStmt_(const BlockRealizeNode* op) final {
     // We have convert blocks into opaque blocks in previous passes.
     ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please "
@@ -67,8 +84,8 @@ class BufferFlattener : public StmtExprMutator {
     }
     // Step 3. Handle allocations in reverse order
     for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
-      const Buffer& buffer = new_block->alloc_buffers[i - 1];
-      body = MakeAllocStmt(buffer, std::move(body));
+      Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]);
+      body = Allocate(buffer->data, buffer->dtype, buffer->shape, const_true(), std::move(body));
     }
     return body;
   }
@@ -112,11 +129,6 @@ class BufferFlattener : public StmtExprMutator {
     return body;
   }
 
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
-    return store->buffer.vstore(store->indices, store->value);
-  }
-
   PrimExpr VisitExpr_(const VarNode* op) final {
     Var var = GetRef<Var>(op);
     auto it = unit_loop_vars_.find(var);
@@ -131,16 +143,69 @@ class BufferFlattener : public StmtExprMutator {
     }
   }
 
+  Buffer GetFlattenedBuffer(Buffer buf) {
+    auto it = buffer_remap_.find(buf);
+    if (it != buffer_remap_.end()) {
+      return it->second;
+    }
+
+    auto flattened = buf.GetFlattenedBuffer();
+
+    // TODO(Lunderberg): Move the handling of boolean into a
+    // dedicated pass.
+    if (flattened->dtype == DataType::Bool()) {
+      auto writer = flattened.CopyOnWrite();
+      writer->dtype = DataType::Int(8);
+    }
+
+    buffer_remap_[buf] = flattened;
+    return flattened;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+
+    // Handle casts from the value's dtype to the dtype of the
+    // backing array.
+    // TODO(Lunderberg): Move the handling of boolean into a
+    // dedicated pass.
+    if (store->value.dtype() == DataType::Bool()) {
+      ICHECK_EQ(store->buffer->dtype, DataType::Int(8))
+          << "Expected int8 backing array for boolean tensor";
+      auto writer = store.CopyOnWrite();
+      writer->value = tir::Cast(DataType::Int(8), store->value);
+    }
+    auto flattened_indices = store->buffer->ElemOffset(store->indices);
+    return VisitBufferAccess(std::move(store));
+  }
+
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    bool load_returns_bool = (op->dtype == DataType::Bool());
     BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
-    return load->buffer.vload(load->indices, load->dtype);
+    load = VisitBufferAccess(load);
+
+    // Handle casts from dtype of the backing array to value's dtype.
+    // TODO(Lunderberg): Move the handling of boolean into a
+    // dedicated pass.
+    if (load_returns_bool) {
+      ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
+          << "Expected int8 backing array for boolean tensor";
+      return tir::Cast(DataType::Bool(), load);
+    } else {
+      return std::move(load);
+    }
   }
 
-  static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) {
-    String storage_scope = buffer.scope();
-    PrimExpr area = BufferArea(buffer);
-    body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body));
-    return body;
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    ICHECK(node->buffer.defined());
+    auto flattened_indices = node->buffer->ElemOffset(node->indices);
+    Buffer flattened_buffer = GetFlattenedBuffer(node->buffer);
+
+    auto writer = node.CopyOnWrite();
+    writer->buffer = flattened_buffer;
+    writer->indices = flattened_indices;
+    return node;
   }
 
   static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag,
@@ -176,14 +241,18 @@ class BufferFlattener : public StmtExprMutator {
 
   /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
   std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;
+
+  /*! \brief Map of buffers being remapped. */
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
+
+  /*! \brief The updated external buffer map. */
+  Map<Var, Buffer> updated_extern_buffer_map_;
 };
 
 PrimFunc FlattenBuffer(PrimFunc f) {
   // Only apply this pass to TIR that is not from TE schedules
   if (!IsFromLegacyTESchedule(f)) {
-    PrimFuncNode* fptr = f.CopyOnWrite();
-    fptr->body = BufferFlattener::Flatten(f);
-    return f;
+    return BufferFlattener::Flatten(f);
   } else {
     return f;
   }
diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc
index 9e74b8c..81842ff 100644
--- a/src/tir/transforms/inject_copy_intrin.cc
+++ b/src/tir/transforms/inject_copy_intrin.cc
@@ -69,9 +69,9 @@ class CopyIntrinInjector : public StmtMutator {
       loops.push_back(op);
       body = op->body;
     }
-    const StoreNode* store = body.as<StoreNode>();
+    auto store = body.as<BufferStoreNode>();
     if (store == nullptr) {
-      *error_info = "the 'StoreNode' of body is a nullptr.";
+      *error_info = "the body is not a 'BufferStoreNode'";
       return false;
     }
     // Expr sel_cond, sel_true_value, sel_false_value;
@@ -81,17 +81,17 @@ class CopyIntrinInjector : public StmtMutator {
                     select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
 
     const CastNode* cast = store->value.as<CastNode>();
-    const LoadNode* load = store->value.as<LoadNode>();
+    auto load = store->value.as<BufferLoadNode>();
     if (0 == loops.size()) {
       ICHECK(!has_cond);
     }
     // for now only support true condition matching
     if (has_cond) {
-      load = sel_true_value.Eval().as<LoadNode>();
+      load = sel_true_value.Eval().as<BufferLoadNode>();
     }
     // cast can be part of the pattern
     if (cast != nullptr) {
-      load = cast->value.as<LoadNode>();
+      load = cast->value.as<BufferLoadNode>();
     }
     if (load == nullptr) {
       *error_info = "the 'LoadNode' of body is a nullptr.";
@@ -102,8 +102,17 @@ class CopyIntrinInjector : public StmtMutator {
     for (const ForNode* op : loops) {
       loop_vars.push_back(op->loop_var);
     }
-    Array<PrimExpr> store_strides = arith::DetectLinearEquation(store->index, loop_vars);
-    Array<PrimExpr> load_strides = arith::DetectLinearEquation(load->index, loop_vars);
+    // TODO(Lunderberg): Move this pass to be before
+    // StorageFlatten/FlattenBuffer.  That will simplify the
+    // implementation, since the pre-flattened indices/strides can be
+    // used directly.
+    ICHECK((store->indices.size() == 1) && (load->indices.size() == 1))
+        << "InjectDoubleBuffer expects flat 1-d buffers.  "
+        << "Has StorageFlatten (TE-based schedules) or "
+        << "FlattenBuffer (TIR-based schedules) been run?";
+
+    Array<PrimExpr> store_strides = arith::DetectLinearEquation(store->indices[0], loop_vars);
+    Array<PrimExpr> load_strides = arith::DetectLinearEquation(load->indices[0], loop_vars);
     if (load_strides.size() == 0 || store_strides.size() == 0) return false;
     Array<PrimExpr> dst_shape;
     const size_t loop_var_size = loop_vars.size();
@@ -160,10 +169,21 @@ class CopyIntrinInjector : public StmtMutator {
       src_strides.push_back(make_const(DataType::Int(32), 1));
       dst_strides.push_back(make_const(DataType::Int(32), 1));
     }
-    Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides,
-                        store_strides[loop_var_size], store->buffer_var->name_hint, 0, 0, kDefault);
-    Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset,
-                        load->buffer_var->name_hint, 0, 0, kDefault);
+    Buffer dst = store->buffer;
+    {
+      auto writer = dst.CopyOnWrite();
+      writer->shape = dst_shape;
+      writer->strides = dst_strides;
+      writer->elem_offset = store_strides[loop_var_size];
+    }
+
+    Buffer src = load->buffer;
+    {
+      auto writer = src.CopyOnWrite();
+      writer->shape = src_shape;
+      writer->strides = src_strides;
+      writer->elem_offset = src_elem_offset;
+    }
     *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
     if (!out->defined()) {
       *error_info = "flower function did not return correct stmt";
diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc
index 0b45bde..03f2ccd 100644
--- a/src/tir/transforms/inject_double_buffer.cc
+++ b/src/tir/transforms/inject_double_buffer.cc
@@ -107,15 +107,15 @@ class DoubleBufferInjector : public StmtExprMutator {
     auto it = dbuffer_info_.find(buf);
     if (it != dbuffer_info_.end()) {
       it->second.scope = GetPtrStorageScope(op->buffer_var);
-      it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
-                                make_const(DataType::Int(32), 1), op->extents) *
-                          op->dtype.lanes();
+
+      ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers.  "
+                                       << "Has StorageFlatten (TE-based schedules) or "
+                                       << "FlattenBuffer (TIR-based schedules) been run?";
+      it->second.stride = op->extents[0];
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<AllocateNode>();
-      Array<PrimExpr> new_extents{make_const(op->extents[0].dtype(), 2)};
-      for (PrimExpr e : op->extents) {
-        new_extents.push_back(e);
-      }
+
+      Array<PrimExpr> new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)};
       ICHECK(it->second.loop != nullptr);
       auto& alloc_nest = loop_allocs_[it->second.loop];
       alloc_nest.emplace_back(
@@ -170,34 +170,77 @@ class DoubleBufferInjector : public StmtExprMutator {
     return stmt;
   }
 
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
+  }
+
   Stmt VisitStmt_(const StoreNode* op) final {
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<StoreNode>();
-    auto it = dbuffer_info_.find(op->buffer_var.get());
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+
+    auto it = dbuffer_info_.find(node->buffer->data.get());
     if (it != dbuffer_info_.end()) {
       const StorageEntry& e = it->second;
       ICHECK(in_double_buffer_scope_);
-      ICHECK(e.stride.defined());
-      return Store(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index,
-                   op->predicate);
-    } else {
-      return stmt;
+      ICHECK(e.switch_write_var.defined());
+
+      ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers.  "
+                                         << "Has StorageFlatten (TE-based schedules) or "
+                                         << "FlattenBuffer (TIR-based schedules) been run?";
+
+      auto writer = node.CopyOnWrite();
+      writer->buffer = GetRemappedBuffer(node->buffer, e.stride);
+      writer->indices = {e.switch_write_var * e.stride + node->indices[0]};
     }
+
+    return std::move(node);
   }
 
-  PrimExpr VisitExpr_(const LoadNode* op) final {
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<LoadNode>();
-    auto it = dbuffer_info_.find(op->buffer_var.get());
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+
+    auto it = dbuffer_info_.find(node->buffer->data.get());
     if (it != dbuffer_info_.end()) {
       const StorageEntry& e = it->second;
-      ICHECK(e.stride.defined());
       ICHECK(e.switch_read_var.defined());
-      return Load(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index,
-                  op->predicate);
-    } else {
-      return expr;
+
+      ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers.  "
+                                         << "Has StorageFlatten (TE-based schedules) or "
+                                         << "FlattenBuffer (TIR-based schedules) been run?";
+
+      auto writer = node.CopyOnWrite();
+      writer->buffer = GetRemappedBuffer(node->buffer, e.stride);
+      writer->indices = {e.switch_read_var * e.stride + node->indices[0]};
     }
+
+    return std::move(node);
+  }
+
+  Buffer GetRemappedBuffer(Buffer buf, PrimExpr stride) {
+    auto key = buf.get();
+    auto it = buf_remap_.find(key);
+    if (it != buf_remap_.end()) {
+      return it->second;
+    }
+
+    ICHECK(stride.defined());
+    // TODO(Lunderberg): Move this pass to before
+    // StorageFlatten/FlattenBuffer.  That will simplify the
+    // implementation, to be the insertion of a new dimension for the
+    // buffer, rather than adjusting the other indices.
+    ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers.  "
+                                    << "Has StorageFlatten (TE-based schedules) or "
+                                    << "FlattenBuffer (TIR-based schedules) been run?";
+    auto writer = buf.CopyOnWrite();
+    writer->shape = {buf->shape[0] * stride};
+
+    buf_remap_[key] = buf;
+    return buf;
   }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
@@ -261,6 +304,8 @@ class DoubleBufferInjector : public StmtExprMutator {
   std::unordered_map<const ForNode*, std::vector<Stmt> > loop_pre_;
   // The allocation size of the buffer
   std::unordered_map<const VarNode*, StorageEntry> dbuffer_info_;
+  // The updated Buffer objects
+  std::unordered_map<const BufferNode*, Buffer> buf_remap_;
 };
 
 namespace transform {
diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc
index 4964bec..f6ce88c 100644
--- a/src/tir/transforms/inject_virtual_thread.cc
+++ b/src/tir/transforms/inject_virtual_thread.cc
@@ -50,7 +50,10 @@ class ExprTouched final : public StmtExprVisitor {
     StmtExprVisitor::VisitStmt(n);
   }
   void VisitExpr_(const LoadNode* op) final {
-    HandleUseVar(op->buffer_var.get());
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  }
+  void VisitExpr_(const BufferLoadNode* op) final {
+    HandleUseVar(op->buffer->data.get());
     StmtExprVisitor::VisitExpr_(op);
   }
   void VisitExpr_(const VarNode* op) final { HandleUseVar(op); }
@@ -101,11 +104,18 @@ class VarTouchedAnalysis : public StmtVisitor {
     Record(op->var.get(), tc);
     this->VisitStmt(op->body);
   }
+
   void VisitStmt_(const StoreNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
     ExprTouched tc(touched_var_, false);
     tc(op->value);
-    tc(op->index);
-    Record(op->buffer_var.get(), tc);
+    for (const auto& index : op->indices) {
+      tc(index);
+    }
+    Record(op->buffer->data.get(), tc);
   }
   void VisitStmt_(const ForNode* op) final {
     ExprTouched tc(touched_var_, false);
@@ -204,20 +214,6 @@ class VTInjector : public StmtExprMutator {
   PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const {
     return index + var_ * alloc_extent;
   }
-  // Load
-  PrimExpr VisitExpr_(const LoadNode* op) final {
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<LoadNode>();
-    if (touched_var_.count(op->buffer_var.get())) {
-      visit_touched_var_ = true;
-    }
-    auto it = alloc_remap_.find(op->buffer_var.get());
-    if (it != alloc_remap_.end()) {
-      return Load(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), op->predicate);
-    } else {
-      return expr;
-    }
-  }
   // Expression.
   PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->op.same_as(builtin::tvm_access_ptr())) {
@@ -230,7 +226,8 @@ class VTInjector : public StmtExprMutator {
       PrimExpr offset = this->VisitExpr(op->args[2]);
       PrimExpr extent = this->VisitExpr(op->args[3]);
       PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes());
-      offset = stride * var_ + offset;
+      offset = RewriteIndex(offset, stride);
+
       return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]});
     } else if (op->op.same_as(builtin::tvm_context_id())) {
       return allow_share_ ? GetRef<PrimExpr>(op) : var_;
@@ -242,21 +239,61 @@ class VTInjector : public StmtExprMutator {
     trigger_base_inject_ = !allow_share_;
     return StmtExprMutator::VisitStmt_(op);
   }
+  // Load
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
+  }
   // Store
   Stmt VisitStmt_(const StoreNode* op) final {
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<StoreNode>();
-    if (touched_var_.count(op->buffer_var.get())) {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+  // BufferLoad
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+  // BufferStore
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    trigger_base_inject_ = !allow_share_;
+    return VisitBufferAccess(std::move(node));
+  }
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    if (touched_var_.count(node->buffer->data.get())) {
       visit_touched_var_ = true;
     }
-    trigger_base_inject_ = !allow_share_;
-    auto it = alloc_remap_.find(op->buffer_var.get());
+
+    auto it = alloc_remap_.find(node->buffer->data.get());
     if (it != alloc_remap_.end()) {
-      return Store(op->buffer_var, op->value, RewriteIndex(op->index, it->second), op->predicate);
-    } else {
-      return stmt;
+      ICHECK_EQ(node->indices.size(), 1)
+          << "InjectVirtualThread expects rewritten allocations to be flat memory.";
+      auto writer = node.CopyOnWrite();
+      writer->buffer = GetRemappedBuffer(node->buffer, it->second);
+      writer->indices = {RewriteIndex(node->indices[0], it->second)};
+    }
+
+    return node;
+  }
+
+  Buffer GetRemappedBuffer(Buffer buf, PrimExpr alloc_extent) {
+    auto key = buf.get();
+    auto it = buf_remap_.find(key);
+    if (it != buf_remap_.end()) {
+      return it->second;
     }
+
+    ICHECK_EQ(buf->shape.size(), 1) << "Expected buffers being rewritten to already be flattened.";
+    auto writer = buf.CopyOnWrite();
+    writer->shape = {buf->shape[0] * alloc_extent};
+
+    buf_remap_[key] = buf;
+    return buf;
   }
+
   // Attribute
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     PrimExpr value = this->VisitExpr(op->value);
@@ -354,46 +391,44 @@ class VTInjector : public StmtExprMutator {
   }
   // Allocate
   Stmt VisitStmt_(const AllocateNode* op) final {
+    Allocate node = GetRef<Allocate>(op);
+
     PrimExpr condition = this->VisitExpr(op->condition);
+
+    Array<PrimExpr> extents = op->extents;
+    extents.MutateByApply([this](const PrimExpr& extent) { return this->VisitExpr(extent); });
+
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
     }
 
-    bool changed = false;
-    Array<PrimExpr> extents;
-    for (size_t i = 0; i < op->extents.size(); i++) {
-      PrimExpr new_ext = this->VisitExpr(op->extents[i]);
-      if (visit_touched_var_ && !vt_loop_injected_) {
-        return InjectVTLoop(GetRef<Stmt>(op), true);
-      }
-      if (!new_ext.same_as(op->extents[i])) changed = true;
-      extents.push_back(new_ext);
-    }
     visit_touched_var_ = false;
 
-    Stmt body;
-    // always rewrite if not allow sharing.
+    // Rewrite the buffer if its shape or any value stored in it
+    // depends on the virtual thread var.  If `allow_share_` is false,
+    // then the buffer is always rewritten, even if separate virtual
+    // threads only read from the buffer.
     if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
       // place v on highest dimension.
-      PrimExpr stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
-                              make_const(DataType::Int(32), 1), op->extents) *
-                        op->dtype.lanes();
-      Array<PrimExpr> other;
-      other.push_back(make_const(op->extents[0].dtype(), num_threads_));
-      for (PrimExpr e : extents) {
-        other.push_back(e);
-      }
-      extents = other;
-      changed = true;
-      // mark this buffer get touched.
+
+      // TODO(Lunderberg): Move pass to apply before
+      // StorageFlatten/FlattenBuffer.  Would rewrite the Buffer to
+      // add the injected virtual thread as the first index.
+      ICHECK_EQ(extents.size(), 1)
+          << "InjectVirtualThread expects rewritten allocations to be flat memory.";
+      PrimExpr stride = extents[0];
+      extents = {stride * num_threads_};
+
+      // Mark the buffer var as touched.  BufferLoad/BufferStore should
+      // access locations at `current_index + stride*vthread_var`.
       alloc_remap_[op->buffer_var.get()] = stride;
-      // Mutate the body.
-      body = this->VisitStmt(op->body);
-    } else {
-      // Mutate the body.
-      body = this->VisitStmt(op->body);
     }
-    if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) {
+
+    // Mutate the body.  Depends on alloc_remap_.
+    auto body = this->VisitStmt(op->body);
+
+    if (extents.same_as(op->extents) && body.same_as(op->body) &&
+        condition.same_as(op->condition)) {
       return GetRef<Stmt>(op);
     } else {
       return Allocate(op->buffer_var, op->dtype, extents, condition, body);
@@ -448,8 +483,21 @@ class VTInjector : public StmtExprMutator {
   const std::unordered_set<const VarNode*>& touched_var_;
   // Whether allow shareding.
   bool allow_share_;
-  // The allocations that get touched -> extent
+  /* \brief The allocations that get touched -> extent
+   *
+   * Maps from the buffer_var of an allocate node to the original
+   * extent of the allocation.  Used when rewriting the indices of
+   * BufferLoad/BufferStore.
+   */
   std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
+  /*! \brief Map of buffers that are modified.
+   *
+   * Buffers allocated or written to within the virtual thread loop
+   * must have one copy per virtual thread.  This is done by enlarging
+   * the allocated buffer size, then modifying the indices at which
+   * each virtual thread accesses the buffer.
+   */
+  std::unordered_map<const BufferNode*, Buffer> buf_remap_;
 };
 
 class VirtualThreadInjector : public StmtMutator {
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 4eb9cc5..700c993 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -101,45 +101,89 @@ class IRConvertSSA final : public StmtExprMutator {
     const Var& v = op->var;
     if (defined_.count(v.get())) {
       PrimExpr value = this->VisitExpr(op->value);
-      Var new_var(v->name_hint, v.dtype());
-      scope_[v.get()].push_back(new_var);
+      ScopedRedefine redefine(this, v);
       PrimExpr body = this->VisitExpr(op->body);
-      scope_[v.get()].pop_back();
-      return Let(new_var, value, body);
+      return Let(redefine.new_var, value, body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitExpr_(op);
     }
   }
+
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<LoadNode>();
-    const VarNode* v = op->buffer_var.get();
-    if (scope_.count(v) && !scope_[v].empty()) {
-      return Load(op->dtype, scope_[v].back(), op->index, op->predicate);
-    } else {
-      return expr;
-    }
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
+
   Stmt VisitStmt_(const StoreNode* op) final {
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<StoreNode>();
-    const VarNode* v = op->buffer_var.get();
-    if (scope_.count(v) && !scope_[v].empty()) {
-      return Store(scope_[v].back(), op->value, op->index, op->predicate);
-    } else {
-      return stmt;
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    auto output = VisitBufferAccess(std::move(node));
+    return std::move(output);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    auto output = VisitBufferAccess(std::move(node));
+    return std::move(output);
+  }
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    Buffer new_buf = GetRemappedBuffer(node->buffer);
+    if (!new_buf.same_as(node->buffer)) {
+      auto writer = node.CopyOnWrite();
+      writer->buffer = new_buf;
     }
+
+    return node;
   }
+
+  Buffer GetRemappedBuffer(Buffer buf) {
+    // Determine the buffer var that should be in the updated buffer,
+    // given the current scope.  If no redefines are present, then the
+    // buffer var is unchanged.
+    Var new_buffer_var = buf->data;
+    auto var_it = scope_.find(buf->data.get());
+    if (var_it != scope_.end() && !var_it->second.empty()) {
+      new_buffer_var = var_it->second.back();
+    }
+
+    // If no mapping is required, return the original buffer.
+    if (new_buffer_var.same_as(buf->data)) {
+      return buf;
+    }
+
+    // If the current scope already has a mapping of this buffer, use
+    // the mapped buffer.
+    auto key = buf.get();
+    std::vector<Buffer>& buffers = buf_remap_[key];
+    if (buffers.size() && buffers.back()->data.same_as(new_buffer_var)) {
+      return buffers.back();
+    }
+
+    // Otherwise, make and return a new buffer object that uses the
+    // new buffer, pushing it onto the scoped stack of existing
+    // buffers.  This will be popped when the new_buffer_var
+    // redefinition is popped.
+    Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, buf->elem_offset,
+                   buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type,
+                   buf->axis_separators, buf->span);
+    buffers.push_back(new_buf);
+    return new_buf;
+  }
+
   Stmt VisitStmt_(const LetStmtNode* op) final {
     const Var& v = op->var;
     if (defined_.count(v.get())) {
       PrimExpr value = this->VisitExpr(op->value);
-      Var new_var(v->name_hint, v.dtype());
-      scope_[v.get()].push_back(new_var);
+      ScopedRedefine redefine(this, v);
       Stmt body = this->VisitStmt(op->body);
-      scope_[v.get()].pop_back();
-      return LetStmt(new_var, value, body);
+      return LetStmt(redefine.new_var, value, body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
@@ -148,12 +192,10 @@ class IRConvertSSA final : public StmtExprMutator {
   Stmt VisitStmt_(const ForNode* op) final {
     const Var& v = op->loop_var;
     if (defined_.count(v.get())) {
-      Var new_var(v->name_hint, v.dtype());
-      scope_[v.get()].push_back(new_var);
+      ScopedRedefine redefine(this, v);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
-      scope_[v.get()].pop_back();
       op = stmt.as<ForNode>();
-      return For(new_var, op->min, op->extent, op->kind, op->body, op->thread_binding,
+      return For(redefine.new_var, op->min, op->extent, op->kind, op->body, op->thread_binding,
                  op->annotations);
     } else {
       defined_.insert(v.get());
@@ -163,12 +205,10 @@ class IRConvertSSA final : public StmtExprMutator {
   Stmt VisitStmt_(const AllocateNode* op) final {
     const Var& v = op->buffer_var;
     if (defined_.count(v.get())) {
-      Var new_var(v->name_hint, v->type_annotation);
-      scope_[v.get()].push_back(new_var);
+      ScopedRedefine redefine(this, v);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
-      scope_[v.get()].pop_back();
       op = stmt.as<AllocateNode>();
-      return Allocate(new_var, op->dtype, op->extents, op->condition, op->body);
+      return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
@@ -189,8 +229,34 @@ class IRConvertSSA final : public StmtExprMutator {
   }
 
  private:
+  struct ScopedRedefine {
+    ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) {
+      if (old_var->type_annotation.defined()) {
+        new_var = Var(old_var->name_hint, old_var->type_annotation);
+      } else {
+        new_var = Var(old_var->name_hint, old_var->dtype);
+      }
+      parent->scope_[old_var.get()].push_back(new_var);
+    }
+
+    ~ScopedRedefine() {
+      parent->scope_[old_var.get()].pop_back();
+      for (auto& kv : parent->buf_remap_) {
+        std::vector<Buffer>& buffers = kv.second;
+        if (buffers.size() && (buffers.back()->data.get() == new_var.get())) {
+          buffers.pop_back();
+        }
+      }
+    }
+
+    IRConvertSSA* parent;
+    Var old_var;
+    Var new_var;
+  };
+
   std::unordered_map<const VarNode*, std::vector<Var>> scope_;
   std::unordered_set<const VarNode*> defined_;
+  std::unordered_map<const BufferNode*, std::vector<Buffer>> buf_remap_;
 };
 
 Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); }
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index d7ae362..2234cc2 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -103,9 +103,11 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index,
  * \param offset the offset index.
  */
 inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) {
-  return Call(DataType::Handle(), builtin::address_of(),
-              {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
-                    const_true(dtype.lanes()))});
+  PrimExpr offset_expr = make_const(DataType::Int(32), offset * dtype.lanes());
+  Buffer dummy_buf(handle, dtype, {offset_expr + 1}, {}, 0, handle->name_hint, 0, 0, kDefault);
+  BufferLoad buf_load(dummy_buf, {offset_expr});
+
+  return Call(DataType::Handle(), builtin::address_of(), {buf_load});
 }
 
 /*!
@@ -119,8 +121,12 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) {
     offset = offset * make_const(offset.dtype(), dtype.lanes());
     offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes());
   }
-  return Call(DataType::Handle(), builtin::address_of(),
-              {Load(dtype, handle, offset, const_true(dtype.lanes()))});
+
+  Buffer dummy_buf(handle, dtype.element_of(), {offset + 1}, {}, 0, handle->name_hint, 0, 0,
+                   kDefault);
+  BufferLoad buf_load(dummy_buf, {offset});
+
+  return Call(DataType::Handle(), builtin::address_of(), {buf_load});
 }
 
 /*!
diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc
index 4df38ff..df8bf69 100644
--- a/src/tir/transforms/lower_cross_thread_reduction.cc
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -314,7 +314,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional<Buf
     // 3-rd argument: predicate
     parameters.push_back(const_true());
     // 4-th argument: destination
-    parameters.push_back(ct_buffer->data);
+    parameters.push_back(BufferLoad(ct_buffer, {0}));
     // next arguments: all the reduction threads
     for (const ForNode* reduction_loop : reduction_loops) {
       if (reduction_loop->thread_binding.defined()) {
diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc
index 21f1b18..3cf5ed2 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -103,32 +103,69 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     }
   }
 
-  PrimExpr VisitExpr_(const LoadNode* load) final {
-    bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
-    PrimExpr expr = StmtExprMutator::VisitExpr_(load);
-    load = expr.as<LoadNode>();
-    if (to_be_lowered) {
-      auto new_load_type = DataType::UInt(load->dtype.bits());
-      auto buffer_var = load->buffer_var;
-      auto it = var_remap_.find(buffer_var);
-      if (it != var_remap_.end()) {
-        buffer_var = it->second;
-      }
-      return Load(new_load_type, buffer_var, load->index, load->predicate);
-    }
-    return expr;
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    Stmt ret = StmtExprMutator::VisitStmt_(op);
-    op = ret.as<StoreNode>();
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    auto modified = VisitBufferAccess(node);
 
-    auto it = var_remap_.find(op->buffer_var);
-    if (it != var_remap_.end()) {
-      return Store(it->second, op->value, op->index, op->predicate);
+    // Not needed for BufferStoreNode, so we can't just call
+    // LegalizeDtype() in VisitBufferAccess.
+    if (node.same_as(modified)) {
+      return std::move(node);
     } else {
-      return ret;
+      auto writer = modified.CopyOnWrite();
+      writer->LegalizeDType();
+      return std::move(modified);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    Buffer new_buf = GetRemappedBuffer(node->buffer);
+    if (!new_buf.same_as(node->buffer)) {
+      auto writer = node.CopyOnWrite();
+      writer->buffer = new_buf;
+    }
+
+    return node;
+  }
+
+  Buffer GetRemappedBuffer(Buffer buf) {
+    auto key = buf;
+    auto cache_it = buf_remap_.find(key);
+    if (cache_it != buf_remap_.end()) {
+      return cache_it->second;
+    }
+
+    bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(buf->dtype.code());
+
+    if (to_be_lowered) {
+      auto new_load_type = DataType::UInt(buf->dtype.bits());
+      auto writer = buf.CopyOnWrite();
+      writer->dtype = new_load_type;
+
+      auto var_it = var_remap_.find(buf->data);
+      if (var_it != var_remap_.end()) {
+        writer->data = var_it->second;
+      }
     }
+
+    buf_remap_[key] = buf;
+    return buf;
   }
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
@@ -200,6 +237,7 @@ class CustomDatatypesLowerer : public StmtExprMutator {
   std::string target_;
   // remap buffer vars
   std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
 };
 
 namespace transform {
diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc
index 6bfbcef..5bde5cb 100644
--- a/src/tir/transforms/lower_match_buffer.cc
+++ b/src/tir/transforms/lower_match_buffer.cc
@@ -177,7 +177,7 @@ class MatchBufferLower : public StmtExprMutator {
     Bind(buffer->data, source_buffer->data, buffer->name + ".data");
 
     // Step.2.2. Update element offset
-    // Note we create Load via vload and try to reuse index calculate.
+    // We use the ElemOffset method to avoid duplicating the index calculation.
     {
       Array<PrimExpr> indices;
       indices.reserve(source->region.size());
@@ -185,11 +185,18 @@ class MatchBufferLower : public StmtExprMutator {
         indices.push_back(range->min);
       }
 
-      Load load = Downcast<Load>(source_buffer.vload(indices, source_buffer->dtype));
-      Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset");
-      CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0))
-          << "The source elem_offset " << load->index << " does not satisfy the offset_factor "
-          << buffer->offset_factor << ".";
+      Array<PrimExpr> buffer_start_indices = source_buffer->ElemOffset(indices);
+      if (buffer_start_indices.size() == 1) {
+        Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset");
+        CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0))
+            << "The source elem_offset " << buffer_start_indices[0]
+            << " does not satisfy the offset_factor " << buffer->offset_factor << ".";
+      } else {
+        // Non-zero elem_offset is ill-defined for non-flat memory.
+        // If needed in the future, will require `Array<PrimExpr>
+        // elem_offsets`, with one offset for each flattened index.
+        Bind(buffer->elem_offset, 0);
+      }
     }
 
     // Step 2.3. Check and update strides
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index 1c6aa16..7e09943 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -98,36 +98,97 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     if (it != alloc_remap_.end()) {
       const AllocateNode* repl = it->second.as<AllocateNode>();
       if (warp_allocs_.count(repl)) {
-        stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
         new_storage_scopes_[repl->buffer_var.get()] = "local";
       } else {
-        stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
         new_storage_scopes_[repl->buffer_var.get()] = "shared";
       }
-      return stmt;
+      return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
     } else {
       return stmt;
     }
   }
+
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    auto it = load_remap_.find(op->buffer_var.get());
-    if (it != load_remap_.end()) {
-      ICHECK(is_zero(op->index));
-      return it->second;
-    } else {
-      return StmtExprMutator::VisitExpr_(op);
-    }
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    auto it = store_remap_.find(op->buffer_var.get());
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    {
+      auto it = load_remap_.find(op->buffer->data.get());
+      if (it != load_remap_.end()) {
+        for (const auto& index : op->indices) {
+          ICHECK(is_zero(index));
+        }
+        return it->second;
+      }
+    }
+
+    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    op = load.get();
+
+    {
+      auto it = buf_remap_.find(op->buffer.get());
+      if (it != buf_remap_.end()) {
+        return BufferLoad(it->second, op->indices, op->span);
+      }
+    }
+
+    {
+      auto it = var_remap_.find(op->buffer->data.get());
+      if (it != var_remap_.end()) {
+        Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape,
+                               op->buffer->strides, op->buffer->elem_offset, op->buffer->name,
+                               op->buffer->data_alignment, op->buffer->offset_factor,
+                               op->buffer->buffer_type, op->buffer->axis_separators,
+                               op->buffer->span);
+        buf_remap_[op->buffer.get()] = remapped_buffer;
+        return BufferLoad(remapped_buffer, op->indices, op->span);
+      }
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+
+    auto it = store_remap_.find(store->buffer.get());
     if (it != store_remap_.end()) {
-      ICHECK(is_zero(op->index));
-      auto value = StmtExprMutator::VisitExpr(op->value);
-      return Store(it->second, value, 0, op->predicate);
-    } else {
-      return StmtExprMutator::VisitStmt_(op);
+      for (const auto& index : op->indices) {
+        ICHECK(is_zero(index));
+      }
+
+      auto writer = store.CopyOnWrite();
+      writer->buffer = it->second;
+      return std::move(store);
+    }
+
+    {
+      auto it = buf_remap_.find(store->buffer.get());
+      if (it != buf_remap_.end()) {
+        return BufferStore(it->second, store->value, store->indices, store->span);
+      }
     }
+
+    {
+      auto it = var_remap_.find(store->buffer->data.get());
+      if (it != var_remap_.end()) {
+        Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape,
+                               store->buffer->strides, store->buffer->elem_offset,
+                               store->buffer->name, store->buffer->data_alignment,
+                               store->buffer->offset_factor, store->buffer->buffer_type,
+                               store->buffer->axis_separators, store->buffer->span);
+        buf_remap_[store->buffer.get()] = remapped_buffer;
+        return BufferStore(remapped_buffer, store->value, store->indices, store->span);
+      }
+    }
+
+    return std::move(store);
   }
 
   std::unordered_map<const VarNode*, String> new_storage_scopes_;
@@ -164,11 +225,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       }
       types[idx] = values[idx].dtype();
     }
-    std::vector<const VarNode*> buffers(size);
+    std::vector<Buffer> buffers(size);
     for (size_t idx = 0; idx < size; ++idx) {
-      const VarNode* buffer = call->args[2 + size + idx].as<VarNode>();
-      ICHECK(buffer);
-      buffers[idx] = buffer;
+      PrimExpr arg = call->args[2 + size + idx];
+      // Loads from boolean buffers may have cast nodes inserted by
+      // earlier passes.
+      if (auto cast = arg.as<CastNode>()) {
+        arg = cast->value;
+      }
+      buffers[idx] = Downcast<BufferLoad>(arg)->buffer;
     }
 
     std::unordered_set<const VarNode*> reduce_set;
@@ -246,8 +311,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     }
 
     std::vector<Stmt> seq;
-    std::vector<Var> shared_bufs(size);
-    std::vector<Stmt> local_vars;
+    std::vector<Var> shared_buffer_vars(size);
+    std::vector<Buffer> shared_bufs(size);
+    std::vector<Buffer> local_bufs;
     //
     // This is an optimization. For small reduction sizes, it may be beneficial
     // for a single warp to performance the entire reduction. No trips to shared
@@ -271,19 +337,23 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // This is the index to the reduction variable, one reduction
       // variable per warp. Local scope seems easier to reason without
       // relying on a pattern match pass to fix it later.
-      PrimExpr index(0);
+      Array<PrimExpr> zero_indices = {0};
 
       for (size_t idx = 0; idx < size; ++idx) {
-        Type ptr_type = PointerType(PrimType(types[idx]));
-        shared_bufs[idx] = Var("red_buf" + std::to_string(idx), ptr_type);
+        Array<PrimExpr> shape = {1};
+
+        Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx));
+        Var buffer_var = buffer->data;
+
+        shared_buffer_vars[idx] = buffer_var;
+        shared_bufs[idx] = buffer;
+
         PrimExpr pred = const_true(types[idx].lanes());
-        seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred));
+        seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices));
 
-        // Uses a local variable to store the shuffled data.
-        // Later on, this allocation will be properly attached to this statement.
-        Var var("t" + std::to_string(idx), ptr_type);
-        Stmt s = Allocate(var, types[idx], {PrimExpr(1)}, pred, Evaluate(0));
-        local_vars.push_back(s);
+        // Uses a local variable to store the shuffled data.  Later
+        // on, an allocation will be built for this local variable.
+        local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx)));
       }
 
       // The mask for this reducer, as this reducer may sit inside
@@ -291,18 +361,16 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // active channels.
       //
       DataType mask_dtype = DataType::UInt(32);
-      Var mask_var("mask", PointerType(PrimType(mask_dtype)));
+      Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
       {
-        PrimExpr pred = const_true(1);
         PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
         if (group_extent > 1) {
           mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent * group_index));
         }
-        seq.emplace_back(Store(mask_var, mask, index, pred));
-        // Push allocation with an empty body. Later this will be fixed
-        // when the entire body is ready.
-        auto stmt = Allocate(mask_var, mask_dtype, {PrimExpr(1)}, pred, Evaluate(0));
-        local_vars.push_back(stmt);
+        seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
+        // Push the buffer description.  Later this will have an
+        // allocation built for it.
+        local_bufs.push_back(mask_buffer);
       }
 
       // Emit reductions within a warp.
@@ -314,9 +382,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         // Load reduction values, no synchronization needed.
         Array<PrimExpr> a, b;
         for (size_t i = 0; i < size; ++i) {
-          Var var = shared_bufs[i];
-          PrimExpr pred = const_true(types[i].lanes());
-          PrimExpr val = Load(types[i], var, index, pred);
+          Buffer shared_buf = shared_bufs[i];
+          BufferLoad val(shared_buf, zero_indices);
+          ICHECK_EQ(val->dtype, types[i]);
           a.push_back(val);
 
           // __shfl_*sync calls shall not appear in if_then_else expressions
@@ -332,12 +400,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
           // The former may cause dead lock as there is a divergent
           // branch with a warp sync call inside.
           //
-          PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_var, val, offset);
-          const AllocateNode* repl = local_vars[i].as<AllocateNode>();
-          Stmt s = Store(repl->buffer_var, other, index, pred);
+          PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
+          Buffer local_buf = local_bufs[i];
+          Stmt s = BufferStore(local_buf, other, zero_indices);
           seq.push_back(s);
 
-          PrimExpr load = Load(types[i], repl->buffer_var, index, pred);
+          BufferLoad load = BufferLoad(local_buf, zero_indices);
+          ICHECK_EQ(load->dtype, types[i]);
           b.push_back(load);
         }
 
@@ -347,9 +416,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         // Store the reduction result to itself.
         std::vector<Stmt> stores(size);
         for (size_t i = 0; i < size; ++i) {
-          Var var = shared_bufs[i];
-          PrimExpr pred = const_true(types[i].lanes());
-          stores[i] = Store(var, ret[i], index, pred);
+          Buffer buf = shared_bufs[i];
+          stores[i] = BufferStore(buf, ret[i], zero_indices);
         }
         seq.push_back(SeqStmt::Flatten(stores));
       }
@@ -359,34 +427,35 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // uniformly writting the same result.
       //
       for (size_t i = 0; i < size; ++i) {
-        Var var = shared_bufs[i];
-        PrimExpr pred = const_true(types[i].lanes());
-        PrimExpr val = Load(types[i], var, index, pred);
+        Buffer buf = shared_bufs[i];
+        PrimExpr val = BufferLoad(buf, zero_indices);
+        ICHECK_EQ(val->dtype, types[i]);
         PrimExpr splat =
-            WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, reduce_extent * group_index);
-        seq.push_back(Store(var, splat, index, pred));
+            WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
+        seq.push_back(BufferStore(buf, splat, zero_indices));
       }
 
       // Update existing allocations.
       for (size_t i = 0; i < size; ++i) {
-        ICHECK(!load_remap_.count(buffers[i]));
+        ICHECK(!load_remap_.count(buffers[i]->data.get()));
         PrimExpr pred = const_true(types[i].lanes());
-        Var var = shared_bufs[i];
-        load_remap_[buffers[i]] = Load(types[i], var, index, pred);
-        store_remap_[buffers[i]] = var;
+        Buffer buf = shared_bufs[i];
+        PrimExpr val = BufferLoad(buf, zero_indices);
+        ICHECK_EQ(val->dtype, types[i]);
+        load_remap_[buffers[i]->data.get()] = val;
+        store_remap_[buffers[i].get()] = buf;
         Array<PrimExpr> extents{PrimExpr(1)};
-        auto node = Allocate(var, types[i], extents, pred, Evaluate(0));
-        alloc_remap_[buffers[i]] = node;
+        auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
+        alloc_remap_[buffers[i]->data.get()] = node;
+        var_remap_[buffers[i]->data.get()] = buf->data;
         warp_allocs_.insert(node.get());
       }
     } else {
       if (reduce_extent == 1) {
         // special case, no reduction is needed.
-        std::vector<Stmt> stores(size);
+        std::vector<Stmt> stores;
         for (size_t i = 0; i < size; ++i) {
-          PrimExpr pred = const_true(types[i].lanes());
-          Var buffer_var = Downcast<Var>(call->args[2 + size + i]);
-          stores[i] = Store(buffer_var, values[i], 0, pred);
+          stores.push_back(BufferStore(buffers[i], values[i], {0}));
         }
         return SeqStmt::Flatten(stores);
       }
@@ -394,35 +463,38 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // previous iteration on the same buffer.
       seq.emplace_back(SyncThread("shared"));
       for (size_t idx = 0; idx < size; ++idx) {
-        shared_bufs[idx] = Var("red_buf" + std::to_string(idx), PointerType(PrimType(types[idx])));
+        Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx));
+
+        shared_bufs[idx] = buffer;
+        shared_buffer_vars[idx] = buffer->data;
+
         PrimExpr pred = const_true(types[idx].lanes());
-        seq.emplace_back(Store(shared_bufs[idx], values[idx],
-                               BufIndex(reduce_index, group_index, reduce_extent), pred));
+        seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
+                                     {BufIndex(reduce_index, group_index, reduce_extent)}));
       }
       seq.emplace_back(SyncThread("shared"));
       seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
                                         reduce_extent, group_extent, contiguous_reduce_extent));
       for (size_t idx = 0; idx < size; ++idx) {
-        ICHECK(!load_remap_.count(buffers[idx]));
+        ICHECK(!load_remap_.count(buffers[idx]->data.get()));
         PrimExpr pred = const_true(types[idx].lanes());
-        load_remap_[buffers[idx]] =
-            Load(types[idx], shared_bufs[idx],
-                 BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
-        alloc_remap_[buffers[idx]] =
-            Allocate(shared_bufs[idx], types[idx],
+        BufferLoad load(shared_bufs[idx],
+                        {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
+        ICHECK_EQ(load->dtype, types[idx]);
+        load_remap_[buffers[idx]->data.get()] = load;
+        alloc_remap_[buffers[idx]->data.get()] =
+            Allocate(shared_bufs[idx]->data, types[idx],
                      {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
-        store_remap_[buffers[idx]] = shared_bufs[idx];
+        var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
+        store_remap_[buffers[idx].get()] = shared_bufs[idx];
       }
     }
 
     // Fix all local allocations as all statements are built.
     Stmt body = SeqStmt::Flatten(seq);
-    for (auto var : local_vars) {
-      const AllocateNode* repl = var.as<AllocateNode>();
-      if (repl) {
-        body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body);
-        new_storage_scopes_[repl->buffer_var.get()] = "local";
-      }
+    for (Buffer buf : local_bufs) {
+      body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
+      new_storage_scopes_[buf->data.get()] = "local";
     }
 
     return body;
@@ -430,8 +502,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
 
   // make allreduce.
   Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector<DataType>& types,
-                        const Array<Var>& shared_bufs, PrimExpr reduce_index, PrimExpr group_index,
-                        int reduce_extent, int group_extent, int contiguous_reduce_extent) {
+                        const Array<Buffer>& shared_bufs, PrimExpr reduce_index,
+                        PrimExpr group_index, int reduce_extent, int group_extent,
+                        int contiguous_reduce_extent) {
     // Get next power of two
     int reduce_align = 1;
     while (reduce_extent > reduce_align) {
@@ -446,10 +519,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     auto fload = [&](int offset) {
       Array<PrimExpr> a, b;
       for (size_t i = 0; i < size; ++i) {
-        b.push_back(Load(types[i], shared_bufs[i],
-                         BufIndex(reduce_index + offset, group_index, reduce_extent),
-                         const_true()));
-        a.push_back(Load(types[i], shared_bufs[i], buf_index, const_true()));
+        BufferLoad b_load(shared_bufs[i],
+                          {BufIndex(reduce_index + offset, group_index, reduce_extent)});
+        ICHECK_EQ(b_load->dtype, types[i]);
+        b.push_back(b_load);
+
+        BufferLoad a_load(shared_bufs[i], {buf_index});
+        ICHECK_EQ(a_load->dtype, types[i]);
+        a.push_back(a_load);
       }
       Array<PrimExpr> ret = (*combiner)(a, b);
       return ret;
@@ -457,7 +534,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     auto fstore = [&](const Array<PrimExpr>& ret) {
       std::vector<Stmt> stores(size);
       for (size_t i = 0; i < size; ++i) {
-        stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true());
+        stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index});
       }
       return SeqStmt::Flatten(stores);
     };
@@ -567,10 +644,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   }
 
   // Emit warp shuffle  calls.
-  PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, PrimExpr delta_or_lane) {
-    PrimExpr pred = const_true(1);
-    PrimExpr index(0);
-    PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred);
+  PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) {
+    Array<PrimExpr> indices = {0};
+    PrimExpr mask = BufferLoad(mask_buffer, indices);
     PrimExpr width = IntImm(DataType::Int(32), warp_size_);
     Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
     return Call(val.dtype(), op, args);
@@ -640,9 +716,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   // The load remap
   std::unordered_map<const VarNode*, PrimExpr> load_remap_;
   // The store remap
-  std::unordered_map<const VarNode*, Var> store_remap_;
+  std::unordered_map<const BufferNode*, Buffer> store_remap_;
   // Allocate remap
   std::unordered_map<const VarNode*, Stmt> alloc_remap_;
+  // BufferVar remap
+  std::unordered_map<const VarNode*, Var> var_remap_;
+  // Buffer remap
+  std::unordered_map<const BufferNode*, Buffer> buf_remap_;
   // Allocate from warp reductions
   std::unordered_set<const void*> warp_allocs_;
   // Internal analyzer
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index bcf763c..7f0631d 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -34,16 +34,125 @@
 namespace tvm {
 namespace tir {
 
+class StackSizeChecker : public StmtExprVisitor {
+ public:
+  struct StackSizes {
+    // If a tvm_stack_make_shape call has no arguments, it is still
+    // valid and represents a scalar shape ().  Therefore, -1 is used
+    // to represent "no shape arguments exist", while 0 represents
+    // "shape arguments exist, all of which are size 0".
+    int64_t shape_stack{-1};
+    uint64_t array_stack{0};
+    uint64_t arg_stack{0};
+  };
+
+  static StackSizes Check(Stmt stmt) {
+    StackSizeChecker visitor;
+    visitor.VisitStmt(stmt);
+    return visitor.max_stack_;
+  }
+
+ private:
+  void VisitStmt_(const ForNode* op) final {
+    if (op->kind == ForKind::kParallel) {
+      // Parallel for loops have their own stack and allocations, so
+      // stop the recursion here.
+      return;
+    } else {
+      this->VisitStmt(op->body);
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::tvm_call_packed())) {
+      return MakeCallPacked(op, /* use_string_lookup */ true);
+    } else if (op->op.same_as(builtin::tvm_call_cpacked())) {
+      return MakeCallPacked(op, /* use_string_lookup */ false);
+    } else if (op->op.same_as(builtin::tvm_call_trace_packed())) {
+      return MakeCallTracePacked(op);
+    } else if (op->op.same_as(builtin::tvm_stack_make_shape())) {
+      return MakeShape(op);
+    } else if (op->op.same_as(builtin::tvm_stack_make_array())) {
+      return MakeArray(op);
+    } else {
+      return StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  // call shape
+  void MakeShape(const CallNode* op) {
+    // if args.size() == 0, it is still valid and represents a scalar
+    // shape ().  Therefore, -1 is used to represent "no shape
+    // arguments exist", while 0 represents "shape arguments exist,
+    // all of which are size 0".
+    if (current_stack_.shape_stack == -1) {
+      current_stack_.shape_stack = 0;
+    }
+    current_stack_.shape_stack += op->args.size();
+    StmtExprVisitor::VisitExpr_(op);
+  }
+  // make array
+  void MakeArray(const CallNode* op) {
+    current_stack_.array_stack += 1;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+  // call packed.
+  void MakeCallPacked(const CallNode* op, bool use_string_lookup) {
+    StackSizes restore_stack = current_stack_;
+
+    size_t arg_count = op->args.size();
+
+    // cpacked expects a resource_handle parameter
+    if (!use_string_lookup) {
+      arg_count--;
+    }
+
+    current_stack_.arg_stack += arg_count;
+    // Specially handle the buffer packed intrinsic
+    StmtExprVisitor::VisitExpr_(op);
+    // Record the amount of stack space needed, then reset the stack
+    // position to its previous location.
+    UpdateMaxStack();
+    current_stack_ = restore_stack;
+  }
+
+  void MakeCallTracePacked(const CallNode* op) {
+    StackSizes restore_stack = current_stack_;
+
+    size_t args_size = op->args.size();
+    ICHECK_GT(args_size, 0);
+    current_stack_.arg_stack += args_size;
+
+    StmtExprVisitor::VisitExpr_(op);
+    // Record the amount of stack space needed, then reset the stack
+    // position to its previous location.
+    UpdateMaxStack();
+    current_stack_ = restore_stack;
+
+    // However, the arguments to this CallNode remain on top of the
+    // stack, so we can use more than one packed function's arguments
+    // with the one stack.
+    current_stack_.arg_stack = restore_stack.arg_stack + args_size - 1;
+  }
+
+  void UpdateMaxStack() {
+    max_stack_.arg_stack = std::max(current_stack_.arg_stack, max_stack_.arg_stack);
+    max_stack_.shape_stack = std::max(current_stack_.shape_stack, max_stack_.shape_stack);
+    max_stack_.array_stack = std::max(current_stack_.array_stack, max_stack_.array_stack);
+  }
+
+  StackSizes current_stack_;
+  StackSizes max_stack_;
+};
+
 // Calculate the statistics of packed function.
 // These information are needed during codegen.
 class BuiltinLower : public StmtExprMutator {
  public:
   // Record stack frame for existing scope.
   struct AllocaScope {
-    Var stack_shape = Var("stack_shape", DataType::Handle());
+    Buffer stack_shape;
     Var stack_array = Var("stack_array", DataType::Handle());
     Var stack_value = Var("stack_value", DataType::Handle());
-    Var stack_tcode = Var("stack_tcode", DataType::Handle());
+    Buffer stack_tcode;
 
     int64_t max_shape_stack{-1};
     uint64_t max_array_stack{0};
@@ -58,21 +167,41 @@ class BuiltinLower : public StmtExprMutator {
 
   // Allcoate stack frames, only at parallel-for or root.
   Stmt VisitBodyAndRealizeAlloca(Stmt stmt) {
+    // Initial check to identify maximum stack sizes.  These are used
+    // to construct Buffer objects to hold the stack, which are then
+    // used when mutating.
+    auto max_sizes = StackSizeChecker::Check(stmt);
+
     alloca_scope_.emplace_back();
-    stmt = this->VisitStmt(stmt);
-    ICHECK(!alloca_scope_.empty());
     auto& scope = alloca_scope_.back();
-    if (scope.max_shape_stack != -1) {
-      stmt = LetStmt(scope.stack_shape, StackAlloca("shape", scope.max_shape_stack), stmt);
+
+    if (max_sizes.shape_stack != -1) {
+      scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), max_sizes.shape_stack)},
+                                      DataType::Int(64), "stack_shape");
+      stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", max_sizes.shape_stack), stmt);
     }
 
-    if (scope.max_array_stack != 0) {
-      stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_array_stack), stmt);
+    if (max_sizes.array_stack != 0) {
+      stmt = LetStmt(scope.stack_array, StackAlloca("array", max_sizes.array_stack), stmt);
     }
-    if (scope.max_arg_stack != 0) {
-      stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_arg_stack), stmt);
-      stmt = LetStmt(scope.stack_tcode, StackAlloca("arg_tcode", scope.max_arg_stack), stmt);
+
+    if (max_sizes.arg_stack != 0) {
+      scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), max_sizes.arg_stack)},
+                                      DataType::Int(32), "stack_tcode");
+      stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", max_sizes.arg_stack), stmt);
+
+      stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", max_sizes.arg_stack), stmt);
     }
+
+    // Copy these values from the earlier search, for use in bounds
+    // checks.
+    scope.max_shape_stack = max_sizes.shape_stack;
+    scope.max_array_stack = max_sizes.array_stack;
+    scope.max_arg_stack = max_sizes.arg_stack;
+
+    stmt = this->VisitStmt(stmt);
+
+    ICHECK(!alloca_scope_.empty());
     alloca_scope_.pop_back();
 
     return stmt;
@@ -244,10 +373,10 @@ class BuiltinLower : public StmtExprMutator {
     op = expr.as<CallNode>();
     // no need to perform any store for a scalar shape
     for (size_t i = 0; i < op->args.size(); ++i) {
-      prep_seq.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
-                                  ConstInt32(stack_begin + i), const_true(1)));
+      prep_seq.emplace_back(BufferStore(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
+                                        {ConstInt32(stack_begin + i)}));
     }
-    return AddressOffset(scope.stack_shape, DataType::Int(64), stack_begin);
+    return AddressOffset(scope.stack_shape->data, DataType::Int(64), stack_begin);
   }
   // make array
   PrimExpr MakeArray(const CallNode* op) {
@@ -328,17 +457,16 @@ class BuiltinLower : public StmtExprMutator {
         arg_tcode = kTVMStr;
       }
       if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
-      prep_seq.emplace_back(
-          Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
+      prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
     }
-    // UPDATE stack value
-    scope.max_arg_stack = std::max(scope.run_arg_stack, scope.max_arg_stack);
-    scope.max_shape_stack = std::max(scope.run_shape_stack, scope.max_shape_stack);
-    scope.max_array_stack = std::max(scope.run_array_stack, scope.max_array_stack);
+    // Verify stack size matches earlier value.
+    ICHECK_LE(scope.run_arg_stack, scope.max_arg_stack);
+    ICHECK_LE(scope.run_shape_stack, scope.max_shape_stack);
+    ICHECK_LE(scope.run_array_stack, scope.max_array_stack);
     scope.run_shape_stack = restore_shape_stack;
     scope.run_array_stack = restore_array_stack;
     scope.run_arg_stack = arg_stack_begin;
-    Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode,
+    Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data,
                                    ConstInt32(arg_stack_begin),
                                    ConstInt32(arg_stack_begin + op->args.size() - 1)};
 
@@ -379,19 +507,18 @@ class BuiltinLower : public StmtExprMutator {
                                          builtin::kTVMValueContent, arg));
       int arg_tcode = api_type.code();
       ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
-      prep_seq.emplace_back(
-          Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
+      prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
     }
-    // UPDATE stack value
-    scope.max_arg_stack = std::max(scope.run_arg_stack, scope.max_arg_stack);
-    scope.max_shape_stack = std::max(scope.run_shape_stack, scope.max_shape_stack);
-    scope.max_array_stack = std::max(scope.run_array_stack, scope.max_array_stack);
+    // Verify stack size matches earlier value.
+    ICHECK_LE(scope.run_arg_stack, scope.max_arg_stack);
+    ICHECK_LE(scope.run_shape_stack, scope.max_shape_stack);
+    ICHECK_LE(scope.run_array_stack, scope.max_array_stack);
     scope.run_shape_stack = restore_shape_stack;
     scope.run_array_stack = restore_array_stack;
     // Update the top of the stack, so we can use more than one
     // packed function's arguments with the one stack.
     scope.run_arg_stack = arg_stack_begin + args_size - 1;
-    Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode,
+    Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data,
                                    ConstInt32(arg_stack_begin),
                                    ConstInt32(arg_stack_begin + op->args.size() - 1),
                                    // Pass traced value.
diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc
index f316ae9..4097111 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -114,19 +114,31 @@ class WarpStoreCoeffFinder : private StmtVisitor {
  private:
   /// Visitor implementation
   void VisitStmt_(const StoreNode* op) final {
-    if (op->buffer_var.get() == buffer_) {
-      if (op->value.dtype().lanes() == 1) {
-        UpdatePattern(op->index);
-      } else {
-        arith::PVar<PrimExpr> base;
-        ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index))
-            << "LowerWarpMemory failed due to store index=" << op->index
-            << ", can only handle continuous store";
-        UpdatePattern(base.Eval());
-      }
-    } else {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    if (op->buffer->data.get() != buffer_) {
       StmtVisitor::VisitStmt_(op);
+      return;
+    }
+
+    ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory.  "
+                                     << "Has StorageFlatten (TE-based schedule) or "
+                                     << "FlattenBuffer (TIR-based schedules) been run?";
+
+    PrimExpr index = op->indices[0];
+    if (op->value.dtype().lanes() != 1) {
+      arith::PVar<PrimExpr> base;
+      ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index))
+          << "LowerWarpMemory failed due to store index=" << index
+          << ", can only handle continuous store";
+      UpdatePattern(base.Eval());
+
+      index = base.Eval();
     }
+
+    UpdatePattern(index);
   }
 
   void UpdatePattern(const PrimExpr& index) {
@@ -239,35 +251,62 @@ class WarpAccessRewriter : protected StmtExprMutator {
   }
 
   Stmt VisitStmt_(const StoreNode* op) override {
-    if (op->buffer_var.get() == buffer_) {
-      PrimExpr local_index, group;
-      std::tie(local_index, group) = SplitIndexByGroup(op->index);
-      PrimExpr new_value = VisitExpr(op->value);
-      return Store(op->buffer_var, new_value, local_index, op->predicate);
-    } else {
-      return StmtExprMutator::VisitStmt_(op);
-    }
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) override {
-    if (op->buffer_var.get() == buffer_) {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) override {
+    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+
+    if (store->buffer->data.get() == buffer_) {
+      ICHECK_EQ(store->indices.size(), 1) << "Expected flat memory to use as warp memory.  "
+                                          << "Has StorageFlatten (TE-based schedule) or "
+                                          << "FlattenBuffer (TIR-based schedules) been run?";
+
       PrimExpr local_index, group;
-      std::tie(local_index, group) = SplitIndexByGroup(op->index);
-      // invariance: local index must do not contain warp id
-      ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); }))
-          << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index
-          << " local_index=" << local_index;
-      PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate);
-      if (analyzer_->CanProveEqual(group, warp_index_)) {
-        return load_value;
-      }
-      PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
-      return Call(load_value.dtype(), builtin::tvm_warp_shuffle(),
-                  {mask, load_value, group, width_, warp_size_});
-    } else {
-      return StmtExprMutator::VisitExpr_(op);
+      std::tie(local_index, group) = SplitIndexByGroup(store->indices[0]);
+
+      auto writer = store.CopyOnWrite();
+      writer->indices = {local_index};
+    }
+
+    return std::move(store);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) override {
+    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+
+    if (load->buffer->data.get() != buffer_) {
+      return std::move(load);
+    }
+
+    ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory.  "
+                                     << "Has StorageFlatten (TE-based schedule) or "
+                                     << "FlattenBuffer (TIR-based schedules) been run?";
+
+    PrimExpr local_index, group;
+    std::tie(local_index, group) = SplitIndexByGroup(op->indices[0]);
+    // invariance: local index must do not contain warp id
+    ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); }))
+        << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0]
+        << " local_index=" << local_index;
+
+    auto writer = load.CopyOnWrite();
+    writer->indices = {local_index};
+
+    if (analyzer_->CanProveEqual(group, warp_index_)) {
+      return std::move(load);
     }
+
+    PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
+    return Call(load.dtype(), builtin::tvm_warp_shuffle(), {mask, load, group, width_, warp_size_});
   }
+
   // Split the index to the two component
   // <local_index, source_index>
   // local index is the index in the local
diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc
index d7e1bef..a31349f 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -61,34 +61,63 @@ class ReturnRewriter : public StmtMutator {
       if (call->op.same_as(builtin::ret())) {
         ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope.";
         ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
-        ret = WriteToOut(call->args[0], ret_var_, ret_tcode_);
+        ret = WriteToOut(call->args[0]);
       }
     }
     return ret;
   }
 
  private:
-  std::pair<int, PrimExpr> ConvertForFFI(PrimExpr val) {
+  struct ConvertedInfo {
+    int tcode{-1};
+    PrimExpr expr;
+    Buffer dummy_val_buffer;
+    Buffer dummy_tcode_buffer;
+  };
+
+  ConvertedInfo ConvertForFFI(PrimExpr val) {
+    ConvertedInfo info;
+
     // convert val's data type to FFI data type, return type code
     DataType dtype = val.dtype();
     if (dtype.is_int() || dtype.is_uint()) {
-      return {kTVMArgInt, Cast(DataType::Int(64), val)};
+      info.tcode = kTVMArgInt;
+      info.expr = Cast(DataType::Int(64), val);
     } else if (dtype.is_float()) {
-      return {kTVMArgFloat, Cast(DataType::Float(64), val)};
+      info.tcode = kTVMArgFloat;
+      info.expr = Cast(DataType::Float(64), val);
     } else if (dtype.is_void()) {
-      return {kTVMNullptr, val};
+      info.tcode = kTVMNullptr;
+      info.expr = val;
     } else {
       LOG(FATAL) << "data type " << dtype << " not supported yet";
     }
-    return {kTVMNullptr, val};
+
+    // If multiple return locations have the same data type, use the
+    // same dummy buffer declaration.
+    auto it = dummy_val_buffer_map_.find(info.tcode);
+    if (it != dummy_val_buffer_map_.end()) {
+      info.dummy_val_buffer = it->second;
+    } else {
+      info.dummy_val_buffer = Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0),
+                                     ret_var_->name_hint, 0, 0, kDefault);
+      dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer;
+    }
+
+    // The tcode is always a 32-bit int, so we don't need to have a separate map.
+    if (!dummy_tcode_buffer_.defined()) {
+      dummy_tcode_buffer_ = Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0),
+                                   ret_tcode_->name_hint, 0, 0, kDefault);
+    }
+    info.dummy_tcode_buffer = dummy_tcode_buffer_;
+
+    return info;
   }
 
-  Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) {
-    auto p = ConvertForFFI(val);
-    int tcode = p.first;
-    val = p.second;
-    Stmt store_val = Store(ret_var_, val, 0, const_true());
-    Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true());
+  Stmt WriteToOut(PrimExpr val) {
+    auto info = ConvertForFFI(val);
+    Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0});
+    Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0});
     Stmt ret_zero = Evaluate(tvm::ret(0));
     return SeqStmt({store_val, store_tcode, ret_zero});
   }
@@ -96,6 +125,9 @@ class ReturnRewriter : public StmtMutator {
   Var ret_var_;
   Var ret_tcode_;
   int in_parallel_{0};
+
+  std::unordered_map<int, Buffer> dummy_val_buffer_map_;
+  Buffer dummy_tcode_buffer_;
 };
 
 Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
@@ -131,10 +163,11 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
   // Data field definitions
   // The packed fields
   Var v_packed_args("args", DataType::Handle());
-  Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle());
+  Buffer buf_packed_arg_type_ids = decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())},
+                                               DataType::Int(32), "arg_type_ids");
   Var v_num_packed_args("num_args", DataType::Int(32));
-  Var v_out_ret_value("out_ret_value", DataType::Handle());
-  Var v_out_ret_tcode("out_ret_tcode", DataType::Handle());
+  Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::Void())));
+  Var v_out_ret_tcode("out_ret_tcode", PointerType(PrimType(DataType::Int(32))));
   Var v_resource_handle("resource_handle", DataType::Handle());
   // The arguments of the function.
   Array<Var> args;
@@ -166,7 +199,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
   // add signature for packed arguments.
   if (pack_args) {
     args.push_back(v_packed_args);
-    args.push_back(v_packed_arg_type_ids);
+    args.push_back(buf_packed_arg_type_ids->data);
     args.push_back(v_num_packed_args);
   }
 
@@ -185,21 +218,21 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
       continue;
     }
 
-    auto it = func_ptr->buffer_map.find(param);
-    if (it != func_ptr->buffer_map.end()) {
-      buffer_def.emplace_back(v_arg, (*it).second);
+    if (func_ptr->preflattened_buffer_map.count(param)) {
+      buffer_def.emplace_back(v_arg, func_ptr->preflattened_buffer_map[param]);
+    } else if (func_ptr->buffer_map.count(param)) {
+      buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]);
     } else {
       var_def.emplace_back(v_arg, param);
     }
+
     if (i < num_packed_args) {
       // Value loads
       seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
       // type code checks
       Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
-      seq_init.emplace_back(LetStmt(tcode,
-                                    Load(DataType::Int(32), v_packed_arg_type_ids,
-                                         IntImm(DataType::Int(32), i), const_true(1)),
-                                    nop));
+      seq_init.emplace_back(
+          LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop));
       DataType t = v_arg.dtype();
       if (t.is_handle()) {
         std::ostringstream msg;
diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
index b10e443..e61af84 100644
--- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
+++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
@@ -102,12 +102,17 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
     alloc_info_[buf].level = level;
     StmtExprVisitor::VisitStmt_(op);
   }
+
   void VisitStmt_(const StoreNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
     scope_.push_back(StmtEntry());
     // visit subexpr
     StmtExprVisitor::VisitStmt_(op);
     // Add write access.
-    const VarNode* buf = op->buffer_var.get();
+    const VarNode* buf = op->buffer->data.get();
     auto it = alloc_info_.find(buf);
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size());
@@ -122,6 +127,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
       linear_seq_.push_back(e);
     }
   }
+
   void VisitStmt_(const EvaluateNode* op) final {
     scope_.push_back(StmtEntry());
     // visit subexpr
@@ -133,10 +139,15 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
       linear_seq_.push_back(e);
     }
   }
+
   void VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
     // Add write access.
     StmtExprVisitor::VisitExpr_(op);
-    const VarNode* buf = op->buffer_var.get();
+    const VarNode* buf = op->buffer->data.get();
     auto it = alloc_info_.find(buf);
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
@@ -145,10 +156,13 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
       }
     }
   }
+
   void VisitExpr_(const CallNode* op) final {
     if (op->op.same_as(builtin::address_of())) {
-      const LoadNode* l = op->args[0].as<LoadNode>();
-      this->VisitExpr(l->index);
+      const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
+      for (const auto& index : load->indices) {
+        this->VisitExpr(index);
+      }
     } else {
       StmtExprVisitor::VisitExpr_(op);
     }
@@ -294,22 +308,61 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    if (IsDynamicSharedMemory(op->buffer_var)) {
-      PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype);
-      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
-      return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span);
-    }
-    return StmtExprMutator::VisitExpr_(op);
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    if (IsDynamicSharedMemory(op->buffer_var)) {
-      PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype);
-      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
-      PrimExpr value = StmtExprMutator::VisitExpr(op->value);
-      return Store(merged_buf_var_, value, offset + index, op->predicate, op->span);
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    if (IsDynamicSharedMemory(node->buffer->data)) {
+      ICHECK_EQ(node->indices.size(), 1)
+          << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
+          << "and is to be run after "
+          << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
+      Array<PrimExpr> indices = {node->indices[0] +
+                                 this->GetBufferOffset(node->buffer->data, node->buffer->dtype)};
+
+      auto writer = node.CopyOnWrite();
+      writer->buffer = GetUpdatedBuffer(node->buffer);
+      writer->indices = indices;
     }
-    return StmtExprMutator::VisitStmt_(op);
+
+    return node;
+  }
+
+  Buffer GetUpdatedBuffer(Buffer buffer) {
+    auto key = buffer.get();
+    auto it = buffer_remap_.find(key);
+    if (it != buffer_remap_.end()) {
+      return it->second;
+    }
+
+    if (IsDynamicSharedMemory(buffer->data)) {
+      ICHECK_EQ(buffer->shape.size(), 1)
+          << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
+          << "and is to be run after "
+          << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
+      auto writer = buffer.CopyOnWrite();
+      writer->data = merged_buf_var_;
+    }
+
+    buffer_remap_[key] = buffer;
+    return buffer;
   }
 
   PrimExpr VisitExpr_(const CallNode* op) final {
@@ -542,6 +595,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
   PrimExpr merged_alloc_size_{0};
   // The mapping from the original buffer var to its offset in the merged buffer
   std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets_;
+  // The mapping from the original buffer objects to their location in the merged buffer.
+  std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
   // The flag indicating whether the merged buffer has been allocated
   bool allocated_{false};
   // Locations of free ops.
diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index dd5f54e..d5d1456 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -205,12 +205,52 @@ class DataTypeRewriter : public StmtExprMutator {
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    PrimExpr value = this->VisitExpr(op->value);
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore store = GetRef<BufferStore>(op);
+
+    auto value = this->VisitExpr(op->value);
+    auto indices = VisitIndices(op->indices);
+
+    if (!value.same_as(op->value) || !indices.same_as(op->indices)) {
+      auto writer = store.CopyOnWrite();
+      writer->value = value;
+      writer->indices = indices;
+    }
+
+    return std::move(store);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    BufferLoad load = GetRef<BufferLoad>(op);
+
+    auto indices = VisitIndices(op->indices);
+
+    if (!indices.same_as(op->indices)) {
+      auto writer = load.CopyOnWrite();
+      writer->indices = indices;
+    }
+
+    return std::move(load);
+  }
+
... 14486 lines suppressed ...