You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/09/20 19:38:11 UTC
[tvm] branch main updated: [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon User DMA (#12785)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5dfa8da00e [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon User DMA (#12785)
5dfa8da00e is described below
commit 5dfa8da00ec658934f3fc0df8eb9f41a167e1545
Author: Adam Straw <as...@octoml.ai>
AuthorDate: Tue Sep 20 12:38:04 2022 -0700
[Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon User DMA (#12785)
* [Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to HexagonUserDMA
* save queue ID in `copy`, inspect in `wait` transform; add comments
* improve testing; parameters for shape, scope, dtype
* add log statements and adjust comments to clarify pass behavior
* generalize use_async_copy for pass enable
* use DLOG instead of LOG
* trigger ci
* trigger ci again
---
include/tvm/tir/builtin.h | 10 ++
include/tvm/tir/transform.h | 5 +
src/driver/driver_api.cc | 12 +-
src/runtime/hexagon/hexagon_device_api.cc | 25 +++
src/tir/op/builtin.cc | 6 +
src/tir/transforms/lower_async_dma.cc | 194 +++++++++++++++++++++
src/tir/transforms/lower_tvm_builtin.cc | 30 ++++
.../test_hexagon/test_software_pipeline_async.py | 86 +++++++++
.../test_tir_transform_inject_ptx_async_copy.py | 4 +-
.../test_tir_transform_inject_software_pipeline.py | 2 +-
10 files changed, 367 insertions(+), 7 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 12290a97c8..a1a97595bf 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -720,6 +720,16 @@ TVM_DLL const Op& texture2d_load();
*/
TVM_DLL const Op& mem_copy();
+/*!
+ * \brief Initiate a non-blocking DMA copy from source to destination
+ */
+TVM_DLL const Op& dma_copy();
+
+/*!
+ * \brief Wait until the number of DMAs in flight is less than or equal to some maximum
+ */
+TVM_DLL const Op& dma_wait();
+
/*!
* \brief Provide a true statement that can be used for simplifications
*
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index fd4261e4a4..a4caeee436 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -485,6 +485,11 @@ TVM_DLL Pass TextureFlatten();
*/
TVM_DLL Pass LowerVtcmAlloc();
+/*!
+ * \brief Lower Async TIR primitives to DMA copy and wait builtins
+ */
+TVM_DLL Pass LowerAsyncDMA();
+
/*!
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e528686d96..1a617dcd49 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -50,7 +50,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
using runtime::PackedFunc;
using runtime::TVMArgs;
@@ -225,6 +225,11 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
}
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
pass_list.push_back(tir::transform::LowerVtcmAlloc());
+ bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
+
+ if (use_async_copy) {
+ pass_list.push_back(tir::transform::LowerAsyncDMA());
+ }
pass_list.push_back(tir::transform::UnrollLoop());
// Add user-defined phase-2 passes
@@ -543,10 +548,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
- bool use_ptx_async_copy =
- pass_ctx->GetConfig<Bool>("tir.use_ptx_async_copy", Bool(false)).value();
+ bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
- if (use_ptx_async_copy) {
+ if (use_async_copy) {
mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
}
diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc
index 463d9799b0..84232a6144 100644
--- a/src/runtime/hexagon/hexagon_device_api.cc
+++ b/src/runtime/hexagon/hexagon_device_api.cc
@@ -33,6 +33,7 @@
#include "../workspace_pool.h"
#include "hexagon_common.h"
+#include "hexagon_user_dma.h"
namespace tvm {
namespace runtime {
@@ -206,6 +207,30 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});
+TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
+ int queue_id = args[0];
+ ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
+ void* dst = args[1];
+ void* src = args[2];
+ int size = args[3];
+ ICHECK(size > 0);
+
+ int ret = DMA_RETRY;
+ do {
+ ret = HexagonUserDMA::Get().Copy(dst, src, size);
+ } while (ret == DMA_RETRY);
+ *rv = static_cast<int32_t>(ret);
+});
+
+TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
+ int queue_id = args[0];
+ ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
+ int inflight = args[1];
+ ICHECK(inflight >= 0);
+ HexagonUserDMA::Get().Wait(inflight);
+ *rv = static_cast<int32_t>(0);
+});
+
TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 9642f8e39f..1e2d790c76 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -288,6 +288,12 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr<TCallEffectKind>("TCallEffectKind",
+ Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr<TCallEffectKind>("TCallEffectKind",
+ Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(assume)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
.set_num_inputs(1);
diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc
new file mode 100644
index 0000000000..78d363f67c
--- /dev/null
+++ b/src/tir/transforms/lower_async_dma.cc
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_async_dma.cc
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+class AsyncDMALowerer : public StmtExprMutator {
+ public:
+ AsyncDMALowerer() {}
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ // Convert this, for example:
+ // attr [0] "async_wait_queue_scope" = 0;
+ // attr [0] "async_wait_inflight_count" = 0;
+ //
+ // To this:
+ // @tir.dma_wait(
+ // 0, /* queue id */
+ // 0, /* in flight count */
+ // dtype=int32
+ // )
+ if (op->attr_key == tir::attr::async_wait_queue_scope) {
+ // get queue ID
+ auto queue_id_node = op->value.as<IntImmNode>();
+ ICHECK(queue_id_node);
+ int queue_id = queue_id_node->value;
+
+ // abort if we have not seen this queue ID in `copy` transform
+ if (queue_ids.find(queue_id) == queue_ids.end()) {
+ DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the "
+ "`async_wait_queue_scope` transform has not been previously observed in the "
+ "`async_commit_queue_scope` transform";
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ auto async_wait = op->body.as<AttrStmtNode>();
+ if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) {
+ DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
+ "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key "
+ "`async_wait_inflight_count`";
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ auto call_dma_wait =
+ Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value}));
+
+ // concatenate the call with the body and return
+ return SeqStmt({call_dma_wait, async_wait->body});
+
+ // Convert this, for example:
+ // attr [0] "async_commit_queue_scope" = 0;
+ // attr [0] "async_scope" = 1;
+ // for (ax0: int32, 0, 128) {
+ // A_global[ax0] = A[ax0]
+ // }
+ //
+ // To this:
+ // @tir.dma_copy(
+ // 0, /* queue id */
+ // @tir.address_of(A_global[0], dtype=handle),
+ // @tir.address_of(A[0], dtype=handle),
+ // 128, /* size */
+ // dtype=int32
+ // )
+ } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
+ // get queue ID
+ auto queue_id_node = op->value.as<IntImmNode>();
+ ICHECK(queue_id_node);
+ int queue_id = queue_id_node->value;
+
+ // save queue ID for inspection in `wait` transform
+ queue_ids.insert(queue_id);
+
+ // walk the graph to verify this is a mem copy ...
+ // 1) async_commit_queue_scope contains async_scope
+ auto async_scope = op->body.as<AttrStmtNode>();
+ if (!async_scope || async_scope->attr_key != tir::attr::async_scope) {
+ DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
+ "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key "
+ "`async_scope`";
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ // 2) async_scope contains single for loop
+ auto for_loop = async_scope->body.as<ForNode>();
+ if (!for_loop) {
+ DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
+ "`async_scope` does not contain a single `ForNode`";
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ // 3) for loop contains buffer store with single index
+ auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
+ if (!bufferstorenode || bufferstorenode->indices.size() != 1) {
+ DLOG(INFO)
+ << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a "
+ "single `BufferStoreNode` with a single index variable";
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ // 4) buffer store value is a buffer load with single index
+ auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
+ if (!bufferloadnode || bufferloadnode->indices.size() != 1) {
+ DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a "
+ "single `BufferLoadNode` with a single index variable";
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ // get store buffer; assert it exists and is contiguous given it uses a single index
+ auto bufferstore = bufferstorenode->buffer.as<BufferNode>();
+ ICHECK(bufferstore && bufferstore->strides.empty());
+
+ // get load buffer; assert it exists and is contiguous given it uses a single index
+ auto bufferload = bufferloadnode->buffer.as<BufferNode>();
+ ICHECK(bufferload && bufferload->strides.empty());
+
+ // we will be replacing the entire for loop including its index
+ // with a DMA copy instrinsic that spans the entire index space of the for loop
+ // so we will need to replace the for loop index with value zero in the buffer indices
+ // thus we eliminate the index from the expression so the DMA copy receives the buffer range
+ // base address
+ Map<Var, PrimExpr> loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}};
+
+ // map loop variable to zero for the store index & simplify
+ Array<PrimExpr> store_index = bufferstorenode->indices;
+ store_index.MutateByApply([&](PrimExpr expr) {
+ arith::Analyzer analyzer;
+ return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
+ });
+
+ // map loop variable to zero for the load index & simplify
+ Array<PrimExpr> load_index = bufferloadnode->indices;
+ load_index.MutateByApply([&](PrimExpr expr) {
+ arith::Analyzer analyzer;
+ return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
+ });
+
+ return Evaluate(Call(DataType::Int(32), builtin::dma_copy(),
+ {queue_id,
+ Call(DataType::Handle(), builtin::address_of(),
+ {BufferLoad(bufferstorenode->buffer, store_index)}),
+ Call(DataType::Handle(), builtin::address_of(),
+ {BufferLoad(bufferloadnode->buffer, load_index)}),
+ for_loop->extent * bufferloadnode->dtype.bytes()}));
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ private:
+ std::set<int> queue_ids;
+};
+
+namespace transform {
+
+Pass LowerAsyncDMA() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto fptr = f.CopyOnWrite();
+ fptr->body = AsyncDMALowerer()(std::move(fptr->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA);
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index 9d0087cc7a..f79682ef7e 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -317,6 +317,10 @@ class BuiltinLower : public StmtExprMutator {
return make_zero(op->dtype);
} else if (op->op.same_as(builtin::mem_copy())) {
return MakeMemCopy(op);
+ } else if (op->op.same_as(builtin::dma_copy())) {
+ return MakeDMACopy(op);
+ } else if (op->op.same_as(builtin::dma_wait())) {
+ return MakeDMAWait(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
@@ -335,6 +339,32 @@ class BuiltinLower : public StmtExprMutator {
return VisitExpr(call_packed);
}
+ PrimExpr MakeDMACopy(const CallNode* op) {
+ PrimExpr queue_id = op->args[0];
+ PrimExpr dst = op->args[1];
+ PrimExpr src = op->args[2];
+ PrimExpr size = op->args[3];
+
+ std::string fdevapi_prefix =
+ "device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));
+
+ Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
+ {StringImm(fdevapi_prefix + ".dma_copy"), queue_id, dst, src, size});
+ return VisitExpr(call_packed);
+ }
+
+ PrimExpr MakeDMAWait(const CallNode* op) {
+ PrimExpr queue_id = op->args[0];
+ PrimExpr inflight = op->args[1];
+
+ std::string fdevapi_prefix =
+ "device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));
+
+ Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
+ {StringImm(fdevapi_prefix + ".dma_wait"), queue_id, inflight});
+ return VisitExpr(call_packed);
+ }
+
// call shape
PrimExpr MakeShape(const CallNode* op) {
// if args.size() == 0, it represents a scalar shape ()
diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py
new file mode 100644
index 0000000000..6bcca90ec9
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import pytest
+import numpy as np
+
+import tvm
+from tvm import tir
+from tvm.contrib.hexagon.session import Session
+from tvm.script import tir as T
+
+outer = tvm.testing.parameter(8, 16)
+inner = tvm.testing.parameter(64, 128)
+scope = tvm.testing.parameter("global", "global.vtcm")
+dtype = tvm.testing.parameter("uint8", "float16")
+
+
+@tvm.testing.fixture
+def compute(outer, inner, dtype):
+ @T.prim_func
+ def plus_one_primfunc(A: T.Buffer[(outer, inner), dtype], B: T.Buffer[(outer, inner), dtype]):
+ for i in T.serial(outer):
+ for j in T.serial(inner):
+ with T.block("compute"):
+ with T.block():
+ B[i, j] = A[i, j] + T.cast(1, dtype)
+
+ def plus_one_ref(a):
+ return a + 1
+
+ return plus_one_primfunc, plus_one_ref
+
+
+@tvm.testing.requires_hexagon
+def test_software_pipeline_with_cache_read(hexagon_launcher, compute, outer, inner, dtype, scope):
+ sch = tir.Schedule(compute[0])
+ root = sch.get_block("root")
+ compute_block = sch.get_block("compute")
+ cache_read_block = sch.cache_read(compute_block, 0, scope)
+
+ i, _ = sch.get_loops(compute_block)
+ sch.compute_at(cache_read_block, i)
+ sch.annotate(i, "software_pipeline_stage", [0, 1])
+ sch.annotate(i, "software_pipeline_order", [0, 1])
+ sch.annotate(i, "software_pipeline_async_stages", [0])
+
+ a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
+ b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype)
+ ref = compute[1](a_np)
+
+ target_hexagon = tvm.target.hexagon("v68", link_params=True)
+ with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
+ func = tvm.build(
+ sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)
+ )
+
+ with hexagon_launcher.start_session() as hexagon_session:
+ dev = hexagon_session.device
+ a = tvm.nd.array(a_np, device=dev)
+ b = tvm.nd.array(b_np, device=dev)
+ mod = hexagon_session.load_module(func)
+ mod(a, b)
+
+ if "int" in dtype:
+ np.testing.assert_equal(b.numpy(), ref)
+ else:
+ np.testing.assert_allclose(b.numpy(), ref, rtol=1e-3, atol=1e-3)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 1a906b2fb6..7062d51297 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -138,7 +138,7 @@ def test_inject_async_copy():
if not tvm.testing.is_ampere_or_newer():
continue
- with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+ with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
A_np = np.random.rand(32, 128).astype(dtype)
@@ -166,7 +166,7 @@ def test_inject_async_copy_shared_dyn():
if not tvm.testing.is_ampere_or_newer():
return
- with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+ with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
A_np = np.random.rand(32, 128).astype("float16")
diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index edaeb7c9b6..49255e0f20 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -1390,7 +1390,7 @@ def get_mma_schedule():
def build_and_run(sch):
if tvm.testing.is_ampere_or_newer():
- with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+ with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
f = tvm.build(sch.mod["main"], target="cuda")
dev = tvm.device("cuda", 0)