You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/11 15:24:40 UTC

(tvm) branch main updated: [Relax] CUDA graph rewrite treating StringImm as static (#16691)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 95f97e881a [Relax] CUDA graph rewrite treating StringImm as static (#16691)
95f97e881a is described below

commit 95f97e881a8988c801392f30994bad50b0451c9c
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Mar 11 11:24:33 2024 -0400

    [Relax] CUDA graph rewrite treating StringImm as static (#16691)
    
    The RewriteCUDAGraph pass missed to consider StringImm as a static
    expression, causing some loss of CUDA graph rewrite opportunities.
    This PR fixes the issue.
---
 src/relax/transform/rewrite_cuda_graph.cc          |  3 +-
 .../relax/test_transform_rewrite_cuda_graph.py     | 57 +++++++++++++++++++++-
 2 files changed, 57 insertions(+), 3 deletions(-)

diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc
index 719703a3ec..b67a638dd6 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -348,7 +348,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
   }
 
   bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector = nullptr) {
-    if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>()) {
+    if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>() ||
+        expr->IsInstance<StringImmNode>()) {
       return true;
     }
     if (const auto* prim_value = expr.as<PrimValueNode>()) {
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 73aaf4dac5..dc115939a7 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -18,9 +18,11 @@
 import pytest
 
 import tvm
-from tvm import relax
-from tvm.script import tir as T, relax as R, ir as I
 import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 class BaseCompare(tvm.testing.CompareBeforeAfter):
@@ -704,5 +706,56 @@ def test_transform_is_no_op_when_disabled():
     tvm.ir.assert_structural_equal(Before, AfterWhenDisabled)
 
 
+def test_static_args():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main():
+            storage0 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32")
+            alloc0 = R.memory.alloc_tensor(storage0, 0, R.shape([8]), "float32")
+            _ = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string"))
+            return R.tuple()
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def cuda_graph_alloc() -> R.Tuple(R.Object):
+            R.func_attr({"relax.force_pure": True})
+            storage0: R.Object = R.memory.alloc_storage(
+                R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32")
+            )
+            gv: R.Tuple(R.Object) = (storage0,)
+            return gv
+
+        @R.function(private=True)
+        def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple:
+            R.func_attr({"relax.force_pure": True})
+            _: R.Object = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string"))
+            gv: R.Tuple = R.tuple()
+            return gv
+
+        @R.function
+        def main() -> R.Tuple:
+            cls = Expected
+            gv: R.Tuple(R.Object) = R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.get_cached_alloc",
+                (cls.cuda_graph_alloc, R.prim_value(0)),
+                sinfo_args=(R.Tuple(R.Object),),
+            )
+            storage0: R.Object = gv[0]
+            alloc0: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+                storage0, R.prim_value(0), R.shape([8]), R.dtype("float32")
+            )
+            gv1: R.Tuple = R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.run_or_capture",
+                (cls.cuda_graph_capture, (alloc0,), R.prim_value(0)),
+                sinfo_args=(R.Tuple,),
+            )
+            return R.tuple()
+
+    mod = relax.transform.RewriteCUDAGraph()(Before)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()