You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/09/14 23:23:42 UTC
[tvm] branch unity updated: [Unity] Fix BYOC codegen for dynamic shapes (#15750)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 25d6c45ff4 [Unity] Fix BYOC codegen for dynamic shapes (#15750)
25d6c45ff4 is described below
commit 25d6c45ff4395536c1c53cfe87e6a7c8de34292d
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Sep 14 16:23:36 2023 -0700
[Unity] Fix BYOC codegen for dynamic shapes (#15750)
---
src/relax/transform/run_codegen.cc | 10 +--
tests/python/relax/test_transform_codegen_pass.py | 77 ++++++++++++++++++++++-
2 files changed, 81 insertions(+), 6 deletions(-)
diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc
index fa726b82af..9955b5f483 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -85,16 +85,16 @@ class CodeGenRunner : ExprMutator {
return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
};
+ auto ret_sinfo = GetStructInfo(call);
if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
- return create_call_dps_packed(it->second.first, it->second.second);
+ return create_call_dps_packed(it->second, ret_sinfo);
} else {
// TODO(@sunggg): Is there any better way to get this func?
Function func = Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
Expr new_func = VisitExpr(func);
if (new_func->IsInstance<ExternFuncNode>()) {
- auto ret_sinfo = GetStructInfo(call);
- extern_funcs_[gvar_node] = {new_func, ret_sinfo};
+ extern_funcs_[gvar_node] = new_func;
// Remove the global symbol and codegen attributes from the function so that it can be
// removed the module.
static const runtime::PackedFunc* RemoveFuncAttrFunc =
@@ -173,8 +173,8 @@ class CodeGenRunner : ExprMutator {
/*! \brief The names of all constants in the original module. */
Map<Constant, String> constant_names;
- /*! \brief Extern funcs and their return struct infos for each global variable. */
- std::unordered_map<const GlobalVarNode*, std::pair<Expr, StructInfo>> extern_funcs_;
+ /*! \brief Extern funcs for each global variable. */
+ std::unordered_map<const GlobalVarNode*, Expr> extern_funcs_;
};
} // namespace relax
diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py
index 77756dc664..d103291388 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -21,7 +21,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
import numpy as np
-from tvm.script import relax as R
+from tvm.script import relax as R, ir as I, tir as T
from tvm.relax.testing import transform
import tempfile
from tvm.relax.transform.tuning_api import Trace
@@ -248,6 +248,81 @@ def test_multiple_calls_same_extern():
tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])
+def test_dynamic_shape():
+ import tvm.relax.backend.contrib.cublas
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((1, 4096), dtype="float16"),
+ w1: R.Tensor((4096, "r1"), dtype="float16"),
+ w2: R.Tensor((4096, "r2"), dtype="float16"),
+ ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
+ r1 = T.int64()
+ r2 = T.int64()
+ cls = Before
+ with R.dataflow():
+ lv: R.Tensor((1, r1), dtype="float16") = cls.fused_relax_matmul_cublas(x, w1)
+ lv1: R.Tensor((1, r2), dtype="float16") = cls.fused_relax_matmul_cublas(x, w2)
+ gv: R.Tuple(
+ R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
+ ) = (lv, lv1)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_matmul_cublas(
+ x: R.Tensor((1, 4096), dtype="float16"), w1: R.Tensor((4096, "r1"), dtype="float16")
+ ) -> R.Tensor((1, "r1"), dtype="float16"):
+ r1 = T.int64()
+ R.func_attr({"Codegen": "cublas"})
+
+ @R.function
+ def gv(
+ x_1: R.Tensor((1, 4096), dtype="float16"),
+ w1_1: R.Tensor((4096, r1), dtype="float16"),
+ ) -> R.Tensor((1, r1), dtype="float16"):
+ R.func_attr({"Composite": "cublas.matmul"})
+ with R.dataflow():
+ gv_1: R.Tensor((1, r1), dtype="float16") = R.matmul(x_1, w1_1, out_dtype="void")
+ R.output(gv_1)
+ return gv_1
+
+ gv1: R.Tensor((1, r1), dtype="float16") = gv(x, w1)
+ return gv1
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 4096), dtype="float16"),
+ w1: R.Tensor((4096, "r1"), dtype="float16"),
+ w2: R.Tensor((4096, "r2"), dtype="float16"),
+ ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
+ r1 = T.int64()
+ r2 = T.int64()
+ with R.dataflow():
+ lv = R.call_dps_packed(
+ "fused_relax_matmul_cublas",
+ (x, w1),
+ out_sinfo=R.Tensor((1, r1), dtype="float16"),
+ )
+ lv1 = R.call_dps_packed(
+ "fused_relax_matmul_cublas",
+ (x, w2),
+ out_sinfo=R.Tensor((1, r2), dtype="float16"),
+ )
+ gv: R.Tuple(
+ R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
+ ) = (lv, lv1)
+ R.output(gv)
+ return gv
+
+ after = relax.transform.RunCodegen()(Before)
+ tvm.ir.assert_structural_equal(after["main"], Expected["main"])
+
+
# TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding)
if __name__ == "__main__":