You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/09 19:09:35 UTC
[tvm] branch unity updated: [Unity][Pass] Enhance Dynamic-aware FuseOps (#14543)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b816aa26db [Unity][Pass] Enhance Dynamic-aware FuseOps (#14543)
b816aa26db is described below
commit b816aa26dbf1acba05c9d49ed39292f26e35c22b
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Mon Apr 10 03:09:28 2023 +0800
[Unity][Pass] Enhance Dynamic-aware FuseOps (#14543)
[Unity][Pass] Enhance Dynamic-aware fuse_ops
The current fuse_ops pass only supports symbolic vars that are all defined
in the parameter `struct_info`. This commit enhances the pass to support
free symbolic vars by adding them to the parameter explicitly
---
include/tvm/relax/analysis.h | 15 +++
python/tvm/relax/analysis/analysis.py | 34 ++++++
src/relax/analysis/struct_info_analysis.cc | 119 +++++++++++++++++++++
src/relax/transform/fuse_ops.cc | 19 +++-
tests/python/relax/test_analysis.py | 13 +++
.../relax/test_analysis_struct_info_analysis.py | 26 ++++-
tests/python/relax/test_transform_fuse_ops.py | 42 ++++++++
7 files changed, 263 insertions(+), 5 deletions(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index f38a202d49..59f9e475bf 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -268,6 +268,21 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
*/
TVM_DLL Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo);
+/*!
+ * \brief Get the TIR variables that defined in the input function.
+ * The returned list is deduplicated - each TIR variable will appear at most once.
+ * \param func The function object to be analyzed.
+ * \return The list of TIR variables that are defined in the input function.
+ */
+TVM_DLL Array<tir::Var> DefinedSymbolicVars(const Function& func);
+
+/*!
+ * \brief Get the TIR variables that are used but not defined in the input function.
+ * The returned list is deduplicated - each TIR variable will appear at most once.
+ * \param func The function object to be analyzed.
+ * \return The list of TIR variables that are used but not defined in the input function.
+ */
+TVM_DLL Array<tir::Var> FreeSymbolicVars(const Function& func);
//-----------------------------------
// General IR analysis
//-----------------------------------
diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py
index 82d1f7a828..3db79ed8be 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -184,6 +184,40 @@ def tir_vars_in_struct_info(sinfo: StructInfo) -> List[tir.Var]:
return _ffi_api.TIRVarsInStructInfo(sinfo) # type: ignore
+def defined_symbolic_vars(func: Function) -> List[Var]:
+ """Get the TIR variables that defined in the input function.
+ The returned list is deduplicated - each TIR variable will appear at most once.
+
+ Parameters
+ ----------
+ func : Function
+ The function object to be analyzed.
+
+ Returns
+ -------
+ ret : List[Var]
+ The list of symbolic variables that are defined in the input function.
+ """
+ return _ffi_api.DefinedSymbolicVars(func) # type: ignore
+
+
+def free_symbolic_vars(func: Function) -> List[Var]:
+ """Get the TIR variables that are used but not defined in the input function.
+ The returned list is deduplicated - each TIR variable will appear at most once.
+
+ Parameters
+ ----------
+ func : Function
+ The function object to be analyzed.
+
+ Returns
+ -------
+ ret : List[Var]
+ The list of symbolic variables that are used but not defined in the input function.
+ """
+ return _ffi_api.FreeSymbolicVars(func) # type: ignore
+
+
def bound_vars(expr: Expr) -> List[Var]:
"""
Return all bound variables from expression expr.
diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc
index 55f58d8511..d2ef8c4e73 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -904,5 +904,124 @@ Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo) {
TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo")
.set_body_typed([](const StructInfo& sinfo) { return TIRVarsInStructInfo(sinfo); });
+class SymbolicVarCollector : public relax::ExprVisitor,
+ public relax::StructInfoVisitor,
+ public tir::ExprVisitor {
+ public:
+ static Array<tir::Var> Free(const Function& func) {
+ SymbolicVarCollector collector;
+ collector.VisitExpr(func);
+ Array<tir::Var> ret{collector.free_symbolic_var_.begin(), collector.free_symbolic_var_.end()};
+ return ret;
+ }
+
+ static Array<tir::Var> Defined(const Function& func) {
+ SymbolicVarCollector collector;
+ collector.VisitExpr(func);
+ Array<tir::Var> ret{collector.defined_symbolic_var_.begin(),
+ collector.defined_symbolic_var_.end()};
+ return ret;
+ }
+
+ private:
+ using relax::ExprVisitor::VisitExpr;
+ using relax::ExprVisitor::VisitExpr_;
+ using tir::ExprVisitor::VisitExpr;
+ using tir::ExprVisitor::VisitExpr_;
+
+ // Possible mode of visitor
+ enum class VisitMode {
+ /*! \brief Check all vars are well-defined. */
+ kDefault,
+ /*! \brief Match define the vars on first occurrence. */
+ kMatchVarDef,
+ };
+
+ void VisitExpr_(const FunctionNode* op) final {
+ WithMode(VisitMode::kMatchVarDef, [&]() {
+ ICHECK(mode_ == VisitMode::kMatchVarDef);
+ for (Var param : op->params) {
+ relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
+ }
+ });
+
+ relax::ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) final {
+ WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitStructInfo(binding->struct_info); });
+
+ relax::ExprVisitor::VisitBinding_(binding);
+ }
+
+ void VisitExprDepStructInfoField(const StructInfo& struct_info) {
+ return this->VisitStructInfo(struct_info);
+ }
+
+ void VisitStructInfo_(const FuncStructInfoNode* op) final {
+ if (op->params.defined()) {
+ WithMode(VisitMode::kMatchVarDef, [&]() {
+ ICHECK(mode_ == VisitMode::kMatchVarDef);
+ for (StructInfo param : op->params.value()) {
+ this->VisitStructInfo(param);
+ }
+ });
+ }
+ this->VisitStructInfo(op->ret);
+ }
+
+ void VisitStructInfoExprField(const Expr& expr) final {
+ relax::ExprVisitor::VisitExpr(expr);
+ if (auto* shape = expr.as<relax::ShapeExprNode>()) {
+ for (const auto& val : shape->values) {
+ this->VisitStructInfoExprField(val);
+ }
+ }
+ }
+
+ void VisitStructInfoExprField(const PrimExpr& expr) final {
+ if (mode_ == VisitMode::kMatchVarDef && expr->IsInstance<tir::VarNode>()) {
+ // populate symbolic var in first occurrence
+ const auto& var = Downcast<tir::Var>(expr);
+ if (defined_symbolic_var_.count(var) == 0) {
+ defined_symbolic_var_.insert(var);
+ }
+ }
+ tir::ExprVisitor::VisitExpr(expr);
+ }
+
+ void VisitExpr_(const tir::VarNode* op) final {
+ tir::Var var = GetRef<tir::Var>(op);
+ // default mode, check defined.
+ if (defined_symbolic_var_.count(var) == 0) {
+ free_symbolic_var_.insert(var);
+ }
+ }
+
+ // Run callback with mode.
+ template <typename FType>
+ void WithMode(VisitMode mode, FType callback) {
+ std::swap(mode_, mode);
+ callback();
+ std::swap(mode_, mode);
+ }
+
+ /*! \brief The current visit mode. */
+ VisitMode mode_ = VisitMode::kDefault;
+ /*! \brief The set of defined symbolic vars. */
+ std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> defined_symbolic_var_;
+ /*! \brief The set of free/undefined symbolic vars. */
+ std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> free_symbolic_var_;
+};
+
+Array<tir::Var> DefinedSymbolicVars(const Function& func) {
+ return SymbolicVarCollector::Defined(func);
+}
+Array<tir::Var> FreeSymbolicVars(const Function& func) { return SymbolicVarCollector::Free(func); }
+
+TVM_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.FreeSymbolicVars").set_body_typed(FreeSymbolicVars);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index cf9bd0ac37..b01097aa1b 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -506,10 +506,21 @@ class FunctionCreator : public ExprMutator {
body = builder_->Normalize(body);
body = builder_->Normalize(SeqExpr({new_block}, body));
group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
- function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_, //
- /*body=*/body, //
- /*ret_struct_info=*/NullOpt, //
- /*attrs=*/DictAttrs(group_attrs)));
+ Function function = Function(/*params=*/params_, //
+ /*body=*/body, //
+ /*ret_struct_info=*/NullOpt, //
+ /*attrs=*/DictAttrs(group_attrs));
+ Array<PrimExpr> free_vars =
+ FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; });
+ if (!free_vars.empty()) {
+ params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars)));
+ arguments_.push_back(ShapeExpr(free_vars));
+ function = Function(/*params=*/params_, //
+ /*body=*/body, //
+ /*ret_struct_info=*/NullOpt, //
+ /*attrs=*/DictAttrs(group_attrs));
+ }
+ function_ = SymbolicVarRenewMutator::Renew(function);
}
}
diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py
index 72a256d733..31efb98646 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -421,5 +421,18 @@ def test_reshape_pattern_reject_reduction():
assert not has_reshape_pattern(reduction)
+def test_reshape_pattern_reject_reduction():
+ @T.prim_func
+ def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")):
+ for i0, i1 in T.grid(4, 4):
+ with T.block("identity"):
+ vi0, vi1 = T.axis.remap("SR", [i0, i1])
+ with T.init():
+ B[vi0] = T.float32(0)
+ B[vi0] = B[vi0] + A[vi0, vi1]
+
+ assert not has_reshape_pattern(reduction)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py
index 62f44d15dd..85136d803b 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -18,9 +18,11 @@
"""Tests analysis functions of struct info"""
import pytest
+
import tvm
import tvm.testing
-from tvm import relax as rx, TVMError
+from tvm import TVMError
+from tvm import relax as rx
from tvm import tir
@@ -574,5 +576,27 @@ def test_tir_vars_in_struct_info():
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func), [n, m])
+def test_symbolic_var_collector():
+ n, m, k, q, p = (
+ tir.Var("n", "int64"),
+ tir.Var("m", "int64"),
+ tir.Var("k", "int64"),
+ tir.Var("q", "int64"),
+ tir.Var("p", "int64"),
+ )
+ bb = rx.BlockBuilder()
+ x = rx.Var("x", rx.TensorStructInfo([m, m + n], "float32"))
+ with bb.function("main", [x]):
+ v0 = bb.match_cast(x, rx.TensorStructInfo([m, k], "float32"))
+ v1 = bb.emit(rx.call_dps_packed("test", x, rx.TensorStructInfo([p, q], "float32")))
+ bb.emit_func_output(rx.const(1))
+ func = bb.get()["main"]
+
+ defined_vars = set(rx.analysis.defined_symbolic_vars(func))
+ free_vars = set(rx.analysis.free_symbolic_vars(func))
+ assert defined_vars == {m, k}
+ assert free_vars == {n, p, q}
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py
index 72f4e29a16..cf8efb0587 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1291,5 +1291,47 @@ def test_symbolic_shape_aware_fuse():
_check(Before, Expected)
+def test_symbolic_shape_aware_fuse_2():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(s: R.Shape(["n"])):
+ n = T.int64()
+ with R.dataflow():
+ lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
+ lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True)
+ gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def fused_full_trilu_broadcast_to(
+ s: R.Shape(["n"]),
+ ) -> R.Tensor([1, 1, "n", "n"], "float32"):
+ R.func_attr({"Primitive": 1})
+ n = T.int64()
+ with R.dataflow():
+ lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
+ lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True)
+ gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(s: R.Shape(["n"])) -> R.Tensor((1, 1, "n", "n"), dtype="float32"):
+ cls = Expected
+ n = T.int64()
+ with R.dataflow():
+ gv: R.Tensor([1, 1, n, n], "float32") = cls.fused_full_trilu_broadcast_to(
+ R.shape([n])
+ )
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()