You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2024/02/23 14:41:32 UTC

(tvm) branch main updated: [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#16591)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 89cc09c621 [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#16591)
89cc09c621 is described below

commit 89cc09c62103d74dce02e03754261b1e205cadab
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Feb 23 08:41:26 2024 -0600

    [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#16591)
    
    * [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul
    
    Prior to this commit, if the weight of a matmul a dynamic shape, and that
    matmul is being combined with the `CombineParallelMatmul` pass, it
    could cause a segfault when `dim.as<IntImmNode>()` returns a null
    pointer.
    
    This commit adds explicit test cases for these dynamic shapes, and
    updates `CombineParallelMatmul` to handle the dynamic shapes.
    
    * Add Tuple constructor for PR-16589
---
 include/tvm/relax/expr.h                           |  18 +++
 src/relax/transform/combine_parallel_matmul.cc     | 160 +++++++++++++--------
 .../test_transform_combine_parallel_matmul.py      | 123 +++++++++++++++-
 3 files changed, 240 insertions(+), 61 deletions(-)

diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index bb1b2c8dd7..23262ea817 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -320,6 +320,24 @@ class Tuple : public Expr {
    */
   TVM_DLL explicit Tuple(tvm::Array<Expr> fields, Span span = Span());
 
+  /*!
+   * \brief Utility constructor to handle conversion to relax::Expr
+   *
+   * If the calling scope already has an array of a specific type of
+   * relax expression (e.g. `Array<relax::Var>`), it must be converted
+   * into an array of base type.  This constructor handles the
+   * conversion to the base `Array<relax::Expr>`.
+   *
+   * \tparam RelaxExpr The type of relax expression passed in as an argument.
+   *
+   * \param fields The fields of a tuple.
+   *
+   * \param span The source span of the expression.
+   */
+  template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
+  TVM_DLL explicit Tuple(tvm::Array<RelaxExpr> fields, Span span = Span())
+      : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
+
   TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
 };
diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc
index 3ea17fdd70..7e6aa6277b 100644
--- a/src/relax/transform/combine_parallel_matmul.cc
+++ b/src/relax/transform/combine_parallel_matmul.cc
@@ -71,7 +71,16 @@ struct Patterns {
   WildcardPattern input;
   std::vector<WildcardPattern> rhs;
   std::vector<WildcardPattern> bias;
-  std::vector<CallPattern> matmul, bias_add, activation;
+  std::vector<CallPattern> matmul;
+  std::vector<CallPattern> bias_add;
+  std::vector<CallPattern> activation;
+};
+
+struct SplitInfo {
+  Var rhs;
+  Optional<Var> bias;
+  PrimExpr split_size;
+  DFPattern pattern_to_replace;
 };
 
 Patterns CreatePatterns(const BranchInfo& branch_info) {
@@ -140,40 +149,68 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
     for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
       if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, rhs_shapes)) continue;
 
-      auto inp = matchings[patterns.input];
+      auto lhs = matchings[patterns.input];
+
+      const auto& patterns_to_replace = [&patterns, &branch_info]() {
+        if (branch_info.activation) return patterns.activation;
+        if (branch_info.bias_dim) return patterns.bias_add;
+        return patterns.matmul;
+      }();
 
-      Array<Var> rhs, bias;
-      for (auto ind : indices) {
-        rhs.push_back(matchings[patterns.rhs[ind]]);
-        if (branch_info.bias_dim) {
-          ICHECK(matchings.count(patterns.bias[ind]));
-          bias.push_back(matchings[patterns.bias[ind]]);
+      std::vector<SplitInfo> splits;
+      for (auto index : indices) {
+        Var rhs = matchings[patterns.rhs[index]];
+        Optional<Var> bias = NullOpt;
+        if (branch_info.bias_dim.has_value()) {
+          bias = matchings[patterns.bias[index]];
         }
+        PrimExpr split_size = GetTensorSInfo(rhs)->GetShape().value()[rhs_dim - 1];
+        DFPattern pattern_to_replace = patterns_to_replace[index];
+        splits.push_back(SplitInfo{rhs, bias, split_size, pattern_to_replace});
+      }
+      // At most one dynamic output shape can be part of the combined
+      // matmul, and it must be the last item in the split.  Use
+      // `std::stable_sort` instead of `std::sort` to maintain a
+      // consistent order for all static shapes, and to consistently
+      // select the same dynamic weight to participate.
+      auto is_dynamic_split = [](const SplitInfo& split) -> bool {
+        return !split.split_size->IsInstance<IntImmNode>();
+      };
+      std::stable_sort(splits.begin(), splits.end(),
+                       [&is_dynamic_split](const auto& a, const auto& b) {
+                         return is_dynamic_split(a) < is_dynamic_split(b);
+                       });
+      // Remove anything after the first dynamic shape participating
+      // in the combined matmul.
+      if (auto it = std::find_if(splits.begin(), splits.end(), is_dynamic_split);
+          it != splits.end()) {
+        splits.erase(it + 1, splits.end());
       }
 
-      if (!check(inp, rhs, bias, bindings)) {
+      if (splits.size() == 1) {
         continue;
       }
 
-      auto make_tuple = [](const Array<Var>& var_array) {
-        Array<Expr> exp_array;
-        for (auto v : var_array) exp_array.push_back(v);
-        return Tuple(exp_array);
-      };
+      Array<Var> rhs;
+      Array<Var> bias;
+      for (const auto& split : splits) {
+        rhs.push_back(split.rhs);
+        if (split.bias) {
+          bias.push_back(split.bias.value());
+        }
+      }
 
-      auto concat_rhs = concat(make_tuple(rhs), Integer(rhs_dim - 1));
-      auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
-      auto matmul_combined = matmul(inp, concat_rhs, out_dtype);
+      if (!check(lhs, rhs, bias, bindings)) {
+        continue;
+      }
 
-      const auto& pattern_to_replace = [&patterns, &branch_info]() {
-        if (branch_info.activation) return patterns.activation;
-        if (branch_info.bias_dim) return patterns.bias_add;
-        return patterns.matmul;
-      }();
+      auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1));
+      auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
+      auto matmul_combined = matmul(lhs, concat_rhs, out_dtype);
 
       if (branch_info.bias_dim) {
         auto bias_dim = GetTensorSInfo(bias[0])->ndim;
-        auto concat_bias = concat(make_tuple(bias), Integer(bias_dim - 1));
+        auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1));
         matmul_combined = add(matmul_combined, concat_bias);
       }
 
@@ -191,20 +228,23 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
         }
       }
 
-      int ind = 0;
+      int split_index = 0;
       Array<IntImm> sections;
-      for (int i = 0; i < static_cast<int>(indices.size()) - 1; ++i) {
-        auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1].as<IntImmNode>();
-        ind += width->value;
-        sections.push_back(IntImm(DataType::Int(64), ind));
+      for (size_t i = 0; i + 1 < splits.size(); i++) {
+        auto width = splits[i].split_size.as<IntImmNode>();
+        ICHECK(width) << "InternalError: "
+                      << "All splits except the last one must have a static shape";
+        split_index += width->value;
+        sections.push_back(IntImm(DataType::Int(64), split_index));
       }
 
-      int lhs_dim = GetTensorSInfo(inp)->ndim;
+      int lhs_dim = GetTensorSInfo(lhs)->ndim;
       int split_axis = std::max<int>(lhs_dim, rhs_dim) - 1;
       auto chunks = split(matmul_combined, sections, split_axis);
 
-      for (size_t i = 0; i < indices.size(); ++i) {
-        auto bound_var = matchings[pattern_to_replace[indices[i]]];
+      for (size_t i = 0; i < splits.size(); i++) {
+        const auto& split = splits[i];
+        auto bound_var = matchings[split.pattern_to_replace];
         replacements.Set(bound_var, TupleGetItem(chunks, i));
       }
     }
@@ -244,43 +284,43 @@ std::vector<BranchInfo> GetBranchInfo(Function f) {
 
     PostOrderVisit(f, [&](const Expr& e) {
       if (!e->IsInstance<CallNode>()) return;
-      if (auto match = ExtractMatchedExpr(pat, e, bindings)) {
-        auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
-        auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);
 
-        auto it = groups.find(matmul_lhs.get());
-        BranchInfo* branch = it != groups.end() ? &it->second : nullptr;
-        std::optional<int> bias_dim = std::nullopt;
-        std::optional<std::string> activation = std::nullopt;
+      auto match = ExtractMatchedExpr(pat, e, bindings);
+      if (!match) return;
 
-        if (match.value().count(bias_pat)) {
-          bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
-        }
+      auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
+      auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);
 
-        for (size_t i = 0; i < activations.size(); ++i) {
-          if (match.value().count(activation_pat[i]) ||
-              match.value().count(bias_activation_pat[i])) {
-            activation = activations[i];
-          }
+      std::optional<int> bias_dim = std::nullopt;
+      std::optional<std::string> activation = std::nullopt;
+
+      if (match.value().count(bias_pat)) {
+        bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
+      }
+
+      for (size_t i = 0; i < activations.size(); ++i) {
+        if (match.value().count(activation_pat[i]) || match.value().count(bias_activation_pat[i])) {
+          activation = activations[i];
         }
+      }
 
-        if (!branch) {
-          // Create a new subgraph with one matmul
-          groups[matmul_lhs.get()] = {1, bias_dim, activation};
-        } else {
-          // Create a new branch in the existing parallel matmul subtree, and
-          // invalidate bias and activation information when needed.
-          branch->num_branches += 1;
+      if (auto it = groups.find(matmul_lhs.get()); it != groups.end()) {
+        // Create a new branch in the existing parallel matmul subtree, and
+        // invalidate bias and activation information when needed.
+        BranchInfo* branch = &it->second;
+
+        branch->num_branches += 1;
 
-          if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
-            branch->bias_dim = std::nullopt;
-          }
+        if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
+          branch->bias_dim = std::nullopt;
+        }
 
-          if (!activation || (branch->activation && *branch->activation != *activation)) {
-            branch->activation = std::nullopt;
-          }
+        if (!activation || (branch->activation && *branch->activation != *activation)) {
+          branch->activation = std::nullopt;
         }
-        return;
+      } else {
+        // Create a new subgraph with one matmul
+        groups[matmul_lhs.get()] = {1, bias_dim, activation};
       }
     });
 
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py
index 7e7f2328f3..6168d0c58d 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -525,7 +525,16 @@ def test_check():
     tvm.ir.assert_structural_equal(after, expected)
 
 
-def test_dynamic_rhs():
+def test_combine_matmul_of_static_and_dynamic_shapes():
+    """Combine two matmuls, one with dynamic shape
+
+    The `R.split` operator must have a static list of integer indices
+    at which to split the matmul output, because these integer indices
+    are stored as operator attributes.  However, the last output can
+    still have a dynamic shape.
+
+    """
+
     @R.function(private=True)
     def before(
         x: R.Tensor((2, 1024, 640), "float32"),
@@ -572,5 +581,117 @@ def test_dynamic_rhs():
     tvm.ir.assert_structural_equal(after, expected)
 
 
+def test_combine_matmul_of_dynamic_and_static_shapes():
+    """Combine two matmuls, one with dynamic shape
+
+    Like `test_combine_matmul_of_static_and_dynamic_shapes`, but the
+    dynamic-shaped matmul is encountered first.  Due to the
+    requirements imposed by `R.split` storing the split indices as
+    static integers, the static-shaped weights must occur first in the
+    concatenated weights.
+    """
+
+    @R.function(private=True)
+    def before(
+        x: R.Tensor((2, 1024, 640), "float32"),
+        w0: R.Tensor((640, "M"), "float32"),
+        w1: R.Tensor((640, 640), "float32"),
+    ):
+        M = T.int64()
+        with R.dataflow():
+            lv0 = R.matmul(x, w0)
+            lv1 = R.matmul(x, w1)
+            out = (lv0, lv1)
+            R.output(out)
+        return out
+
+    @R.function(private=True)
+    def expected(
+        x: R.Tensor((2, 1024, 640), dtype="float32"),
+        w0: R.Tensor((640, "M"), dtype="float32"),
+        w1: R.Tensor((640, 640), dtype="float32"),
+    ) -> R.Tuple(
+        R.Tensor((2, 1024, "M"), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")
+    ):
+        M = T.int64()
+        with R.dataflow():
+            lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w1, w0), axis=1)
+            lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul(
+                x, lv, out_dtype="float32"
+            )
+            lv2: R.Tuple(
+                R.Tensor((2, 1024, 640), dtype="float32"),
+                R.Tensor((2, 1024, M), dtype="float32"),
+            ) = R.split(lv1, indices_or_sections=[640], axis=2)
+            lv0: R.Tensor((2, 1024, M), dtype="float32") = lv2[1]
+            lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0]
+            out: R.Tuple(
+                R.Tensor((2, 1024, M), dtype="float32"),
+                R.Tensor((2, 1024, 640), dtype="float32"),
+            ) = (lv0, lv1_1)
+            R.output(out)
+        return out
+
+    after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
+
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_limit_one_dynamic_shape_in_combined_matmul():
+    """Combine two matmuls, one with dynamic shape
+
+    Like `test_combine_matmul_of_static_and_dynamic_shapes`, but with
+    two dynamic weights that could, in principle, be merged together.
+    Because `R.split` must have integer indices at which to split,
+    only one of the dynamic outputs can be part of the combined
+    matmul.
+    """
+
+    @R.function(private=True)
+    def before(
+        x: R.Tensor((2, 1024, 640), "float32"),
+        w0: R.Tensor((640, "M"), "float32"),
+        w1: R.Tensor((640, 640), "float32"),
+        w2: R.Tensor((640, "N"), "float32"),
+    ):
+        M = T.int64()
+        with R.dataflow():
+            lv0 = R.matmul(x, w0)
+            lv1 = R.matmul(x, w1)
+            lv2 = R.matmul(x, w2)
+            out = (lv0, lv1, lv2)
+            R.output(out)
+        return out
+
+    @R.function(private=True)
+    def expected(
+        x: R.Tensor((2, 1024, 640), dtype="float32"),
+        w0: R.Tensor((640, "M"), dtype="float32"),
+        w1: R.Tensor((640, 640), dtype="float32"),
+        w2: R.Tensor((640, "N"), "float32"),
+    ) -> R.Tuple(
+        R.Tensor((2, 1024, "M"), dtype="float32"),
+        R.Tensor((2, 1024, 640), dtype="float32"),
+        R.Tensor((2, 1024, "N"), dtype="float32"),
+    ):
+        M = T.int64()
+        with R.dataflow():
+            concat_weights = R.concat((w1, w0), axis=1)
+            concat_output = R.matmul(x, concat_weights, out_dtype="float32")
+            split_output: R.Tuple(
+                [R.Tensor([2, 1024, 640], dtype="float32"), R.Tensor([2, 1024, M], dtype="float32")]
+            ) = R.split(concat_output, indices_or_sections=[640], axis=2)
+            lv0 = split_output[1]
+            lv1 = split_output[0]
+            lv2 = R.matmul(x, w2)
+            out = (lv0, lv1, lv2)
+            R.output(out)
+        return out
+
+    after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
+
+    tvm.ir.assert_structural_equal(after, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()