You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/03/24 20:34:04 UTC

[tvm] branch unity updated: [Unity] Support simple dynamic-shape-aware fusion (#14396)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 23803b7c1f [Unity] Support simple dynamic-shape-aware fusion (#14396)
23803b7c1f is described below

commit 23803b7c1ff23bd3746248021a13e9ae347c7627
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Sat Mar 25 04:33:54 2023 +0800

    [Unity] Support simple dynamic-shape-aware fusion (#14396)
    
    This PR adds support for simple dynamic-shape-aware fusion, which is the first step towards supporting dynamic shapes. The main changes are as follows:
    
    - Fix FuncStructInfo in well-formed checks
    - Renew symbolic var defs in fuse_ops to prevent malformed functions
---
 src/relax/analysis/well_formed.cc               | 12 +++++++
 src/relax/ir/expr_functor.cc                    |  5 ++-
 src/relax/transform/fuse_ops.cc                 | 48 ++++++++++++++++++++++---
 src/runtime/relax_vm/cuda/cuda_graph_builtin.cc |  4 +--
 tests/python/relax/test_analysis_well_formed.py | 13 +++++++
 tests/python/relax/test_transform_fuse_ops.py   | 37 +++++++++++++++++++
 6 files changed, 112 insertions(+), 7 deletions(-)

diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc
index 9a97931136..3eeefd0be5 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -402,6 +402,18 @@ class WellFormedChecker : public relax::ExprVisitor,
     }
   }
 
+  void VisitStructInfo_(const FuncStructInfoNode* op) final {
+    if (op->params.defined()) {
+      WithMode(VisitMode::kMatchVarDef, [&]() {
+        ICHECK(mode_ == VisitMode::kMatchVarDef);
+        for (StructInfo param : op->params.value()) {
+          this->VisitStructInfo(param);
+        }
+      });
+    }
+    this->VisitStructInfo(op->ret);
+  }
+
   void VisitStructInfoExprField(const Expr& expr) final {
     if (mode_ == VisitMode::kMatchVarDef) {
       // populate symbolic var in first occurrence
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 174d40053f..3f0fc86a2a 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -577,7 +577,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
   for (Var param : op->params) {
     Var new_param = this->VisitVarDef(param);
     params.push_back(new_param);
-    all_params_unchanged &= param.same_as(new_param);
+    if (!param.same_as(new_param)) {
+      var_remap_[param->vid] = new_param;
+      all_params_unchanged = false;
+    }
   }
 
   Expr body = this->VisitWithNewScope(op->body, params);
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 24f068c03f..8e4346e206 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -33,6 +33,7 @@
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
+#include <tvm/tir/expr_functor.h>
 #include <tvm/tir/function.h>
 
 #include <optional>
@@ -344,6 +345,45 @@ class GraphCreator : public ExprVisitor {
   std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
 };
 
+/*!
+ * \brief Renew the definition of symbolic vars in Relax.
+ * \details This mutator is used to prevent the same symbolic var from being used in different
+ *          functions, which is malformed.
+ */
+class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
+ public:
+  static Function Renew(const Function& function) {
+    SymbolicVarRenewMutator mutator;
+    return Downcast<Function>(mutator.VisitExpr(function));
+  }
+
+ private:
+  SymbolicVarRenewMutator() = default;
+  using relax::ExprMutator::VisitExpr;
+  using relax::ExprMutator::VisitExpr_;
+  using tir::ExprMutator::VisitExpr_;
+
+  PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tir::ExprMutator::VisitExpr(expr); }
+
+  // TODO(Siyuan): enhance the method to the following steps:
+  // 1. Visit and replace all tir::Vars at the definition point
+  // 2. Revisit the function again and update the use side.
+  PrimExpr VisitExpr_(const tir::VarNode* op) final {
+    auto it = var_map_.find(GetRef<tir::Var>(op));
+    if (it != var_map_.end()) {
+      return (*it).second;
+    } else {
+      auto n = make_object<tir::VarNode>(*op);
+      tir::Var v(n);
+      var_map_.Set(GetRef<tir::Var>(op), v);
+      return v;
+    }
+  }
+
+ private:
+  Map<tir::Var, tir::Var> var_map_;
+};
+
 /*!
  * \brief The ExprMutator used to create a new grouped function
  * \details The workflow of this ExprMutator is:
@@ -466,10 +506,10 @@ class FunctionCreator : public ExprMutator {
       body = builder_->Normalize(body);
       body = builder_->Normalize(SeqExpr({new_block}, body));
       group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
-      function_ = Function(/*params=*/params_,           //
-                           /*body=*/body,                //
-                           /*ret_struct_info=*/NullOpt,  //
-                           /*attrs=*/DictAttrs(group_attrs));
+      function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_,           //
+                                                          /*body=*/body,                //
+                                                          /*ret_struct_info=*/NullOpt,  //
+                                                          /*attrs=*/DictAttrs(group_attrs)));
     }
   }
 
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index ad8770da8d..45342cf4ff 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -76,11 +76,11 @@ class CUDAGraphCache : public Object {
 
   /*!
    * \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode.
-   * \param vm The virutal machine.
+   * \param vm The virtual machine.
    * \param capture_func The function of type (args...) -> Tuple[ObjectRef], where 'args' are the
    * static arguments that are the same for all invocations of the capture function, the returned
    * tuple contains the intermediate tensors that will be used outside the capture function.
-   * \params args The static arguments of the capture function
+   * \param args The static arguments of the capture function
    * \param entry_index The unique index of the capture function used for lookup.
    * \return The return value of the capture function.
    */
diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py
index 49d2b76011..b4b68504a4 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -520,5 +520,18 @@ def test_sinfo_erase_to_well_formed():
     assert not rx.analysis.well_formed(mod)
 
 
+def test_func_sinfo_well_formed():
+    @R.function
+    def foo():
+        @R.function
+        def local(x: R.Tensor(["m", "n"], "float32")):
+            return x
+
+        return local
+
+    mod = rx.transform.Normalize()(tvm.IRModule.from_expr(foo))
+    assert rx.analysis.well_formed(mod)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py
index 8f7d8bf40f..72f4e29a16 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1254,5 +1254,42 @@ def test_dead_group():
     _check(mod, Expected)
 
 
+def test_symbolic_shape_aware_fuse():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor(["n", "m"], "float32")):
+            with R.dataflow():
+                lv0 = R.emit_te(topi.add, x, R.const(1, "float32"))
+                lv1 = R.emit_te(topi.exp, lv0)
+                gv = R.emit_te(topi.squeeze, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def fused_add_exp_squeeze(
+            x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
+        ) -> R.Tensor(["n", "m"], dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            with R.dataflow():
+                lv0 = R.emit_te(topi.add, x, p0)
+                lv1 = R.emit_te(topi.exp, lv0)
+                gv = R.emit_te(topi.squeeze, lv1)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32"))
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()