You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/09/14 23:23:42 UTC

[tvm] branch unity updated: [Unity] Fix BYOC codegen for dynamic shapes (#15750)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 25d6c45ff4 [Unity] Fix BYOC codegen for dynamic shapes (#15750)
25d6c45ff4 is described below

commit 25d6c45ff4395536c1c53cfe87e6a7c8de34292d
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Sep 14 16:23:36 2023 -0700

    [Unity] Fix BYOC codegen for dynamic shapes (#15750)
---
 src/relax/transform/run_codegen.cc                | 10 +--
 tests/python/relax/test_transform_codegen_pass.py | 77 ++++++++++++++++++++++-
 2 files changed, 81 insertions(+), 6 deletions(-)

diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc
index fa726b82af..9955b5f483 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -85,16 +85,16 @@ class CodeGenRunner : ExprMutator {
         return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
       };
 
+      auto ret_sinfo = GetStructInfo(call);
       if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
-        return create_call_dps_packed(it->second.first, it->second.second);
+        return create_call_dps_packed(it->second, ret_sinfo);
       } else {
         // TODO(@sunggg): Is there any better way to get this func?
         Function func = Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
         Expr new_func = VisitExpr(func);
 
         if (new_func->IsInstance<ExternFuncNode>()) {
-          auto ret_sinfo = GetStructInfo(call);
-          extern_funcs_[gvar_node] = {new_func, ret_sinfo};
+          extern_funcs_[gvar_node] = new_func;
           // Remove the global symbol and codegen attributes from the function so that it can be
           // removed the module.
           static const runtime::PackedFunc* RemoveFuncAttrFunc =
@@ -173,8 +173,8 @@ class CodeGenRunner : ExprMutator {
 
   /*! \brief The names of all constants in the original module. */
   Map<Constant, String> constant_names;
-  /*! \brief Extern funcs and their return struct infos for each global variable.  */
-  std::unordered_map<const GlobalVarNode*, std::pair<Expr, StructInfo>> extern_funcs_;
+  /*! \brief Extern funcs for each global variable.  */
+  std::unordered_map<const GlobalVarNode*, Expr> extern_funcs_;
 };
 
 }  // namespace relax
diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py
index 77756dc664..d103291388 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -21,7 +21,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 import numpy as np
-from tvm.script import relax as R
+from tvm.script import relax as R, ir as I, tir as T
 from tvm.relax.testing import transform
 import tempfile
 from tvm.relax.transform.tuning_api import Trace
@@ -248,6 +248,81 @@ def test_multiple_calls_same_extern():
     tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])
 
 
+def test_dynamic_shape():
+    import tvm.relax.backend.contrib.cublas
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((1, 4096), dtype="float16"),
+            w1: R.Tensor((4096, "r1"), dtype="float16"),
+            w2: R.Tensor((4096, "r2"), dtype="float16"),
+        ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
+            r1 = T.int64()
+            r2 = T.int64()
+            cls = Before
+            with R.dataflow():
+                lv: R.Tensor((1, r1), dtype="float16") = cls.fused_relax_matmul_cublas(x, w1)
+                lv1: R.Tensor((1, r2), dtype="float16") = cls.fused_relax_matmul_cublas(x, w2)
+                gv: R.Tuple(
+                    R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
+                ) = (lv, lv1)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def fused_relax_matmul_cublas(
+            x: R.Tensor((1, 4096), dtype="float16"), w1: R.Tensor((4096, "r1"), dtype="float16")
+        ) -> R.Tensor((1, "r1"), dtype="float16"):
+            r1 = T.int64()
+            R.func_attr({"Codegen": "cublas"})
+
+            @R.function
+            def gv(
+                x_1: R.Tensor((1, 4096), dtype="float16"),
+                w1_1: R.Tensor((4096, r1), dtype="float16"),
+            ) -> R.Tensor((1, r1), dtype="float16"):
+                R.func_attr({"Composite": "cublas.matmul"})
+                with R.dataflow():
+                    gv_1: R.Tensor((1, r1), dtype="float16") = R.matmul(x_1, w1_1, out_dtype="void")
+                    R.output(gv_1)
+                return gv_1
+
+            gv1: R.Tensor((1, r1), dtype="float16") = gv(x, w1)
+            return gv1
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 4096), dtype="float16"),
+            w1: R.Tensor((4096, "r1"), dtype="float16"),
+            w2: R.Tensor((4096, "r2"), dtype="float16"),
+        ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
+            r1 = T.int64()
+            r2 = T.int64()
+            with R.dataflow():
+                lv = R.call_dps_packed(
+                    "fused_relax_matmul_cublas",
+                    (x, w1),
+                    out_sinfo=R.Tensor((1, r1), dtype="float16"),
+                )
+                lv1 = R.call_dps_packed(
+                    "fused_relax_matmul_cublas",
+                    (x, w2),
+                    out_sinfo=R.Tensor((1, r2), dtype="float16"),
+                )
+                gv: R.Tuple(
+                    R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
+                ) = (lv, lv1)
+                R.output(gv)
+            return gv
+
+    after = relax.transform.RunCodegen()(Before)
+    tvm.ir.assert_structural_equal(after["main"], Expected["main"])
+
+
 # TODO(@sunggg):  test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding)
 
 if __name__ == "__main__":