You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/11 15:24:40 UTC
(tvm) branch main updated: [Relax] CUDA graph rewrite treating StringImm as static (#16691)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 95f97e881a [Relax] CUDA graph rewrite treating StringImm as static (#16691)
95f97e881a is described below
commit 95f97e881a8988c801392f30994bad50b0451c9c
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Mar 11 11:24:33 2024 -0400
[Relax] CUDA graph rewrite treating StringImm as static (#16691)
The RewriteCUDAGraph pass missed to consider StringImm as a static
expression, causing some loss of CUDA graph rewrite opportunities.
This PR fixes the issue.
---
src/relax/transform/rewrite_cuda_graph.cc | 3 +-
.../relax/test_transform_rewrite_cuda_graph.py | 57 +++++++++++++++++++++-
2 files changed, 57 insertions(+), 3 deletions(-)
diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc
index 719703a3ec..b67a638dd6 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -348,7 +348,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}
bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector = nullptr) {
- if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>()) {
+ if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>() ||
+ expr->IsInstance<StringImmNode>()) {
return true;
}
if (const auto* prim_value = expr.as<PrimValueNode>()) {
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 73aaf4dac5..dc115939a7 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -18,9 +18,11 @@
import pytest
import tvm
-from tvm import relax
-from tvm.script import tir as T, relax as R, ir as I
import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
class BaseCompare(tvm.testing.CompareBeforeAfter):
@@ -704,5 +706,56 @@ def test_transform_is_no_op_when_disabled():
tvm.ir.assert_structural_equal(Before, AfterWhenDisabled)
+def test_static_args():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main():
+ storage0 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32")
+ alloc0 = R.memory.alloc_tensor(storage0, 0, R.shape([8]), "float32")
+ _ = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string"))
+ return R.tuple()
+
+ @I.ir_module
+ class Expected:
+ @R.function(private=True)
+ def cuda_graph_alloc() -> R.Tuple(R.Object):
+ R.func_attr({"relax.force_pure": True})
+ storage0: R.Object = R.memory.alloc_storage(
+ R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32")
+ )
+ gv: R.Tuple(R.Object) = (storage0,)
+ return gv
+
+ @R.function(private=True)
+ def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple:
+ R.func_attr({"relax.force_pure": True})
+ _: R.Object = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string"))
+ gv: R.Tuple = R.tuple()
+ return gv
+
+ @R.function
+ def main() -> R.Tuple:
+ cls = Expected
+ gv: R.Tuple(R.Object) = R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.get_cached_alloc",
+ (cls.cuda_graph_alloc, R.prim_value(0)),
+ sinfo_args=(R.Tuple(R.Object),),
+ )
+ storage0: R.Object = gv[0]
+ alloc0: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+ storage0, R.prim_value(0), R.shape([8]), R.dtype("float32")
+ )
+ gv1: R.Tuple = R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.run_or_capture",
+ (cls.cuda_graph_capture, (alloc0,), R.prim_value(0)),
+ sinfo_args=(R.Tuple,),
+ )
+ return R.tuple()
+
+ mod = relax.transform.RewriteCUDAGraph()(Before)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()