You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/09/20 19:38:11 UTC

[tvm] branch main updated: [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon User DMA (#12785)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5dfa8da00e [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon User DMA (#12785)
5dfa8da00e is described below

commit 5dfa8da00ec658934f3fc0df8eb9f41a167e1545
Author: Adam Straw <as...@octoml.ai>
AuthorDate: Tue Sep 20 12:38:04 2022 -0700

    [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon User DMA (#12785)
    
    * [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to HexagonUserDMA
    
    * save queue ID in `copy`, inspect in `wait` transform; add comments
    
    * improve testing; parameters for shape, scope, dtype
    
    * add log statements and adjust comments to clarify pass behavior
    
    * generalize use_async_copy for pass enable
    
    * use DLOG instead of LOG
    
    * trigger ci
    
    * trigger ci again
---
 include/tvm/tir/builtin.h                          |  10 ++
 include/tvm/tir/transform.h                        |   5 +
 src/driver/driver_api.cc                           |  12 +-
 src/runtime/hexagon/hexagon_device_api.cc          |  25 +++
 src/tir/op/builtin.cc                              |   6 +
 src/tir/transforms/lower_async_dma.cc              | 194 +++++++++++++++++++++
 src/tir/transforms/lower_tvm_builtin.cc            |  30 ++++
 .../test_hexagon/test_software_pipeline_async.py   |  86 +++++++++
 .../test_tir_transform_inject_ptx_async_copy.py    |   4 +-
 .../test_tir_transform_inject_software_pipeline.py |   2 +-
 10 files changed, 367 insertions(+), 7 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 12290a97c8..a1a97595bf 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -720,6 +720,16 @@ TVM_DLL const Op& texture2d_load();
  */
 TVM_DLL const Op& mem_copy();
 
+/*!
+ * \brief Initiate a non-blocking DMA copy from source to destination
+ */
+TVM_DLL const Op& dma_copy();
+
+/*!
+ * \brief Wait until the number of DMAs in flight is less than or equal to some maximum
+ */
+TVM_DLL const Op& dma_wait();
+
 /*!
  * \brief Provide a true statement that can be used for simplifications
  *
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index fd4261e4a4..a4caeee436 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -485,6 +485,11 @@ TVM_DLL Pass TextureFlatten();
  */
 TVM_DLL Pass LowerVtcmAlloc();
 
+/*!
+ * \brief Lower Async TIR primitives to DMA copy and wait builtins
+ */
+TVM_DLL Pass LowerAsyncDMA();
+
 /*!
  * \brief Implements a Common Subexpression Elimination (CSE) for TIR
  *        which introduces let-in bindings for duplicated sub-expressions.
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e528686d96..1a617dcd49 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -50,7 +50,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
 
 using runtime::PackedFunc;
 using runtime::TVMArgs;
@@ -225,6 +225,11 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
   }
   // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
   pass_list.push_back(tir::transform::LowerVtcmAlloc());
+  bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
+
+  if (use_async_copy) {
+    pass_list.push_back(tir::transform::LowerAsyncDMA());
+  }
   pass_list.push_back(tir::transform::UnrollLoop());
 
   // Add user-defined phase-2 passes
@@ -543,10 +548,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
   mixed_pass_list.push_back(tir::transform::InferFragment());
   mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
 
-  bool use_ptx_async_copy =
-      pass_ctx->GetConfig<Bool>("tir.use_ptx_async_copy", Bool(false)).value();
+  bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
 
-  if (use_ptx_async_copy) {
+  if (use_async_copy) {
     mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
   }
 
diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc
index 463d9799b0..84232a6144 100644
--- a/src/runtime/hexagon/hexagon_device_api.cc
+++ b/src/runtime/hexagon/hexagon_device_api.cc
@@ -33,6 +33,7 @@
 
 #include "../workspace_pool.h"
 #include "hexagon_common.h"
+#include "hexagon_user_dma.h"
 
 namespace tvm {
 namespace runtime {
@@ -206,6 +207,30 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM
   *rv = static_cast<int32_t>(0);
 });
 
+TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
+  int queue_id = args[0];
+  ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
+  void* dst = args[1];
+  void* src = args[2];
+  int size = args[3];
+  ICHECK(size > 0);
+
+  int ret = DMA_RETRY;
+  do {
+    ret = HexagonUserDMA::Get().Copy(dst, src, size);
+  } while (ret == DMA_RETRY);
+  *rv = static_cast<int32_t>(ret);
+});
+
+TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
+  int queue_id = args[0];
+  ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
+  int inflight = args[1];
+  ICHECK(inflight >= 0);
+  HexagonUserDMA::Get().Wait(inflight);
+  *rv = static_cast<int32_t>(0);
+});
+
 TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
   int32_t device_type = args[0];
   int32_t device_id = args[1];
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 9642f8e39f..1e2d790c76 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -288,6 +288,12 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
 TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
                                                             Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr<TCallEffectKind>("TCallEffectKind",
+                                                            Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr<TCallEffectKind>("TCallEffectKind",
+                                                            Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(assume)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
     .set_num_inputs(1);
diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc
new file mode 100644
index 0000000000..78d363f67c
--- /dev/null
+++ b/src/tir/transforms/lower_async_dma.cc
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_async_dma.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+class AsyncDMALowerer : public StmtExprMutator {
+ public:
+  AsyncDMALowerer() {}
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    // Convert this, for example:
+    // attr [0] "async_wait_queue_scope" = 0;
+    // attr [0] "async_wait_inflight_count" = 0;
+    //
+    // To this:
+    // @tir.dma_wait(
+    //   0, /* queue id */
+    //   0, /* in flight count */
+    //   dtype=int32
+    // )
+    if (op->attr_key == tir::attr::async_wait_queue_scope) {
+      // get queue ID
+      auto queue_id_node = op->value.as<IntImmNode>();
+      ICHECK(queue_id_node);
+      int queue_id = queue_id_node->value;
+
+      // abort if we have not seen this queue ID in `copy` transform
+      if (queue_ids.find(queue_id) == queue_ids.end()) {
+        DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the "
+                      "`async_wait_queue_scope` transform has not been previously observed in the "
+                      "`async_commit_queue_scope` transform";
+        return StmtExprMutator::VisitStmt_(op);
+      }
+
+      auto async_wait = op->body.as<AttrStmtNode>();
+      if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) {
+        DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
+                      "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key "
+                      "`async_wait_inflight_count`";
+        return StmtExprMutator::VisitStmt_(op);
+      }
+
+      auto call_dma_wait =
+          Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value}));
+
+      // concatenate the call with the body and return
+      return SeqStmt({call_dma_wait, async_wait->body});
+
+      // Convert this, for example:
+      // attr [0] "async_commit_queue_scope" = 0;
+      // attr [0] "async_scope" = 1;
+      // for (ax0: int32, 0, 128) {
+      //   A_global[ax0] = A[ax0]
+      // }
+      //
+      // To this:
+      // @tir.dma_copy(
+      //   0, /* queue id */
+      //   @tir.address_of(A_global[0], dtype=handle),
+      //   @tir.address_of(A[0], dtype=handle),
+      //   128, /* size */
+      //   dtype=int32
+      // )
+    } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
+      // get queue ID
+      auto queue_id_node = op->value.as<IntImmNode>();
+      ICHECK(queue_id_node);
+      int queue_id = queue_id_node->value;
+
+      // save queue ID for inspection in `wait` transform
+      queue_ids.insert(queue_id);
+
+      // walk the graph to verify this is a mem copy ...
+      // 1) async_commit_queue_scope contains async_scope
+      auto async_scope = op->body.as<AttrStmtNode>();
+      if (!async_scope || async_scope->attr_key != tir::attr::async_scope) {
+        DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
+                      "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key "
+                      "`async_scope`";
+        return StmtExprMutator::VisitStmt_(op);
+      }
+
+      // 2) async_scope contains single for loop
+      auto for_loop = async_scope->body.as<ForNode>();
+      if (!for_loop) {
+        DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
+                      "`async_scope` does not contain a single `ForNode`";
+        return StmtExprMutator::VisitStmt_(op);
+      }
+
+      // 3) for loop contains buffer store with single index
+      auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
+      if (!bufferstorenode || bufferstorenode->indices.size() != 1) {
+        DLOG(INFO)
+            << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a "
+               "single `BufferStoreNode` with a single index variable";
+        return StmtExprMutator::VisitStmt_(op);
+      }
+
+      // 4) buffer store value is a buffer load with single index
+      auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
+      if (!bufferloadnode || bufferloadnode->indices.size() != 1) {
+        DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a "
+                      "single `BufferLoadNode` with a single index variable";
+        return StmtExprMutator::VisitStmt_(op);
+      }
+
+      // get store buffer; assert it exists and is contiguous given it uses a single index
+      auto bufferstore = bufferstorenode->buffer.as<BufferNode>();
+      ICHECK(bufferstore && bufferstore->strides.empty());
+
+      // get load buffer; assert it exists and is contiguous given it uses a single index
+      auto bufferload = bufferloadnode->buffer.as<BufferNode>();
+      ICHECK(bufferload && bufferload->strides.empty());
+
+      // we will be replacing the entire for loop including its index
+      // with a DMA copy instrinsic that spans the entire index space of the for loop
+      // so we will need to replace the for loop index with value zero in the buffer indices
+      // thus we eliminate the index from the expression so the DMA copy receives the buffer range
+      // base address
+      Map<Var, PrimExpr> loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}};
+
+      // map loop variable to zero for the store index & simplify
+      Array<PrimExpr> store_index = bufferstorenode->indices;
+      store_index.MutateByApply([&](PrimExpr expr) {
+        arith::Analyzer analyzer;
+        return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
+      });
+
+      // map loop variable to zero for the load index & simplify
+      Array<PrimExpr> load_index = bufferloadnode->indices;
+      load_index.MutateByApply([&](PrimExpr expr) {
+        arith::Analyzer analyzer;
+        return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
+      });
+
+      return Evaluate(Call(DataType::Int(32), builtin::dma_copy(),
+                           {queue_id,
+                            Call(DataType::Handle(), builtin::address_of(),
+                                 {BufferLoad(bufferstorenode->buffer, store_index)}),
+                            Call(DataType::Handle(), builtin::address_of(),
+                                 {BufferLoad(bufferloadnode->buffer, load_index)}),
+                            for_loop->extent * bufferloadnode->dtype.bytes()}));
+    }
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  std::set<int> queue_ids;
+};
+
+namespace transform {
+
+Pass LowerAsyncDMA() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto fptr = f.CopyOnWrite();
+    fptr->body = AsyncDMALowerer()(std::move(fptr->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA);
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index 9d0087cc7a..f79682ef7e 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -317,6 +317,10 @@ class BuiltinLower : public StmtExprMutator {
       return make_zero(op->dtype);
     } else if (op->op.same_as(builtin::mem_copy())) {
       return MakeMemCopy(op);
+    } else if (op->op.same_as(builtin::dma_copy())) {
+      return MakeDMACopy(op);
+    } else if (op->op.same_as(builtin::dma_wait())) {
+      return MakeDMAWait(op);
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
@@ -335,6 +339,32 @@ class BuiltinLower : public StmtExprMutator {
     return VisitExpr(call_packed);
   }
 
+  PrimExpr MakeDMACopy(const CallNode* op) {
+    PrimExpr queue_id = op->args[0];
+    PrimExpr dst = op->args[1];
+    PrimExpr src = op->args[2];
+    PrimExpr size = op->args[3];
+
+    std::string fdevapi_prefix =
+        "device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));
+
+    Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
+                            {StringImm(fdevapi_prefix + ".dma_copy"), queue_id, dst, src, size});
+    return VisitExpr(call_packed);
+  }
+
+  PrimExpr MakeDMAWait(const CallNode* op) {
+    PrimExpr queue_id = op->args[0];
+    PrimExpr inflight = op->args[1];
+
+    std::string fdevapi_prefix =
+        "device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));
+
+    Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
+                            {StringImm(fdevapi_prefix + ".dma_wait"), queue_id, inflight});
+    return VisitExpr(call_packed);
+  }
+
   // call shape
   PrimExpr MakeShape(const CallNode* op) {
     // if args.size() == 0, it represents a scalar shape ()
diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py
new file mode 100644
index 0000000000..6bcca90ec9
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import pytest
+import numpy as np
+
+import tvm
+from tvm import tir
+from tvm.contrib.hexagon.session import Session
+from tvm.script import tir as T
+
+outer = tvm.testing.parameter(8, 16)
+inner = tvm.testing.parameter(64, 128)
+scope = tvm.testing.parameter("global", "global.vtcm")
+dtype = tvm.testing.parameter("uint8", "float16")
+
+
+@tvm.testing.fixture
+def compute(outer, inner, dtype):
+    @T.prim_func
+    def plus_one_primfunc(A: T.Buffer[(outer, inner), dtype], B: T.Buffer[(outer, inner), dtype]):
+        for i in T.serial(outer):
+            for j in T.serial(inner):
+                with T.block("compute"):
+                    with T.block():
+                        B[i, j] = A[i, j] + T.cast(1, dtype)
+
+    def plus_one_ref(a):
+        return a + 1
+
+    return plus_one_primfunc, plus_one_ref
+
+
+@tvm.testing.requires_hexagon
+def test_software_pipeline_with_cache_read(hexagon_launcher, compute, outer, inner, dtype, scope):
+    sch = tir.Schedule(compute[0])
+    root = sch.get_block("root")
+    compute_block = sch.get_block("compute")
+    cache_read_block = sch.cache_read(compute_block, 0, scope)
+
+    i, _ = sch.get_loops(compute_block)
+    sch.compute_at(cache_read_block, i)
+    sch.annotate(i, "software_pipeline_stage", [0, 1])
+    sch.annotate(i, "software_pipeline_order", [0, 1])
+    sch.annotate(i, "software_pipeline_async_stages", [0])
+
+    a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
+    b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
+    ref = compute[1](a_np)
+
+    target_hexagon = tvm.target.hexagon("v68", link_params=True)
+    with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
+        func = tvm.build(
+            sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)
+        )
+
+    with hexagon_launcher.start_session() as hexagon_session:
+        dev = hexagon_session.device
+        a = tvm.nd.array(a_np, device=dev)
+        b = tvm.nd.array(b_np, device=dev)
+        mod = hexagon_session.load_module(func)
+        mod(a, b)
+
+        if "int" in dtype:
+            np.testing.assert_equal(b.numpy(), ref)
+        else:
+            np.testing.assert_allclose(b.numpy(), ref, rtol=1e-3, atol=1e-3)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 1a906b2fb6..7062d51297 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -138,7 +138,7 @@ def test_inject_async_copy():
         if not tvm.testing.is_ampere_or_newer():
             continue
 
-        with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+        with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
             mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
 
         A_np = np.random.rand(32, 128).astype(dtype)
@@ -166,7 +166,7 @@ def test_inject_async_copy_shared_dyn():
     if not tvm.testing.is_ampere_or_newer():
         return
 
-    with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+    with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
         mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
 
     A_np = np.random.rand(32, 128).astype("float16")
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index edaeb7c9b6..49255e0f20 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -1390,7 +1390,7 @@ def get_mma_schedule():
 
 def build_and_run(sch):
     if tvm.testing.is_ampere_or_newer():
-        with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+        with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
             f = tvm.build(sch.mod["main"], target="cuda")
 
         dev = tvm.device("cuda", 0)