You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/05/10 10:31:20 UTC
[tvm] branch unity updated: [Unity] Fix CUDA graph rewrite var used before def (#14800)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b5928e775c [Unity] Fix CUDA graph rewrite var used before def (#14800)
b5928e775c is described below
commit b5928e775c308c9e42e2951ed5de726eb85d7ed1
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed May 10 03:31:12 2023 -0700
[Unity] Fix CUDA graph rewrite var used before def (#14800)
CUDA graph rewriting may result reordering of original bindings, for
example when a variable is used as an input of the lifted function.
If the variable comes from the output of another function, we need to
make sure output unpacking is emitted.
---
src/relax/transform/rewrite_cuda_graph.cc | 27 +++++++++--
src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 1 -
.../relax/test_transform_rewrite_cuda_graph.py | 53 ++++++++++++----------
3 files changed, 54 insertions(+), 27 deletions(-)
diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc
index 9621d9ff58..42ec5fca9d 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -461,6 +461,9 @@ class CUDAGraphRewriter : public ExprMutator {
}
Expr ret_value = builder_->Emit(launch_subgraph);
for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
+ // The unpacked result is saved in the var_redef_. It will be emitted when 1) the var
+ // definition is the original IR is visited, or 2) the var is used as an input to another
+ // lifted function, whichever comes first.
var_redef_[plan.outputs[i]] = TupleGetItem(ret_value, i);
}
@@ -471,9 +474,9 @@ class CUDAGraphRewriter : public ExprMutator {
if (subgraph_launches_.count(op->var.get())) {
LaunchSubgraph(op, subgraph_launches_[op->var.get()]);
}
- if (auto it = var_redef_.find(op->var.get()); it != var_redef_.end()) {
- auto new_var = builder_->Emit(it->second, op->var->name_hint());
- var_remap_[op->var->vid] = new_var;
+ if (auto it = var_redef_.find(op->var.get());
+ it != var_redef_.end() && !var_remap_.count(op->var->vid)) {
+ EmitRedef(op->var.get(), it->second);
return;
}
if (lifted_bindings_.count(op->var.get())) {
@@ -483,6 +486,24 @@ class CUDAGraphRewriter : public ExprMutator {
ExprMutator::VisitBinding_(op);
}
+ Expr VisitExpr_(const VarNode* op) final {
+ if (auto it = var_remap_.find(op->vid); it != var_remap_.end()) {
+ return it->second;
+ }
+ if (auto it = var_redef_.find(op); it != var_redef_.end()) {
+ // This is the case that the var is used as an input to another lifted when
+ // the original var definition is not visited yet.
+ return EmitRedef(op, it->second);
+ }
+ return GetRef<Expr>(op);
+ }
+
+ Var EmitRedef(const VarNode* var, const Expr& redef) {
+ auto new_var = builder_->Emit(redef, var->name_hint());
+ var_remap_[var->vid] = new_var;
+ return new_var;
+ }
+
std::unordered_map<const VarNode*, LiftedFunctionRewritePlan> subgraph_launches_;
std::unordered_map<const VarNode*, Expr> var_redef_;
std::unordered_set<const VarNode*> lifted_bindings_;
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index 45342cf4ff..9d2025d647 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -87,7 +87,6 @@ class CUDAGraphCache : public Object {
ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, ObjectRef args,
int64_t entry_index) {
if (auto it = capture_cache_.find(entry_index); it != capture_cache_.end()) {
- LOG(INFO) << "HIT";
// Launch CUDA graph
const auto& [states, cuda_graph] = it->second;
cudaGraphExec_t cuda_graph_exec;
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 4fc4d6f4a1..40c0a4a876 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -29,15 +29,11 @@ def test_rewrite_cuda_graph():
def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "exp"})
- # body
- # with T.block("root")
for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
with T.block("compute"):
i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) // T.int64(4))
i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) % T.int64(4))
- T.reads(rxplaceholder[i0, i1])
- T.writes(compute[i0, i1])
compute[i0, i1] = T.exp(rxplaceholder[i0, i1], dtype="float32")
@@ -54,12 +50,17 @@ def test_rewrite_cuda_graph():
alloc2: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
_4: R.Tuple = cls.exp(alloc1, alloc2)
_5: R.Tuple = R.memory.kill_tensor(alloc1)
- alloc3: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0)
- _6 = cls.exp(alloc2, alloc3)
+ storage2: R.Object = R.memory.alloc_storage(R.shape([32]), 0, "global", "float32")
+ alloc3: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage2, 0, R.shape([2, 4]), "float32")
+ _6: R.Tuple = cls.exp(alloc2, alloc3)
_7: R.Tuple = R.memory.kill_tensor(alloc2)
- _8: R.Tuple = R.memory.kill_storage(storage)
- _9: R.Tuple = R.memory.kill_storage(storage1)
- return alloc3
+ alloc4: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0)
+ _8 = cls.exp(alloc3, alloc4)
+ _9: R.Tuple = R.memory.kill_tensor(alloc3)
+ _10: R.Tuple = R.memory.kill_storage(storage)
+ _11: R.Tuple = R.memory.kill_storage(storage1)
+ _12: R.Tuple = R.memory.kill_storage(storage2)
+ return alloc4
@I.ir_module
@@ -80,40 +81,46 @@ def test_rewrite_cuda_graph():
compute[i0, i1] = T.exp(rxplaceholder[i0, i1], dtype="float32")
@R.function
- def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
- gv: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32"))
- gv1: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32"))
- gv2: R.Tuple(R.Object, R.Object) = (gv, gv1)
- return gv2
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object):
+ storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+ storage1: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+ storage2: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+ gv: R.Tuple(R.Object, R.Object, R.Object) = (storage, storage1, storage2)
+ return gv
@R.function
- def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
+ def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
cls = Expected
_2: R.Tuple = cls.exp(alloc, alloc1)
_3: R.Tuple = R.memory.kill_tensor(alloc)
alloc2: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32"))
_4: R.Tuple = cls.exp(alloc1, alloc2)
_5: R.Tuple = R.memory.kill_tensor(alloc1)
- gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc2,)
+ alloc3: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage2, 0, R.shape([2, 4]), "float32")
+ _6: R.Tuple = cls.exp(alloc2, alloc3)
+ _7: R.Tuple = R.memory.kill_tensor(alloc2)
+ gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc3,)
return gv
@R.function
def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
cls = Expected
- gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),))
+ gv: R.Tuple(R.Object, R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object, R.Object),))
storage: R.Object = gv[0]
alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32"))
_1: R.Tuple = cls.exp(x, alloc)
storage1: R.Object = gv[1]
alloc1: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), R.dtype("float32"))
- gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
- alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0]
- alloc3: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
- _6: R.Tuple = cls.exp(alloc2, alloc3)
- _7: R.Tuple = R.memory.kill_tensor(alloc2)
+ storage2: R.Object = gv[2]
+ gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+ alloc3: R.Tensor((2, 4), dtype="float32") = gv1[0]
+ alloc4: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
+ _6: R.Tuple = cls.exp(alloc3, alloc4)
+ _7: R.Tuple = R.memory.kill_tensor(alloc3)
_8: R.Tuple = R.memory.kill_storage(storage)
_9: R.Tuple = R.memory.kill_storage(storage1)
- return alloc3
+ _10: R.Tuple = R.memory.kill_storage(storage2)
+ return alloc4
# fmt: on
after = relax.transform.RewriteCUDAGraph()(Before)