You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/09 19:09:35 UTC

[tvm] branch unity updated: [Unity][Pass] Enhance Dynamic-aware FuseOps (#14543)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b816aa26db [Unity][Pass] Enhance Dynamic-aware FuseOps (#14543)
b816aa26db is described below

commit b816aa26dbf1acba05c9d49ed39292f26e35c22b
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Mon Apr 10 03:09:28 2023 +0800

    [Unity][Pass] Enhance Dynamic-aware FuseOps (#14543)
    
    [Unity][Pass] Enhance Dynamic-aware fuse_ops
    
    The current fuse_ops pass only supports symbolic vars that are all defined
    in the parameter `struct_info`. This commit enhances the pass to support
    free symbolic vars by adding them to the parameter explicitly
---
 include/tvm/relax/analysis.h                       |  15 +++
 python/tvm/relax/analysis/analysis.py              |  34 ++++++
 src/relax/analysis/struct_info_analysis.cc         | 119 +++++++++++++++++++++
 src/relax/transform/fuse_ops.cc                    |  19 +++-
 tests/python/relax/test_analysis.py                |  13 +++
 .../relax/test_analysis_struct_info_analysis.py    |  26 ++++-
 tests/python/relax/test_transform_fuse_ops.py      |  42 ++++++++
 7 files changed, 263 insertions(+), 5 deletions(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index f38a202d49..59f9e475bf 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -268,6 +268,21 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
  */
 TVM_DLL Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo);
 
+/*!
+ * \brief Get the TIR variables that defined in the input function.
+ * The returned list is deduplicated - each TIR variable will appear at most once.
+ * \param func The function object to be analyzed.
+ * \return The list of TIR variables that are defined in the input function.
+ */
+TVM_DLL Array<tir::Var> DefinedSymbolicVars(const Function& func);
+
+/*!
+ * \brief Get the TIR variables that are used but not defined in the input function.
+ * The returned list is deduplicated - each TIR variable will appear at most once.
+ * \param func The function object to be analyzed.
+ * \return The list of TIR variables that are used but not defined in the input function.
+ */
+TVM_DLL Array<tir::Var> FreeSymbolicVars(const Function& func);
 //-----------------------------------
 // General IR analysis
 //-----------------------------------
diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py
index 82d1f7a828..3db79ed8be 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -184,6 +184,40 @@ def tir_vars_in_struct_info(sinfo: StructInfo) -> List[tir.Var]:
     return _ffi_api.TIRVarsInStructInfo(sinfo)  # type: ignore
 
 
+def defined_symbolic_vars(func: Function) -> List[Var]:
+    """Get the TIR variables that defined in the input function.
+    The returned list is deduplicated - each TIR variable will appear at most once.
+
+    Parameters
+    ----------
+    func : Function
+        The function object to be analyzed.
+
+    Returns
+    -------
+    ret : List[Var]
+        The list of symbolic variables that are defined in the input function.
+    """
+    return _ffi_api.DefinedSymbolicVars(func)  # type: ignore
+
+
+def free_symbolic_vars(func: Function) -> List[Var]:
+    """Get the TIR variables that are used but not defined in the input function.
+    The returned list is deduplicated - each TIR variable will appear at most once.
+
+    Parameters
+    ----------
+    func : Function
+        The function object to be analyzed.
+
+    Returns
+    -------
+    ret : List[Var]
+        The list of symbolic variables that are used but not defined in the input function.
+    """
+    return _ffi_api.FreeSymbolicVars(func)  # type: ignore
+
+
 def bound_vars(expr: Expr) -> List[Var]:
     """
     Return all bound variables from expression expr.
diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc
index 55f58d8511..d2ef8c4e73 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -904,5 +904,124 @@ Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo) {
 TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo")
     .set_body_typed([](const StructInfo& sinfo) { return TIRVarsInStructInfo(sinfo); });
 
+class SymbolicVarCollector : public relax::ExprVisitor,
+                             public relax::StructInfoVisitor,
+                             public tir::ExprVisitor {
+ public:
+  static Array<tir::Var> Free(const Function& func) {
+    SymbolicVarCollector collector;
+    collector.VisitExpr(func);
+    Array<tir::Var> ret{collector.free_symbolic_var_.begin(), collector.free_symbolic_var_.end()};
+    return ret;
+  }
+
+  static Array<tir::Var> Defined(const Function& func) {
+    SymbolicVarCollector collector;
+    collector.VisitExpr(func);
+    Array<tir::Var> ret{collector.defined_symbolic_var_.begin(),
+                        collector.defined_symbolic_var_.end()};
+    return ret;
+  }
+
+ private:
+  using relax::ExprVisitor::VisitExpr;
+  using relax::ExprVisitor::VisitExpr_;
+  using tir::ExprVisitor::VisitExpr;
+  using tir::ExprVisitor::VisitExpr_;
+
+  // Possible mode of visitor
+  enum class VisitMode {
+    /*! \brief Check all vars are well-defined. */
+    kDefault,
+    /*! \brief Match define the vars on first occurrence. */
+    kMatchVarDef,
+  };
+
+  void VisitExpr_(const FunctionNode* op) final {
+    WithMode(VisitMode::kMatchVarDef, [&]() {
+      ICHECK(mode_ == VisitMode::kMatchVarDef);
+      for (Var param : op->params) {
+        relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
+      }
+    });
+
+    relax::ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitBinding_(const MatchCastNode* binding) final {
+    WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitStructInfo(binding->struct_info); });
+
+    relax::ExprVisitor::VisitBinding_(binding);
+  }
+
+  void VisitExprDepStructInfoField(const StructInfo& struct_info) {
+    return this->VisitStructInfo(struct_info);
+  }
+
+  void VisitStructInfo_(const FuncStructInfoNode* op) final {
+    if (op->params.defined()) {
+      WithMode(VisitMode::kMatchVarDef, [&]() {
+        ICHECK(mode_ == VisitMode::kMatchVarDef);
+        for (StructInfo param : op->params.value()) {
+          this->VisitStructInfo(param);
+        }
+      });
+    }
+    this->VisitStructInfo(op->ret);
+  }
+
+  void VisitStructInfoExprField(const Expr& expr) final {
+    relax::ExprVisitor::VisitExpr(expr);
+    if (auto* shape = expr.as<relax::ShapeExprNode>()) {
+      for (const auto& val : shape->values) {
+        this->VisitStructInfoExprField(val);
+      }
+    }
+  }
+
+  void VisitStructInfoExprField(const PrimExpr& expr) final {
+    if (mode_ == VisitMode::kMatchVarDef && expr->IsInstance<tir::VarNode>()) {
+      // populate symbolic var in first occurrence
+      const auto& var = Downcast<tir::Var>(expr);
+      if (defined_symbolic_var_.count(var) == 0) {
+        defined_symbolic_var_.insert(var);
+      }
+    }
+    tir::ExprVisitor::VisitExpr(expr);
+  }
+
+  void VisitExpr_(const tir::VarNode* op) final {
+    tir::Var var = GetRef<tir::Var>(op);
+    // default mode, check defined.
+    if (defined_symbolic_var_.count(var) == 0) {
+      free_symbolic_var_.insert(var);
+    }
+  }
+
+  // Run callback with mode.
+  template <typename FType>
+  void WithMode(VisitMode mode, FType callback) {
+    std::swap(mode_, mode);
+    callback();
+    std::swap(mode_, mode);
+  }
+
+  /*! \brief The current visit mode. */
+  VisitMode mode_ = VisitMode::kDefault;
+  /*! \brief The set of defined symbolic vars. */
+  std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> defined_symbolic_var_;
+  /*! \brief The set of free/undefined symbolic vars. */
+  std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> free_symbolic_var_;
+};
+
+Array<tir::Var> DefinedSymbolicVars(const Function& func) {
+  return SymbolicVarCollector::Defined(func);
+}
+Array<tir::Var> FreeSymbolicVars(const Function& func) { return SymbolicVarCollector::Free(func); }
+
+TVM_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars);
+
+TVM_REGISTER_GLOBAL("relax.analysis.FreeSymbolicVars").set_body_typed(FreeSymbolicVars);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index cf9bd0ac37..b01097aa1b 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -506,10 +506,21 @@ class FunctionCreator : public ExprMutator {
       body = builder_->Normalize(body);
       body = builder_->Normalize(SeqExpr({new_block}, body));
       group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
-      function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_,           //
-                                                          /*body=*/body,                //
-                                                          /*ret_struct_info=*/NullOpt,  //
-                                                          /*attrs=*/DictAttrs(group_attrs)));
+      Function function = Function(/*params=*/params_,           //
+                                   /*body=*/body,                //
+                                   /*ret_struct_info=*/NullOpt,  //
+                                   /*attrs=*/DictAttrs(group_attrs));
+      Array<PrimExpr> free_vars =
+          FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; });
+      if (!free_vars.empty()) {
+        params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars)));
+        arguments_.push_back(ShapeExpr(free_vars));
+        function = Function(/*params=*/params_,           //
+                            /*body=*/body,                //
+                            /*ret_struct_info=*/NullOpt,  //
+                            /*attrs=*/DictAttrs(group_attrs));
+      }
+      function_ = SymbolicVarRenewMutator::Renew(function);
     }
   }
 
diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py
index 72a256d733..31efb98646 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -421,5 +421,18 @@ def test_reshape_pattern_reject_reduction():
     assert not has_reshape_pattern(reduction)
 
 
+def test_reshape_pattern_reject_reduction():
+    @T.prim_func
+    def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")):
+        for i0, i1 in T.grid(4, 4):
+            with T.block("identity"):
+                vi0, vi1 = T.axis.remap("SR", [i0, i1])
+                with T.init():
+                    B[vi0] = T.float32(0)
+                B[vi0] = B[vi0] + A[vi0, vi1]
+
+    assert not has_reshape_pattern(reduction)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py
index 62f44d15dd..85136d803b 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -18,9 +18,11 @@
 """Tests analysis functions of struct info"""
 
 import pytest
+
 import tvm
 import tvm.testing
-from tvm import relax as rx, TVMError
+from tvm import TVMError
+from tvm import relax as rx
 from tvm import tir
 
 
@@ -574,5 +576,27 @@ def test_tir_vars_in_struct_info():
     tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func), [n, m])
 
 
+def test_symbolic_var_collector():
+    n, m, k, q, p = (
+        tir.Var("n", "int64"),
+        tir.Var("m", "int64"),
+        tir.Var("k", "int64"),
+        tir.Var("q", "int64"),
+        tir.Var("p", "int64"),
+    )
+    bb = rx.BlockBuilder()
+    x = rx.Var("x", rx.TensorStructInfo([m, m + n], "float32"))
+    with bb.function("main", [x]):
+        v0 = bb.match_cast(x, rx.TensorStructInfo([m, k], "float32"))
+        v1 = bb.emit(rx.call_dps_packed("test", x, rx.TensorStructInfo([p, q], "float32")))
+        bb.emit_func_output(rx.const(1))
+    func = bb.get()["main"]
+
+    defined_vars = set(rx.analysis.defined_symbolic_vars(func))
+    free_vars = set(rx.analysis.free_symbolic_vars(func))
+    assert defined_vars == {m, k}
+    assert free_vars == {n, p, q}
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py
index 72f4e29a16..cf8efb0587 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1291,5 +1291,47 @@ def test_symbolic_shape_aware_fuse():
     _check(Before, Expected)
 
 
+def test_symbolic_shape_aware_fuse_2():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(s: R.Shape(["n"])):
+            n = T.int64()
+            with R.dataflow():
+                lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
+                lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True)
+                gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def fused_full_trilu_broadcast_to(
+            s: R.Shape(["n"]),
+        ) -> R.Tensor([1, 1, "n", "n"], "float32"):
+            R.func_attr({"Primitive": 1})
+            n = T.int64()
+            with R.dataflow():
+                lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
+                lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True)
+                gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(s: R.Shape(["n"])) -> R.Tensor((1, 1, "n", "n"), dtype="float32"):
+            cls = Expected
+            n = T.int64()
+            with R.dataflow():
+                gv: R.Tensor([1, 1, n, n], "float32") = cls.fused_full_trilu_broadcast_to(
+                    R.shape([n])
+                )
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()