You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/03/24 20:34:04 UTC
[tvm] branch unity updated: [Unity] Support simple dynamic-shape-aware fusion (#14396)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 23803b7c1f [Unity] Support simple dynamic-shape-aware fusion (#14396)
23803b7c1f is described below
commit 23803b7c1ff23bd3746248021a13e9ae347c7627
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Sat Mar 25 04:33:54 2023 +0800
[Unity] Support simple dynamic-shape-aware fusion (#14396)
This PR adds support for simple dynamic-shape-aware fusion, which is the first step towards supporting dynamic shapes. The main changes are as follows:
- Fix FuncStructInfo in well-formed checks
- Renew symbolic var defs in fuse_ops to prevent malformed functions
---
src/relax/analysis/well_formed.cc | 12 +++++++
src/relax/ir/expr_functor.cc | 5 ++-
src/relax/transform/fuse_ops.cc | 48 ++++++++++++++++++++++---
src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 4 +--
tests/python/relax/test_analysis_well_formed.py | 13 +++++++
tests/python/relax/test_transform_fuse_ops.py | 37 +++++++++++++++++++
6 files changed, 112 insertions(+), 7 deletions(-)
diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc
index 9a97931136..3eeefd0be5 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -402,6 +402,18 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}
+ void VisitStructInfo_(const FuncStructInfoNode* op) final {
+ if (op->params.defined()) {
+ WithMode(VisitMode::kMatchVarDef, [&]() {
+ ICHECK(mode_ == VisitMode::kMatchVarDef);
+ for (StructInfo param : op->params.value()) {
+ this->VisitStructInfo(param);
+ }
+ });
+ }
+ this->VisitStructInfo(op->ret);
+ }
+
void VisitStructInfoExprField(const Expr& expr) final {
if (mode_ == VisitMode::kMatchVarDef) {
// populate symbolic var in first occurrence
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 174d40053f..3f0fc86a2a 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -577,7 +577,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
for (Var param : op->params) {
Var new_param = this->VisitVarDef(param);
params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ if (!param.same_as(new_param)) {
+ var_remap_[param->vid] = new_param;
+ all_params_unchanged = false;
+ }
}
Expr body = this->VisitWithNewScope(op->body, params);
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 24f068c03f..8e4346e206 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -33,6 +33,7 @@
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
+#include <tvm/tir/expr_functor.h>
#include <tvm/tir/function.h>
#include <optional>
@@ -344,6 +345,45 @@ class GraphCreator : public ExprVisitor {
std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
};
+/*!
+ * \brief Renew the definition of symbolic vars in Relax.
+ * \details This mutator is used to prevent the same symbolic var from being used in different
+ * functions, which is malformed.
+ */
+class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
+ public:
+ static Function Renew(const Function& function) {
+ SymbolicVarRenewMutator mutator;
+ return Downcast<Function>(mutator.VisitExpr(function));
+ }
+
+ private:
+ SymbolicVarRenewMutator() = default;
+ using relax::ExprMutator::VisitExpr;
+ using relax::ExprMutator::VisitExpr_;
+ using tir::ExprMutator::VisitExpr_;
+
+ PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tir::ExprMutator::VisitExpr(expr); }
+
+ // TODO(Siyuan): enhance the method to the following steps:
+ // 1. Visit and replace all tir::Vars at the definition point
+ // 2. Revisit the function again and update the use side.
+ PrimExpr VisitExpr_(const tir::VarNode* op) final {
+ auto it = var_map_.find(GetRef<tir::Var>(op));
+ if (it != var_map_.end()) {
+ return (*it).second;
+ } else {
+ auto n = make_object<tir::VarNode>(*op);
+ tir::Var v(n);
+ var_map_.Set(GetRef<tir::Var>(op), v);
+ return v;
+ }
+ }
+
+ private:
+ Map<tir::Var, tir::Var> var_map_;
+};
+
/*!
* \brief The ExprMutator used to create a new grouped function
* \details The workflow of this ExprMutator is:
@@ -466,10 +506,10 @@ class FunctionCreator : public ExprMutator {
body = builder_->Normalize(body);
body = builder_->Normalize(SeqExpr({new_block}, body));
group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
- function_ = Function(/*params=*/params_, //
- /*body=*/body, //
- /*ret_struct_info=*/NullOpt, //
- /*attrs=*/DictAttrs(group_attrs));
+ function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_, //
+ /*body=*/body, //
+ /*ret_struct_info=*/NullOpt, //
+ /*attrs=*/DictAttrs(group_attrs)));
}
}
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index ad8770da8d..45342cf4ff 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -76,11 +76,11 @@ class CUDAGraphCache : public Object {
/*!
* \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode.
- * \param vm The virutal machine.
+ * \param vm The virtual machine.
* \param capture_func The function of type (args...) -> Tuple[ObjectRef], where 'args' are the
* static arguments that are the same for all invocations of the capture function, the returned
* tuple contains the intermediate tensors that will be used outside the capture function.
- * \params args The static arguments of the capture function
+ * \param args The static arguments of the capture function
* \param entry_index The unique index of the capture function used for lookup.
* \return The return value of the capture function.
*/
diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py
index 49d2b76011..b4b68504a4 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -520,5 +520,18 @@ def test_sinfo_erase_to_well_formed():
assert not rx.analysis.well_formed(mod)
+def test_func_sinfo_well_formed():
+ @R.function
+ def foo():
+ @R.function
+ def local(x: R.Tensor(["m", "n"], "float32")):
+ return x
+
+ return local
+
+ mod = rx.transform.Normalize()(tvm.IRModule.from_expr(foo))
+ assert rx.analysis.well_formed(mod)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py
index 8f7d8bf40f..72f4e29a16 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1254,5 +1254,42 @@ def test_dead_group():
_check(mod, Expected)
+def test_symbolic_shape_aware_fuse():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor(["n", "m"], "float32")):
+ with R.dataflow():
+ lv0 = R.emit_te(topi.add, x, R.const(1, "float32"))
+ lv1 = R.emit_te(topi.exp, lv0)
+ gv = R.emit_te(topi.squeeze, lv1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def fused_add_exp_squeeze(
+ x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
+ ) -> R.Tensor(["n", "m"], dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ with R.dataflow():
+ lv0 = R.emit_te(topi.add, x, p0)
+ lv1 = R.emit_te(topi.exp, lv0)
+ gv = R.emit_te(topi.squeeze, lv1)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32"))
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()