You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/07/28 18:13:38 UTC

[tvm] branch main updated: [TIR] Asynchronous stage in software pipeline (#12171)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c737fbd5b [TIR] Asynchronous stage in software pipeline (#12171)
3c737fbd5b is described below

commit 3c737fbd5baccc60aff355b40105220c148b7d7f
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Jul 29 03:13:31 2022 +0900

    [TIR] Asynchronous stage in software pipeline (#12171)
    
    * [TIR] Support asynchronous stages in software pipeline transform
    
    * Support interleaved async producers separated by a consumer
    
    * clean up
    
    * adding doc
    
    * adding doc
    
    * simplifying
    
    * make wait count computation a two pass process
    
    * commit_stage -> commit_queue, wait_stage -> wait_queue
    
    * make async_commit_queue special scope stmt
    
    * codegen async_commit_queue in cuda
    
    * clean up
    
    * clean up
    
    * Move block predicate outside of commit_queue
    
    * updating test
    
    * test updated
    
    * changed async_wait to an annotation
    
    * update doc
    
    * update meaning of software_pipeline_async_stages
    
    * update test
    
    * fixing codegen
    
    * more fix
    
    * remove one of tests that have async and sync ops in the same stage
    
    * format
    
    * lint and other fix
    
    * Define attr::software_pipeline_async_stages
    
    * populate wait count in a separate function
    
    * fold variabel consumed into AsyncStateLocal
    
    * introduce CompletePipelineLoopStatements function for further refactor
---
 include/tvm/tir/stmt.h                             |  27 ++
 src/target/source/codegen_cuda.cc                  |  18 +
 src/tir/transforms/inject_software_pipeline.cc     | 448 ++++++++++++++++++--
 src/tir/transforms/ir_utils.cc                     |   7 +
 src/tir/transforms/ir_utils.h                      |   5 +
 src/tir/transforms/remove_no_op.cc                 |  16 +
 src/tir/transforms/thread_storage_sync.cc          |  45 ++
 .../test_tir_transform_inject_software_pipeline.py | 469 +++++++++++++++++++--
 8 files changed, 966 insertions(+), 69 deletions(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 2060fb7920..5dd4103e82 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1448,6 +1448,27 @@ constexpr const char* device_scope = "device_scope";
  */
 constexpr const char* async_scope = "async_scope";
 
+/*!
+ * \brief Annotations for invoking and synchronizing asynchronous operations.
+
+ * Synchronization is done in terms of "queue": It is an abstract entity associated
+ * with each asynchronous unit, and it tracks invocations and completions of asynchronous
+ * operations in the FIFO order.
+ *
+ * Similarly to PTX instructions commit_group and wait_group, these annotations express
+ * synchronization by "counting":
+ *
+ * async_commit_queue(i): Group one or more invocations of async operations in the given scope,
+ * and "commit" (or push) them to the queue i. A group of operations committed together is
+ * awaited as one chunk. Groups committed to the same queue complete in the FIFO order.
+ *
+ * async_wait_queue(i, N): Block until only N most recent committed groups are still in-flight at
+ * the queue i. N does not have to be a constant, but some backends may require a constant count.
+*/
+constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
+constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
+constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";
+
 /*!
  * \brief Mark that the shape of TensorCore fragment
  */
@@ -1483,6 +1504,12 @@ constexpr const char* software_pipeline_stage = "software_pipeline_stage";
 /*! \brief Mark the order of a statement in the software pipeline */
 constexpr const char* software_pipeline_order = "software_pipeline_order";
 
+/*! \brief List stages in the software pipeline that should run asynchronously
+ * \note All statements in the provided stages are assumed to have asynchronous
+ *       semantics (e.g. CUDA async global to shared memory copy).
+ */
+constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";
+
 /*! \brief Mark the buffers which is const access and can be transformed layout. */
 constexpr const char* layout_free_buffers = "layout_free_buffers";
 
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index 616e75f2e7..3ea6f8d9ed 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -917,6 +917,24 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
     const VarNode* buffer = op->node.as<VarNode>();
     const StringImmNode* layout_str = op->value.as<StringImmNode>();
     fragment_layouts[buffer] = layout_str->value;
+  } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
+    const IntImmNode* queue_id = op->value.as<IntImmNode>();
+    ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
+    this->VisitStmt(op->body);
+    auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
+    this->VisitExpr(commit_group, this->stream);
+    return;
+  } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
+    auto wait_attrs = GetAsyncWaitAttributes(op);
+    auto queue_id = wait_attrs.first.as<IntImmNode>();
+    ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
+    auto wait_cnt = wait_attrs.second;
+    auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
+    this->VisitExpr(wait_group, this->stream);
+    auto inner = op->body.as<AttrStmtNode>();
+    ICHECK(inner);
+    this->VisitStmt(inner->body);
+    return;
   }
   CodeGenC::VisitStmt_(op);
 }
diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc
index b4a597fe97..227935bf72 100644
--- a/src/tir/transforms/inject_software_pipeline.cc
+++ b/src/tir/transforms/inject_software_pipeline.cc
@@ -25,6 +25,8 @@
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/transform.h>
 
+#include <unordered_set>
+
 #include "../../support/utils.h"
 #include "../schedule/utils.h"
 #include "./ir_utils.h"
@@ -60,13 +62,14 @@ Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& buffer_data_to_buffer)
   return block;
 }
 
-/*! Structure that represents the stage and order of the software pipeline component. */
-struct PipelineStageOrder {
+/*! Structure that represents the provided annotation per block or loop. */
+struct PipelineAnnotation {
   int stage;
   int order;
+  bool async;
 };
 
-using PipelineInfo = std::unordered_map<Block, PipelineStageOrder, ObjectPtrHash, ObjectPtrEqual>;
+using PipelineInfo = std::unordered_map<Block, PipelineAnnotation, ObjectPtrHash, ObjectPtrEqual>;
 
 struct BufferAccessInfo {
   int def = -1;  // the defining stage of the buffer
@@ -99,6 +102,8 @@ class PipelineOpaqueAccessRewriter {
     static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync();
     static const auto& mma_sync = builtin::tvm_mma_sync();
     static const auto& access_ptr = builtin::tvm_access_ptr();
+    static const auto& ptx_ldmatrix = builtin::ptx_ldmatrix();
+    static const auto& ptx_mma = builtin::ptx_mma();
     if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) {
       const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[0]));
       auto it = buffer_remap_.find(buffer);
@@ -122,24 +127,11 @@ class PipelineOpaqueAccessRewriter {
       }
       return Call(call->dtype, call->op, new_args, call->span);
     } else if (call->op.same_as(access_ptr)) {
-      const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[1]));
-      auto it = buffer_remap_.find(buffer);
-      if (it != buffer_remap_.end()) {
-        Array<PrimExpr> new_args = call->args;
-        const Buffer& new_buffer = (*it).second;
-        const PrimExpr& old_index = call->args[2];
-        PrimExpr offset;
-        if (new_buffer->strides.empty()) {
-          offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
-                         make_const(DataType::Int(32), 1), buffer->shape);
-        } else {
-          offset = new_buffer->strides[0];
-        }
-        PrimExpr new_index =
-            old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
-        new_args.Set(2, new_index);
-        return Call(call->dtype, call->op, new_args, call->span);
-      }
+      return RewriteBufferAccess(call, {1});
+    } else if (call->op.same_as(ptx_mma)) {
+      return RewriteBufferAccess(call, {6, 8, 10});
+    } else if (call->op.same_as(ptx_ldmatrix)) {
+      return RewriteBufferAccess(call, {3});
     }
     return call;
   }
@@ -166,6 +158,32 @@ class PipelineOpaqueAccessRewriter {
     return new_buffer_offset;
   }
 
+  PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) {
+    auto product = [](const Array<PrimExpr>& input) {
+      return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
+                   make_const(DataType::Int(32), 1), input);
+    };
+    Array<PrimExpr> new_args = call->args;
+    for (int i : arg_indices) {
+      const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
+      auto it = buffer_remap_.find(buffer);
+      if (it != buffer_remap_.end()) {
+        const Buffer& new_buffer = (*it).second;
+        const PrimExpr& old_index = call->args[i + 1];
+        PrimExpr offset;
+        if (new_buffer->strides.empty()) {
+          offset = product(buffer->shape);
+        } else {
+          offset = new_buffer->strides[0];
+        }
+        PrimExpr new_index =
+            old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
+        new_args.Set(i + 1, new_index);
+      }
+    }
+    return Call(call->dtype, call->op, new_args, call->span);
+  }
+
   const Map<Var, Buffer>& buffer_data_to_buffer_;
   const Map<Buffer, Buffer>& buffer_remap_;
   const For& pipeline_loop_;
@@ -494,6 +512,267 @@ class PipelineRewriter : public StmtExprMutator {
     return Buffer(new_buffer);
   }
 
+  // Per-stage states that need to be tracked across pipeline prologue, body, and epilogue.
+  struct AsyncStateGlobal {
+    // Buffers that this stage asynchronously writes.
+    std::unordered_set<const BufferNode*> dst_buffers;
+    // An imaginary index that the latest async operation associated with this stage has written
+    // into. Only valid if all associated predicates are true, so that we can count the number of
+    // async invocations exactly. When it is valid, it is the "sum of extents of loops that have
+    // been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This
+    // is only needed to compute wait count for epilogue without async producers.
+    Optional<PrimExpr> producer_head{PrimExpr(-1)};
+
+    bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
+  };
+
+  // Per-stage states that are local to each of pipeline prologue, body, and epilogue.
+  struct AsyncStateLocal {
+    struct {
+      // The index into a list of blocks, where async_wait_queue should be attached at the
+      // beginning.
+      int insert_before;
+      // in_flight_count would be a more precise name, but the implementation uses wait_count for
+      // brevity.
+      PrimExpr wait_count{nullptr};
+
+      bool valid() const { return wait_count.defined(); }
+    } pending_wait;
+
+    // Destination buffers of async operations that have been encountered so far in the loop
+    //
+    // for (size_t i = 0; i < new_blocks.size(); ++i) {
+    //    ...
+    // }
+    //
+    // This is for tracking which async operations have been issued at the "current" iteration, up
+    // until a point where we encounter a consumer of async result buffers. This is used to decide
+    // if the producer_head of each buffer points to a copy written in the current or previous
+    // iteration.
+    std::unordered_set<const BufferNode*> seen;
+
+    // A symbolic expression representing the index the latest async operation associated with this
+    // stage has written into, at the "current" iteration.
+    Optional<PrimExpr> producer_head;
+    // The predicate of BlockRealize containing the async operation of this stage.
+    Optional<PrimExpr> predicate;
+    // Indices into a list of blocks, where async_commit_queue scope should be attached.
+    // If multiple async producers are interleaved with their consumer in between, we need separate
+    // async_commit_queue for each producer. Thus, we need multiple sets of indices.
+    std::vector<std::vector<size_t>> commit_groups;
+
+    // This is set to true when we reach a stage that consumes this async stage.
+    bool consumed{false};
+  };
+
+  /*! Structure holding intermediate information for pipeline loop rewriting. */
+  struct RewrittenBlockInfo {
+    int stage;
+    PrimExpr predicate;
+    Block block;
+    PrimExpr access_index;
+    bool is_async;
+  };
+
+  // Determine where to insert async_wait and the corresponding wait count.
+  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks,
+                          arith::Analyzer* ana_normalized,
+                          const std::unordered_map<const BufferNode*, int>& buffer_to_commit_group,
+                          std::map<int, AsyncStateLocal>* async_states_local) {
+    for (size_t i = 0; i < new_blocks.size(); ++i) {
+      if (new_blocks[i].is_async) {
+        // Record the fact that we have encountered these write buffers.
+        for (auto write_region : new_blocks[i].block->writes) {
+          (*async_states_local)[new_blocks[i].stage].seen.insert(write_region->buffer.get());
+        }
+      }
+
+      int producer_stage_idx = -1;
+      for (auto read_region : new_blocks[i].block->reads) {
+        for (auto kv : async_states) {
+          if (kv.first <= new_blocks[i].stage && kv.second.writes(read_region->buffer)) {
+            // Found an earlier stage where read_region->buffer was asynchronously written
+            ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first)
+                << "A dependency on multiple async stages is not supported";
+            producer_stage_idx = kv.first;
+          }
+        }
+      }
+
+      if (producer_stage_idx == -1) continue;
+
+      // The following logic has become complicated to handle case like this:
+      //
+      // for i in range(13):
+      //     # Stage 0
+      //     async_commit_queue(0):
+      //        async_scope:
+      //           A_shared[(i + 3) % 4] = A[...]
+      //
+      //
+      //     # Stage 1
+      //     async_wait_queue(0, 5):
+      //        compute(A_shared[i], B_shared[i])
+      //
+      //     # Stage 0
+      //     async_commit_queue(0)
+      //        async_scope:
+      //           B_shared[(i + 3) % 4] = B[...]
+      //
+      //
+      // Here, multiple async producers in the same stage are interleaved with their consumer in
+      // between. Since each buffer is associated with different commit groups, the wait_count
+      // before the consumer should be bigger than the simpler case:
+      //
+      // for i in range(13):
+      //     # Stage 0
+      //     async_commit_queue(0):
+      //        async_scope:
+      //           A_shared[(i + 3) % 4] = A[...]
+      //           B_shared[(i + 3) % 4] = B[...]
+      //
+      //     # Stage 1
+      //     async_wait_queue(0, 3):
+      //        compute(A_shared[i], B_shared[i])
+      //
+      // The correct wait_count can be determined by considering each commit group separately, and
+      // summing "per-commit" wait_counts.
+      //
+      // From A_shared's perspective, it allows for (i + 3) - i async commit groups to be in
+      // flight while from B_shared's perspective, the producer head at compute points to the copy
+      // done by the previous iteration, so its wait_count is calculated as ((i - 1) + 3) - i. The
+      // sum of the two wait_counts gives 5.
+
+      auto& dep_local_state = (*async_states_local)[producer_stage_idx];
+      const auto num_commit_group = dep_local_state.commit_groups.size();
+      std::vector<Optional<PrimExpr>> producer_head_per_commit;
+
+      if (num_commit_group == 0) {
+        // Epilogue, no async producer. Since "local" producer_head is not available, use
+        // "global" producer_head.
+        ICHECK(!dep_local_state.producer_head);
+        producer_head_per_commit.push_back(async_states[producer_stage_idx].producer_head);
+      } else {
+        ICHECK(dep_local_state.producer_head);
+        std::vector<bool> need_wait_count(num_commit_group, true);
+
+        for (auto read_region : new_blocks[i].block->reads) {
+          if (!async_states[producer_stage_idx].writes(read_region->buffer)) continue;
+          auto commit_group_id = buffer_to_commit_group.at(read_region->buffer.get());
+          if (!need_wait_count[commit_group_id]) continue;
+
+          if (!dep_local_state.seen.count(read_region->buffer.get())) {
+            // Multiple async producers interleaved: The most recent async write is from the
+            // previous iteration. This is the B_shared case above.
+            producer_head_per_commit.push_back(dep_local_state.producer_head.value() - 1);
+          } else {
+            // Normal case
+            producer_head_per_commit.push_back(dep_local_state.producer_head.value());
+          }
+
+          need_wait_count[commit_group_id] = false;
+        }
+      }
+
+      auto wait_count = [=, &ana_normalized]() {
+        auto sum = PrimExpr(0);
+        for (auto producer_head : producer_head_per_commit) {
+          if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) {
+            // Here, new_blocks[i].access_index corresponds to "consumer_head".
+            // The difference of producer_head and consumer_head is precisely the number of
+            // async commit groups that can still be in flight after this wait.
+            sum += analyzer_.Simplify(producer_head.value() - new_blocks[i].access_index);
+          } else {
+            // The precise count cannot be determined, give up.
+            return PrimExpr(0);
+          }
+        }
+        return sum;
+      }();
+
+      auto& pending_wait = dep_local_state.pending_wait;
+
+      if (!pending_wait.valid()) {
+        pending_wait = {static_cast<int>(i), wait_count};
+      } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) {
+        // Coalesce multiple wait_queue if the later one allows fewer in-flight ops.
+        pending_wait = {pending_wait.insert_before, wait_count};
+      }
+    }
+  }
+
+  // Given pipelined blocks and async-related information, generate final loop statements with async
+  // scopes (if any).
+  Array<Stmt> CompletePipelineLoopStatements(
+      const std::vector<RewrittenBlockInfo>& blocks,
+      const std::map<int, AsyncStateLocal>& async_states_local,
+      arith::Analyzer* ana_normalized) const {
+    std::vector<RewrittenBlockInfo> new_blocks = blocks;
+    std::vector<int> commit_group_indices(new_blocks.size(), -1);
+    for (const auto& kv : async_states_local) {
+      const int stage_id = kv.first;
+      const AsyncStateLocal& state = kv.second;
+
+      if (!state.commit_groups.empty()) {
+        for (size_t i = 0; i < state.commit_groups.size(); ++i) {
+          for (size_t j = 0; j < state.commit_groups[i].size(); ++j) {
+            ICHECK(state.commit_groups[i][0] + j < new_blocks.size());
+            commit_group_indices[state.commit_groups[i][0] + j] = stage_id;
+          }
+        }
+      }
+
+      if (state.pending_wait.valid()) {
+        auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) {
+          auto& block = new_blocks[i].block;
+          BlockNode* n = block.CopyOnWrite();
+          auto zero = make_zero(DataType::Int(32));
+          n->body =
+              AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
+                       AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body));
+        };
+
+        if (state.predicate && !ana_normalized->CanProve(state.predicate.value())) {
+          // If the async operation that this wait_queue is waiting on is predicated, and we cannot
+          // prove that the predicate is always true, the precise wait count is only valid
+          // at iterations where the predicate is true;
+          auto wait_count = Call(DataType::Int(32), builtin::if_then_else(),
+                                 {state.predicate.value(), state.pending_wait.wait_count, 0});
+          attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count);
+        } else {
+          attach_wait_scope(state.pending_wait.insert_before, stage_id,
+                            state.pending_wait.wait_count);
+        }
+      }
+    }
+
+    Array<Stmt> stmts;
+
+    for (size_t i = 0; i < new_blocks.size();) {
+      if (commit_group_indices[i] == -1) {
+        // A synchrnous block, not part of any commit group
+        stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block));
+        ++i;
+      } else {
+        Array<Stmt> group_bodies;
+        auto stage_id = commit_group_indices[i];
+        auto predicate = new_blocks[i].predicate;
+        for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) {
+          ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate))
+              << "Predicates in the same stage are expected to be identical";
+          group_bodies.push_back(new_blocks[i].block->body);
+        }
+        auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0];
+        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
+                                           tir::attr::async_commit_queue_scope, stage_id, body);
+        auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
+        stmts.push_back(BlockRealize({}, predicate, new_block));
+      }
+    }
+
+    return stmts;
+  }
+
   /*!
    * \brief Emit the pipeline loop in the given range.
    * \param start The start of the range
@@ -502,7 +781,6 @@ class PipelineRewriter : public StmtExprMutator {
    * \return The result loop.
    */
   Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) {
-    Array<Stmt> stmts;
     PrimExpr new_loop_var;
     PrimExpr extent = end - start;
 
@@ -519,6 +797,19 @@ class PipelineRewriter : public StmtExprMutator {
       analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
     }
 
+    // In contrast to analyzer_ which is bound to [start, end), this one is bound to
+    // the "normalized" range, [pipeline_loop_->min, extent).
+    arith::Analyzer ana_normalized;
+    if (!is_unit_loop) {
+      ana_normalized.Bind(Downcast<Var>(new_loop_var), Range(pipeline_loop_->min, extent));
+    }
+
+    std::vector<RewrittenBlockInfo> new_blocks;
+
+    // Async related
+    std::map<int, AsyncStateLocal> async_states_local;
+    std::unordered_map<const BufferNode*, int> buffer_to_commit_group;
+
     for (const Block& block : ordered_stmts_) {
       int stage = pipeline_info_.at(block).stage;
       PrimExpr skewed_loop_var = new_loop_var - stage;
@@ -530,20 +821,78 @@ class PipelineRewriter : public StmtExprMutator {
       Block new_block = Downcast<Block>(PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                                                              pipeline_loop_, max_stage_ != 1,
                                                              fragment_info_)(block));
-      Map<Var, PrimExpr> subst_map;
-      if (is_unit_loop) {
-        subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var);
-      } else {
-        // normalize loop range
-        PrimExpr delta = start - pipeline_loop_->min;
-        subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + delta);
+
+      PrimExpr delta = start - pipeline_loop_->min;
+      // This variable corresponds to
+      // - "producer_head" if this stage is an async producer
+      // - "consumer_head" if this stage reads from asynchronously written buffers.
+      PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
+
+      // Adjust the block predicate and the body according to the final loop bound
+      //  [pipeline_loop_->min, extent).
+      if (!is_unit_loop) {
         Var loop_iter = Downcast<Var>(new_loop_var);
-        inbound = Substitute(inbound, Map<Var, PrimExpr>{{loop_iter, loop_iter + delta}});
+        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
+      }
+
+      new_block = Downcast<Block>(
+          Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
+
+      if (pipeline_info_[block].async) {
+        auto& local_state = async_states_local[stage];
+
+        int commit_group_id = -1;
+        if (local_state.commit_groups.empty() || local_state.consumed) {
+          // consumed == true means there is already a consumer stage waiting for an
+          // eariler async operation of this stage. In such cases, we make multiple commit_queue
+          // for this stage.
+          commit_group_id = local_state.commit_groups.size();
+          local_state.commit_groups.push_back({new_blocks.size()});
+        } else {
+          // This is the case when one commit_queue groups multiple async blocks.
+          // with commit_queue(stage):
+          //   async_scope:
+          //     A_shared[...] = ...
+          //   async_scope:
+          //     B_shared[...] = ...
+
+          commit_group_id = local_state.commit_groups.size() - 1;
+          local_state.commit_groups.back().push_back(new_blocks.size());
+        }
+
+        for (auto write_region : new_block->writes) {
+          async_states[stage].dst_buffers.insert(write_region->buffer.get());
+          buffer_to_commit_group[write_region->buffer.get()] = commit_group_id;
+        }
+
+        local_state.producer_head = normalized_access_index;
+
+        if (!local_state.predicate || ana_normalized.CanProve(local_state.predicate.value())) {
+          local_state.predicate = inbound;
+        } else if (local_state.predicate) {
+          local_state.predicate = ana_normalized.Simplify(local_state.predicate.value() & inbound);
+        }
+
+        BlockNode* n = new_block.CopyOnWrite();
+        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body);
+      }
+
+      new_blocks.push_back(
+          {stage, inbound, new_block, normalized_access_index, pipeline_info_[block].async});
+
+      for (auto read_region : new_block->reads) {
+        for (auto kv : async_states) {
+          int producer_stage_id = kv.first;
+          if (producer_stage_id <= stage && kv.second.writes(read_region->buffer)) {
+            async_states_local[producer_stage_id].consumed = true;
+          }
+        }
       }
-      new_block = Downcast<Block>(Substitute(new_block, subst_map));
-      stmts.push_back(BlockRealize({}, inbound, new_block));
     }
 
+    PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, &async_states_local);
+    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, &ana_normalized);
+
     Stmt new_loop{nullptr};
 
     if (stmts.empty()) {
@@ -559,6 +908,24 @@ class PipelineRewriter : public StmtExprMutator {
       new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
                      unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop));
     }
+
+    // Update producer heads in the global async states.
+    for (const auto& kv : async_states_local) {
+      const int stage_id = kv.first;
+      const AsyncStateLocal& state = kv.second;
+
+      if (state.predicate && ana_normalized.CanProve(state.predicate.value()) &&
+          async_states[stage_id].producer_head) {
+        // Advance the "global" producer head if it is still valid and we know exactly how much we
+        // can increment
+        async_states[stage_id].producer_head =
+            async_states[stage_id].producer_head.value() + extent;
+      } else {
+        // Otherwise, invalidate the global producer head
+        async_states[stage_id].producer_head = NullOpt;
+      }
+    }
+
     return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
   }
 
@@ -572,6 +939,7 @@ class PipelineRewriter : public StmtExprMutator {
   int max_stage_ = -1;
   Map<Buffer, Buffer> buffer_remap_;
   Array<Block> ordered_stmts_;
+  std::map<int, AsyncStateGlobal> async_states;
 };
 
 /*!
@@ -727,11 +1095,23 @@ class PipelineInjector : private StmtExprMutator {
         Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_order));
     CHECK_EQ(pipeline_stages.size(), original_order.size());
     CHECK_EQ(pipeline_orders.size(), original_order.size());
+
+    std::unordered_set<int> pipeline_async_stages;
+    if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) {
+      for (auto s : Downcast<Array<Integer>>(annot)) {
+        pipeline_async_stages.insert(s->value);
+      }
+    }
+
     for (size_t i = 0; i < pipeline_stages.size(); i++) {
-      PipelineStageOrder stage_order{/*stage=*/static_cast<int>(pipeline_stages[i]->value),
-                                     /*order=*/static_cast<int>(pipeline_orders[i]->value)};
+      int stage = static_cast<int>(pipeline_stages[i]->value);
+      bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end();
+      PipelineAnnotation stage_order{stage,
+                                     /*order=*/static_cast<int>(pipeline_orders[i]->value),
+                                     is_async};
       pipeline_info.emplace(original_order[i], stage_order);
     }
+
     ValidatePipelineBody(pipeline_info, original_order);
 
     // Step 4: Rewrite the pipeline body.
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 700c9931bb..66b04bd678 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -441,5 +441,12 @@ void ConditionalBoundsContext::ExitWithScope() {
   }
 }
 
+std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op) {
+  ICHECK(op && op->attr_key == tir::attr::async_wait_queue_scope);
+  auto inner = op->body.as<AttrStmtNode>();
+  ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count);
+  return std::make_pair(op->value, inner->value);
+}
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 2234cc22bc..d89ee36196 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -35,6 +35,7 @@
 #include <limits>
 #include <string>
 #include <unordered_map>
+#include <utility>
 #include <vector>
 
 namespace tvm {
@@ -306,6 +307,10 @@ struct FragmentInfo {
  */
 std::unordered_map<const VarNode*, FragmentInfo> GetTensorCoreFragmentInfo(const Stmt& stmt);
 
+// Return the queue id and the in-flight count associated with the given
+// attr::async_wait_queue_scope annotation.
+std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op);
+
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_TRANSFORMS_IR_UTILS_H_
diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc
index c8c77b8bad..ce0d9b87c4 100644
--- a/src/tir/transforms/remove_no_op.cc
+++ b/src/tir/transforms/remove_no_op.cc
@@ -21,6 +21,7 @@
  * \file remove_no_op.cc
  * \brief Remove no op from the stmt
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/op.h>
@@ -30,6 +31,8 @@
 
 #include <unordered_map>
 
+#include "ir_utils.h"
+
 namespace tvm {
 namespace tir {
 
@@ -44,7 +47,20 @@ class NoOpRemover : public StmtMutator {
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == "pragma_debug_skip_region") {
       return MakeEvaluate(0);
+    } else if (op->attr_key == attr::async_wait_queue_scope) {
+      auto wait_attrs = GetAsyncWaitAttributes(op);
+      auto wait_cnt = wait_attrs.second;
+      arith::Analyzer ana;
+      if (ana.CanProve(wait_cnt < 0)) {
+        // A negative wait count can arise if it depends on a loop variable.
+        // For example, a wait count 1 - i can be negative after loop unrolling.
+        // We assume that such wait is a nop.
+        auto inner = op->body.as<AttrStmtNode>();
+        ICHECK(inner);
+        return StmtMutator::VisitStmt(inner->body);
+      }
     }
+
     Stmt stmt = StmtMutator::VisitStmt_(op);
     op = stmt.as<AttrStmtNode>();
     return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc
index ce3f8fd3e3..954f4f7cc4 100644
--- a/src/tir/transforms/thread_storage_sync.cc
+++ b/src/tir/transforms/thread_storage_sync.cc
@@ -230,6 +230,48 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
   StorageScope sync_scope_;
 };
 
+// There are cases where necessary syncthreads is not inserted by ThreadSyncInserter.
+// For example, syncthreads is needed after async_wait_queue in the second loop below,
+// but since ThreadSyncInserter is not aware of the asynchronous semantics, it cannot tell
+// that the syncthreads is needed there.
+//
+// // Pipeline prologue
+// for i in range(125):
+//    async_commit_queue(0):
+//       async_scope:
+//          shared[(i + 3) % 4] = ...
+// ...
+//
+// // Pipeline Epilogue
+// for i in range(3):
+//    async_wait_queue(0, 2 - i):
+//       local[...] = shared[(i + 125) % 4]
+
+// This class adds syncthreads after all async_wait_queue. That includes syncthreads that
+// can be inserted by ThreadSyncInserter as well, but ThreadSyncInserter will not insert
+// duplicate syncthreads if it finds an existing one at the synchronization point.
+class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator {
+ public:
+  explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) : sync_scope_(sync_scope) {}
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::async_wait_queue_scope) {
+      auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
+                                {StringImm(sync_scope_.to_string())}));
+      auto inner = op->body.as<AttrStmtNode>();
+      ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count);
+      auto zero = make_zero(DataType::Int(32));
+      auto new_body = SeqStmt({sync, inner->body});
+      return AttrStmt(zero, tir::attr::async_wait_queue_scope, op->value,
+                      AttrStmt(zero, tir::attr::async_wait_inflight_count, inner->value, new_body));
+    }
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  StorageScope sync_scope_;
+};
+
 class ThreadSyncInserter : public StmtExprMutator {
  public:
   ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs)
@@ -384,6 +426,9 @@ class ThreadSyncInserter : public StmtExprMutator {
 
 Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
   StorageScope sync_scope = StorageScope::Create(storage_scope);
+  if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") {
+    stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
+  }
   ThreadSyncPlanner planner(sync_scope);
   planner(stmt);
   return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 2f08249ed7..edaeb7c9b6 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -92,26 +92,32 @@ def transformed_trivial_pipeline(
                 C[tx, 0] = B[0, tx, 0] + T.float32(1)
 
 
-@T.prim_func
-def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
-    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
-        for i in T.serial(
-            0,
-            16,
-            annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]},
-        ):
-            with T.block():
-                T.reads(A[tx, i])
-                T.writes(C[tx, i])
-                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
-                with T.block():
+def gen_simple_compute(num_stages):
+    @T.prim_func
+    def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+        for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+            for i in T.serial(
+                0,
+                16,
+                annotations={
+                    "software_pipeline_stage": [0, num_stages],
+                    "software_pipeline_order": [0, 1],
+                },
+            ):
+                with T.block("compute"):
                     T.reads(A[tx, i])
-                    T.writes(B[tx, 0])
-                    B[tx, 0] = A[tx, i] * T.float32(2)
-                with T.block():
-                    T.reads(B[tx, 0])
                     T.writes(C[tx, i])
-                    C[tx, i] = B[tx, 0] + T.float32(1)
+                    B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+                    with T.block():
+                        T.reads(A[tx, i])
+                        T.writes(B[tx, 0])
+                        B[tx, 0] = A[tx, i] * T.float32(2)
+                    with T.block():
+                        T.reads(B[tx, 0])
+                        T.writes(C[tx, i])
+                        C[tx, i] = B[tx, 0] + T.float32(1)
+
+    return simple_compute
 
 
 @T.prim_func
@@ -156,7 +162,7 @@ def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16),
                 "software_pipeline_order": [0, 1, 2],
             },
         ):
-            with T.block():
+            with T.block("compute"):
                 T.reads(A[tx, i])
                 T.writes(D[tx, i])
                 B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
@@ -991,7 +997,7 @@ def simple_compute_missing_annotation(
 
 
 def test_simple_compute():
-    _check(simple_compute, transformed_simple_compute)
+    _check(gen_simple_compute(1), transformed_simple_compute)
 
 
 def test_trivial_pipeline():
@@ -1034,15 +1040,322 @@ def test_error_missing_annotation():
     _check_error(simple_compute_missing_annotation)
 
 
-@tvm.testing.requires_cuda
-def test_three_stage_gemm():
-    N = K = M = 4096
-    i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1]
+def test_simple_compute_async():
+    mod = tvm.IRModule.from_expr(gen_simple_compute(1))
+    sch = tvm.tir.Schedule(mod)
 
-    def is_ampere_or_newer():
-        arch = tvm.contrib.nvcc.get_target_compute_version()
-        major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
-        return major >= 8
+    _, loop = sch.get_loops(sch.get_block("compute"))
+    sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0])
+    mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod)
+
+    @T.prim_func
+    def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
+        for tx in T.thread_binding(16, thread="threadIdx.x"):
+            with T.block():
+                T.reads(A[tx, 0:16])
+                T.writes(C[tx, 0:16])
+                B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, 0])
+                    T.writes(B[0, tx, 0])
+                    with T.attr(0, "async_commit_queue_scope", 0):
+                        with T.attr(0, "async_scope", 1):
+                            B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2)
+                with T.block():
+                    T.reads(A[tx, 1:16], B[0:2, tx, 0])
+                    T.writes(B[0:2, tx, 0], C[tx, 0:15])
+                    for i in T.serial(15):
+                        with T.block():
+                            T.where(i + 1 < 16)
+                            T.reads(A[tx, i + 1])
+                            T.writes(B[(i + 1) % 2, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 0):
+                                with T.attr(0, "async_scope", 1):
+                                    B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+                        with T.block():
+                            T.where(i + 1 - 1 < 16)
+                            T.reads(B[(i - 1 + 1) % 2, tx, 0])
+                            T.writes(C[tx, i - 1 + 1])
+                            with T.attr(0, "async_wait_queue_scope", 0):
+                                with T.attr(0, "async_wait_inflight_count", 1):
+                                    C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1)
+                with T.block():
+                    T.reads(B[15 % 2, tx, 0])
+                    T.writes(C[tx, 15])
+                    with T.attr(0, "async_wait_queue_scope", 0):
+                        with T.attr(0, "async_wait_inflight_count", 0):
+                            C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1)
+
+    tvm.ir.assert_structural_equal(mod["main"], ref, True)
+
+    mod = tvm.IRModule.from_expr(gen_simple_compute(3))
+    sch = tvm.tir.Schedule(mod)
+
+    _, loop = sch.get_loops(sch.get_block("compute"))
+    sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0])
+    mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod)
+
+    @T.prim_func
+    def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
+        for tx in T.thread_binding(16, thread="threadIdx.x"):
+            with T.block():
+                T.reads(A[tx, 0:16])
+                T.writes(C[tx, 0:16])
+                B = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, 0:3])
+                    T.writes(B[0:3, tx, 0])
+                    for i in T.unroll(3):
+                        with T.block():
+                            T.where(i < 16)
+                            T.reads(A[tx, i])
+                            T.writes(B[i % 4, tx, 0])
+                            T.attr(0, "async_commit_queue_scope", 0)
+                            T.attr(0, "async_scope", 1)
+                            B[i % 4, tx, 0] = A[tx, i] * T.float32(2)
+                with T.block():
+                    T.reads(A[tx, 3:16], B[0:4, tx, 0])
+                    T.writes(B[0:4, tx, 0], C[tx, 0:13])
+                    for i in T.serial(13):
+                        with T.block():
+                            T.where(i + 3 < 16)
+                            T.reads(A[tx, i + 3])
+                            T.writes(B[(i + 3) % 4, tx, 0])
+                            T.attr(0, "async_commit_queue_scope", 0)
+                            T.attr(0, "async_scope", 1)
+                            B[(i + 3) % 4, tx, 0] = A[tx, i + 3] * T.float32(2)
+                        with T.block():
+                            T.where(i + 3 - 3 < 16)
+                            T.reads(B[0:4, tx, 0])
+                            T.writes(C[tx, i - 3 + 3])
+                            with T.attr(0, "async_wait_queue_scope", 0):
+                                with T.attr(0, "async_wait_inflight_count", 3):
+                                    C[tx, i - 3 + 3] = B[(i - 3 + 3) % 4, tx, 0] + T.float32(1)
+                with T.block():
+                    T.reads(B[0:4, tx, 0])
+                    T.writes(C[tx, 13:16])
+                    for i in T.unroll(3):
+                        with T.block():
+                            T.where(i + 16 - 3 < 16)
+                            T.reads(B[0:4, tx, 0])
+                            T.writes(C[tx, i - 3 + 16])
+                            with T.attr(0, "async_wait_queue_scope", 0):
+                                with T.attr(0, "async_wait_inflight_count", 2 - i):
+                                    C[tx, i - 3 + 16] = B[(i - 3 + 16) % 4, tx, 0] + T.float32(1)
+
+    tvm.ir.assert_structural_equal(mod["main"], ref, True)
+
+
+def test_async_producer_interleaving():
+    @T.prim_func
+    def simple_compute(
+        A: T.Buffer[(16, 16), "float32"],
+        B: T.Buffer[(16, 16), "float32"],
+        C: T.Buffer[(16, 16), "float32"],
+    ):
+        for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+            for i in range(16):
+                with T.block("compute"):
+                    T.reads(A[tx, i])
+                    T.writes(C[tx, i])
+                    A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+                    B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+                    with T.block():
+                        T.reads(A[tx, i])
+                        T.writes(A_shared[tx, 0])
+                        A_shared[tx, 0] = A[tx, i]
+                    with T.block():
+                        T.reads(B[tx, i])
+                        T.writes(B_shared[tx, 0])
+                        B_shared[tx, 0] = B[tx, i]
+                    with T.block():
+                        T.reads(A_shared[tx, 0], B_shared[tx, 0])
+                        T.writes(C[tx, i])
+                        C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0]
+
+    mod = tvm.IRModule.from_expr(simple_compute)
+    sch = tvm.tir.Schedule(mod)
+
+    _, loop = sch.get_loops(sch.get_block("compute"))
+    sch.annotate(loop, ann_key="software_pipeline_stage", ann_val=[0, 0, 3])
+    sch.annotate(loop, ann_key="software_pipeline_order", ann_val=[0, 2, 1])
+    sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0])
+    mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod)
+
+    @T.prim_func
+    def ref(
+        A: T.Buffer[(16, 16), "float32"],
+        B: T.Buffer[(16, 16), "float32"],
+        C: T.Buffer[(16, 16), "float32"],
+    ) -> None:
+        for tx in T.thread_binding(16, thread="threadIdx.x"):
+            with T.block():
+                T.reads(A[tx, 0:16], B[tx, 0:16])
+                T.writes(C[tx, 0:16])
+                A_shared = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared")
+                B_shared = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, 0:3], B[tx, 0:3])
+                    T.writes(A_shared[0:3, tx, 0], B_shared[0:3, tx, 0])
+                    for i in T.unroll(3):
+                        with T.block():
+                            T.where(i < 16)
+                            T.reads(A[tx, i], B[tx, i])
+                            T.writes(A_shared[i % 4, tx, 0], B_shared[i % 4, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 0):
+                                with T.attr(0, "async_scope", 1):
+                                    A_shared[i % 4, tx, 0] = A[tx, i]
+                                with T.attr(0, "async_scope", 1):
+                                    B_shared[i % 4, tx, 0] = B[tx, i]
+                with T.block():
+                    T.reads(A[tx, 3:16], A_shared[0:4, tx, 0], B_shared[0:4, tx, 0], B[tx, 3:16])
+                    T.writes(A_shared[0:4, tx, 0], C[tx, 0:13], B_shared[0:4, tx, 0])
+                    for i in T.serial(13):
+                        with T.block():
+                            T.where(i + 3 < 16)
+                            T.reads(A[tx, i + 3])
+                            T.writes(A_shared[(i + 3) % 4, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 0):
+                                with T.attr(0, "async_scope", 1):
+                                    A_shared[(i + 3) % 4, tx, 0] = A[tx, i + 3]
+                        with T.block():
+                            T.where(i + 3 - 3 < 16)
+                            T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0])
+                            T.writes(C[tx, i - 3 + 3])
+                            with T.attr(0, "async_wait_queue_scope", 0):
+                                with T.attr(0, "async_wait_inflight_count", 5):
+                                    C[tx, i - 3 + 3] = (
+                                        A_shared[(i - 3 + 3) % 4, tx, 0]
+                                        + B_shared[(i - 3 + 3) % 4, tx, 0]
+                                    )
+                        with T.block():
+                            T.where(i + 3 < 16)
+                            T.reads(B[tx, i + 3])
+                            T.writes(B_shared[(i + 3) % 4, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 0):
+                                with T.attr(0, "async_scope", 1):
+                                    B_shared[(i + 3) % 4, tx, 0] = B[tx, i + 3]
+                with T.block():
+                    T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0])
+                    T.writes(C[tx, 13:16])
+                    for i in T.unroll(3):
+                        with T.block():
+                            T.where(i + 16 - 3 < 16)
+                            T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0])
+                            T.writes(C[tx, i - 3 + 16])
+                            with T.attr(0, "async_wait_queue_scope", 0):
+                                with T.attr(0, "async_wait_inflight_count", 2 - i):
+                                    C[tx, i - 3 + 16] = (
+                                        A_shared[(i - 3 + 16) % 4, tx, 0]
+                                        + B_shared[(i - 3 + 16) % 4, tx, 0]
+                                    )
+
+    tvm.ir.assert_structural_equal(mod["main"], ref, True)
+
+
+def test_three_stage_compute_two_stage_async():
+    mod = tvm.IRModule.from_expr(three_stage_compute)
+    sch = tvm.tir.Schedule(mod)
+
+    _, loop = sch.get_loops(sch.get_block("compute"))
+    sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0, 1])
+
+    mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod)
+
+    @T.prim_func
+    def ref(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> None:
+        for tx in T.thread_binding(16, thread="threadIdx.x"):
+            with T.block():
+                T.reads(A[tx, 0:16])
+                T.writes(D[tx, 0:16])
+                B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+                C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, 0:2], B[0:2, tx, 0])
+                    T.writes(B[0:2, tx, 0], C[0:2, tx, 0])
+                    for i in T.unroll(2):
+                        with T.block():
+                            T.where(i < 16)
+                            T.reads(A[tx, i])
+                            T.writes(B[i % 2, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 0):
+                                with T.attr(0, "async_scope", 1):
+                                    B[i % 2, tx, 0] = A[tx, i] * T.float32(2)
+                        with T.block():
+                            T.where(1 <= i and i - 1 < 16)
+                            T.reads(B[(i + 1) % 2, tx, 0])
+                            T.writes(C[(i + 1) % 2, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 1):
+                                with T.attr(0, "async_wait_queue_scope", 0):
+                                    with T.attr(0, "async_wait_inflight_count", 1):
+                                        with T.attr(0, "async_scope", 1):
+                                            C[(i - 1) % 2, tx, 0] = B[
+                                                (i - 1) % 2, tx, 0
+                                            ] + T.float32(2)
+                with T.block():
+                    T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
+                    T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
+                    for i in T.serial(14):
+                        with T.block():
+                            T.where(i + 2 < 16)
+                            T.reads(A[tx, i + 2])
+                            T.writes(B[i % 2, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 0):
+                                with T.attr(0, "async_scope", 1):
+                                    B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2)
+                        with T.block():
+                            T.where(i + 2 - 1 < 16)
+                            T.reads(B[(i + 1) % 2, tx, 0])
+                            T.writes(C[(i + 1) % 2, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 1):
+                                with T.attr(0, "async_wait_queue_scope", 0):
+                                    with T.attr(0, "async_wait_inflight_count", 1):
+                                        with T.attr(0, "async_scope", 1):
+                                            C[(i - 1 + 2) % 2, tx, 0] = B[
+                                                (i - 1 + 2) % 2, tx, 0
+                                            ] + T.float32(2)
+                        with T.block():
+                            T.where(i + 2 - 2 < 16)
+                            T.reads(C[0:2, tx, 0])
+                            T.writes(D[tx, i - 2 + 2])
+                            with T.attr(0, "async_wait_queue_scope", 1):
+                                with T.attr(0, "async_wait_inflight_count", 1):
+                                    D[tx, i - 2 + 2] = C[(i - 2 + 2) % 2, tx, 0] + T.float32(1)
+                with T.block():
+                    T.reads(B[0:2, tx, 0], C[0:2, tx, 0])
+                    T.writes(C[0:2, tx, 0], D[tx, 14:16])
+                    for i in T.unroll(2):
+                        with T.block():
+                            T.where(i + 16 - 1 < 16)
+                            T.reads(B[(i + 1) % 2, tx, 0])
+                            T.writes(C[(i + 1) % 2, tx, 0])
+                            with T.attr(0, "async_commit_queue_scope", 1):
+                                with T.attr(0, "async_wait_queue_scope", 0):
+                                    with T.attr(0, "async_wait_inflight_count", 0 - i):
+                                        with T.attr(0, "async_scope", 1):
+                                            C[(i - 1 + 16) % 2, tx, 0] = B[
+                                                (i - 1 + 16) % 2, tx, 0
+                                            ] + T.float32(2)
+                        with T.block():
+                            T.where(i + 16 - 2 < 16)
+                            T.reads(C[0:2, tx, 0])
+                            T.writes(D[tx, i - 2 + 16])
+                            with T.attr(0, "async_wait_queue_scope", 1):
+                                with T.attr(
+                                    0,
+                                    "async_wait_inflight_count",
+                                    T.if_then_else(i + 16 - 1 < 16, 1, 0, dtype="int32"),
+                                ):
+                                    D[tx, i - 2 + 16] = C[(i - 2 + 16) % 2, tx, 0] + T.float32(1)
+
+    tvm.ir.assert_structural_equal(mod["main"], ref, True)
+
+
+N = K = M = 4096
+
+
+def get_mma_schedule():
+    i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [16, 2, 4, 1, 2], [128, 2, 1]
 
     def index_map(i, j):
         return (
@@ -1055,7 +1368,7 @@ def test_three_stage_gemm():
         te_workload.matmul(N, M, K, in_dtype="float16", out_dtype="float32")
     )
 
-    sch = mma_schedule(
+    return mma_schedule(
         workload,
         16,
         "float16",
@@ -1074,13 +1387,11 @@ def test_three_stage_gemm():
         "shared.dyn",
     )
 
-    k0 = sch.get_loops(sch.get_block("C_o_update"))[3]
-
-    sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3])
-    sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
 
-    if is_ampere_or_newer():
-        f = tvm.build(sch.mod["main"], target="cuda")
+def build_and_run(sch):
+    if tvm.testing.is_ampere_or_newer():
+        with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+            f = tvm.build(sch.mod["main"], target="cuda")
 
         dev = tvm.device("cuda", 0)
         a_np = np.random.uniform(size=(N, K)).astype("float16")
@@ -1093,5 +1404,93 @@ def test_three_stage_gemm():
         tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
 
 
+@tvm.testing.requires_cuda
+def test_async_pipelined_mma_gemm_simple():
+    sch = get_mma_schedule()
+
+    k0 = sch.get_loops(sch.get_block("C_o_update"))[3]
+
+    sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3])
+    sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
+    sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0])
+
+    seq = tvm.transform.Sequential(
+        [
+            tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
+            tvm.tir.transform.ConvertBlocksToOpaque(),
+            tvm.tir.transform.UnifyThreadBinding(),
+            tvm.tir.transform.LowerMatchBuffer(),
+            tvm.tir.transform.InjectSoftwarePipeline(),
+        ]
+    )
+    mod = seq(sch.mod)
+
+    pipeline = mod["main"].body.block.body.body.body.body.body.block.body[1].block.body
+    prologue, body, epilogue = pipeline
+
+    commit_queue_scope = prologue.block.body.body.block.body
+    assert len(commit_queue_scope.body) == 2
+    assert commit_queue_scope.value == 0
+
+    commit_queue_scope = body.block.body.body[0].block.body
+    assert len(commit_queue_scope.body) == 2
+    assert commit_queue_scope.value == 0
+
+    assert body.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count"
+    assert body.block.body.body[1].block.body.body.value == 3
+
+    assert epilogue.block.body.body.block.body.body.attr_key == "async_wait_inflight_count"
+    assert str(epilogue.block.body.body.block.body.body.value) == "(2 - i2_0_0: int32)"
+
+    build_and_run(sch)
+
+
+@tvm.testing.requires_cuda
+def test_async_nested_pipeline_mma_gemm_ideal_annotation():
+    sch = get_mma_schedule()
+
+    k0 = sch.get_loops(sch.get_block("C_o_update"))[3]
+    k1 = sch.get_loops(sch.get_block("C_o_update"))[4]
+
+    sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 2, 3, 3])
+    sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 3, 2, 4])
+    sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0])
+
+    sch.annotate(k1, ann_key="software_pipeline_stage", ann_val=[0, 0, 1])
+    sch.annotate(k1, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
+
+    seq = tvm.transform.Sequential(
+        [
+            tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
+            tvm.tir.transform.ConvertBlocksToOpaque(),
+            tvm.tir.transform.UnifyThreadBinding(),
+            tvm.tir.transform.LowerMatchBuffer(),
+            tvm.tir.transform.InjectSoftwarePipeline(),
+        ]
+    )
+    mod = seq(sch.mod)
+
+    pipeline = mod["main"].body.block.body.body.body.body.body.block.body[1].block.body
+    prologue, body, epilogue = pipeline
+
+    commit_queue_scope = prologue.block.body.body[0].block.body
+    assert len(commit_queue_scope.body) == 2
+    assert commit_queue_scope.value == 0
+
+    assert prologue.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count"
+    assert prologue.block.body.body[1].block.body.body.value == 2
+
+    commit_queue_scope = body.block.body.body[0].block.body
+    assert len(commit_queue_scope.body) == 2
+    assert commit_queue_scope.value == 0
+
+    assert body.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count"
+    assert body.block.body.body[1].block.body.body.value == 2
+
+    assert str(epilogue.block.body.body[0].block.body.body.value) == "(1 - i2_0_0: int32)"
+
+    build_and_run(sch)
+
+
 if __name__ == "__main__":
     tvm.testing.main()