You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/06/24 13:03:21 UTC

[tvm] branch unity updated: [Unity][Pass] FuseOps with partially accessed Tuple param (#15152)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7c3369c72f [Unity][Pass] FuseOps with partially accessed Tuple param (#15152)
7c3369c72f is described below

commit 7c3369c72f52467291a96346580c99cf093f179e
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Sat Jun 24 06:03:14 2023 -0700

    [Unity][Pass] FuseOps with partially accessed Tuple param (#15152)
    
    This PR enhances the FuseOps pass to better support parameters in fused
    subfunctions where the parameter
    * denotes a Tuple, and
    * is accessed only by TupleGetItem in the fused subfunction.
    
    We now name the property above as "a Tuple param being partially used".
    For example, in the following Relax function, parameter `x` is partially
    used, as it has 6 fields while only `x[0]` is effectively used in the
    function.
    
    ```python
    @R.function
    def fused_add_divide(
        x: R.Tuple(
            R.Tensor((2,), dtype="float32"),
            R.Tensor((2,), dtype="float32"),
            R.Tensor((2,), dtype="float32"),
            R.Tensor((2,), dtype="float32"),
            R.Tensor((2,), dtype="float32"),
            R.Tensor((2,), dtype="float32"),
        ),
        param_0: R.Tensor((), dtype="float32"),
        param_1: R.Tensor((), dtype="float32"),
    ) -> R.Tensor((2,), dtype="float32"):
        R.func_attr({"Primitive": 1})
        cls = Module
        with R.dataflow():
            x0: R.Tensor((2,), dtype="float32") = x[0]
            y0 = R.call_tir(cls.add, (x0, param_0), out_sinfo=R.Tensor((2,), dtype="float32"))
            gv = R.call_tir(cls.divide, (y0, param_1), out_sinfo=R.Tensor((2,), dtype="float32"))
            R.output(gv)
        return gv
    ```
    
    Prior to this PR, the FuseOps pass generates fused functions as above.
    As we can see, it frequently happens that an entire Tuple is passed into
    the function as input, while only few of its fields are effectively used
    in the function.
    
    This behavior here lead to over-consumption of memory for the runtime
    kernel(s) that the fused subfunction will eventually become, which is a
    disaster when the Tuple length is very long. And this case indeed
    happens during our iteration of vertical applications (e.g., MLC-LLM),
    where if the model parameters are given as a Tuple of the main Relax
    function, the Tuple will have length of hundreds or even larger. In this
    case, even running compilation of the model takes much longer, and the
    dumped IRModule TVMScript has tens of MB in size, which is unaffordable.
    
    Therefore, we hereby update the FuseOps pass with better support on the
    partially accessed Tuple parameters. With the enhancement, when a Tuple
    parameter is partially accessed, we will replace the Tuple parameter
    with the list of new parameters, one for each of its field that is
    accessed in the function. You can check the added unit test as a best
    illustrative example.
---
 src/relax/transform/fuse_ops.cc               | 74 +++++++++++++++++++++++++--
 tests/python/relax/test_transform_fuse_ops.py | 62 +++++++++++++++++++++-
 2 files changed, 132 insertions(+), 4 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 1942fbddfe..7c4d0ad303 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -384,6 +384,7 @@ class FunctionCreator : public ExprMutator {
           const Tuple& args = Downcast<Tuple>(call->args[1]);
           for (const Expr& arg : args->fields) {
             CheckDefAndUpdateParam(arg);
+            ICHECK(GetStructInfoAs<TupleStructInfoNode>(arg) == nullptr);
           }
           // TODO(tvm-team): handle shape expr
         } else {
@@ -400,11 +401,21 @@ class FunctionCreator : public ExprMutator {
 
           for (const Expr& arg : call->args) {
             CheckDefAndUpdateParam(arg);
+            if (GetStructInfoAs<TupleStructInfoNode>(arg) != nullptr) {
+              // The argument is fully referenced. Thus we remove it from the mapping.
+              partially_used_tuple_params_.erase(arg.get());
+            }
           }
         }
       } else if (var_binding->value.as<TupleGetItemNode>()) {
         const auto* tuple_item = var_binding->value.as<TupleGetItemNode>();
         CheckDefAndUpdateParam(tuple_item->tuple);
+
+        if (partially_used_tuple_params_.find(tuple_item->tuple.get()) !=
+            partially_used_tuple_params_.end()) {
+          // Appending get-item index to the mapping.
+          partially_used_tuple_params_[tuple_item->tuple.get()].push_back(tuple_item->index);
+        }
       }
 
       // Mark the binding variable as defined.
@@ -440,9 +451,51 @@ class FunctionCreator : public ExprMutator {
     // Step 1. Start constructing a new dataflow block.
     builder_->BeginDataflowBlock();
 
-    // Step 2. Visit each binding and collect outputs one by one.
+    // Step 2. Handing partially used tuple parameters: replacing entire tuple
+    // parameters with the parameters of its fields that are accessed in the
+    // function.
+    std::unordered_map<const ExprNode*, std::unordered_map<int, Var>> tuple_get_item_remap;
+    for (auto& [tuple_arg, item_indices] : partially_used_tuple_params_) {
+      ICHECK(!item_indices.empty());
+      int param_idx = tuple_param_idx_[tuple_arg];
+      Var param = params_[param_idx];
+      String param_name = params_[param_idx]->name_hint();
+      TupleStructInfo param_sinfo = Downcast<TupleStructInfo>(tuple_arg->struct_info_);
+
+      Array<Expr> item_args;
+      Array<Var> item_params;
+      item_args.reserve(item_indices.size());
+      item_params.reserve(item_indices.size());
+      for (int item_idx : item_indices) {
+        Var item_param(param_name + "_" + std::to_string(item_idx), param_sinfo->fields[item_idx]);
+        item_args.push_back(TupleGetItem(GetRef<Expr>(tuple_arg), item_idx));
+        item_params.push_back(item_param);
+        tuple_get_item_remap[tuple_arg][item_idx] = item_param;
+      }
+      arguments_.erase(arguments_.begin() + param_idx);
+      arguments_.insert(arguments_.begin() + param_idx, item_args.begin(), item_args.end());
+      params_.erase(params_.begin() + param_idx);
+      params_.insert(params_.begin() + param_idx, item_params.begin(), item_params.end());
+    }
+
+    // Step 3. Visit each binding and collect outputs one by one.
     Array<Expr> outputs(output_vars_.size(), Expr());
     for (const Binding& binding : bindings_) {
+      // Special handing for TupleGetItem.
+      if (const auto* var_binding = binding.as<VarBindingNode>()) {
+        if (const auto* tuple_get_item = var_binding->value.as<TupleGetItemNode>()) {
+          auto it = tuple_get_item_remap.find(tuple_get_item->tuple.get());
+          if (it != tuple_get_item_remap.end()) {
+            ICHECK(it->second.find(tuple_get_item->index) != it->second.end());
+            var_remap_[var_binding->var->vid] = it->second[tuple_get_item->index];
+            if (auto output_idx = GetOutputIndex(binding->var)) {
+              outputs.Set(*output_idx, it->second[tuple_get_item->index]);
+            }
+            continue;
+          }
+        }
+      }
+
       if (auto output_idx = GetOutputIndex(binding->var)) {
         // Case 1. It is an output binding
         // We only allow VarBinding as output.
@@ -457,7 +510,7 @@ class FunctionCreator : public ExprMutator {
       }
     }
 
-    // Step 3. Finish constructing the new block.
+    // Step 4. Finish constructing the new block.
     BindingBlock new_block = builder_->EndBlock();
     if (outputs.empty()) {
       // If the result is not used outside
@@ -532,9 +585,17 @@ class FunctionCreator : public ExprMutator {
         name = String("param_" + std::to_string(n_param_for_const_++));
       }
 
-      Var param(std::move(name), GetStructInfo(expr));
+      StructInfo param_sinfo = GetStructInfo(expr);
+      Var param(std::move(name), param_sinfo);
       arguments_.push_back(expr);
       params_.push_back(param);
+
+      // Mark the tuple parameter is partially referenced in the beginning.
+      // We will remove it from the mapping once we find it is fully referenced.
+      if (param_sinfo->IsInstance<TupleStructInfoNode>()) {
+        partially_used_tuple_params_[expr.get()] = {};
+        tuple_param_idx_[expr.get()] = static_cast<int>(arguments_.size()) - 1;
+      }
     }
   }
 
@@ -557,6 +618,13 @@ class FunctionCreator : public ExprMutator {
   std::vector<const VarNode*> output_vars_;
   /*! \brief Whether or not to lift bound constants to parameters */
   bool lift_constant_;
+  /*! \brief Mapping from tuple parameter of the function to its position index */
+  std::unordered_map<const ExprNode*, int> tuple_param_idx_;
+  /*!
+   * \brief Mapping from partially referenced tuple parameter to the list of
+   * indices that the parameter is referred by TupleGetItem
+   */
+  std::unordered_map<const ExprNode*, std::vector<int>> partially_used_tuple_params_;
 };
 
 /*!
diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py
index 169539b072..14c3dbe713 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1123,7 +1123,6 @@ def test_multiple_paths():
 
 
 def test_dead_group():
-
     # fmt: off
 
     @I.ir_module
@@ -1411,5 +1410,66 @@ def test_skipping_primvalue():
     _check(Module, Module)
 
 
+def test_partially_used_tuple_param():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            x: R.Tuple(
+                R.Tensor((2,), "float32"),
+                R.Tensor((2,), "float32"),
+                R.Tensor((2,), "float32"),
+                R.Tensor((2,), "float32"),
+                R.Tensor((2,), "float32"),
+                R.Tensor((2,), "float32"),
+            )
+        ):
+            with R.dataflow():
+                x0 = x[0]
+                y0 = R.emit_te(topi.add, x0, R.const(1, "float32"))
+                y1 = R.emit_te(topi.divide, y0, R.const(1, "float32"))
+                gv = y1
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def fused_add_divide(
+            x_0: R.Tensor((2,), dtype="float32"),
+            param_0: R.Tensor((), dtype="float32"),
+            param_1: R.Tensor((), dtype="float32"),
+        ) -> R.Tensor((2,), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            with R.dataflow():
+                y0 = R.emit_te(topi.add, x_0, param_0)
+                gv = R.emit_te(topi.divide, y0, param_1)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tuple(
+                R.Tensor((2,), dtype="float32"),
+                R.Tensor((2,), dtype="float32"),
+                R.Tensor((2,), dtype="float32"),
+                R.Tensor((2,), dtype="float32"),
+                R.Tensor((2,), dtype="float32"),
+                R.Tensor((2,), dtype="float32"),
+            )
+        ) -> R.Tensor((2,), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="float32") = x[0]
+                lv1: R.Tensor((2,), dtype="float32") = cls.fused_add_divide(
+                    lv, R.const(1, "float32"), R.const(1, "float32")
+                )
+                gv: R.Tensor((2,), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    _check(Module, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()