You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2024/02/23 14:41:32 UTC
(tvm) branch main updated: [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#16591)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 89cc09c621 [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#16591)
89cc09c621 is described below
commit 89cc09c62103d74dce02e03754261b1e205cadab
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Feb 23 08:41:26 2024 -0600
[Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#16591)
* [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul
Prior to this commit, if the weight of a matmul a dynamic shape, and that
matmul is being combined with the `CombineParallelMatmul` pass, it
could cause a segfault when `dim.as<IntImmNode>()` returns a null
pointer.
This commit adds explicit test cases for these dynamic shapes, and
updates `CombineParallelMatmul` to handle the dynamic shapes.
* Add Tuple constructor for PR-16589
---
include/tvm/relax/expr.h | 18 +++
src/relax/transform/combine_parallel_matmul.cc | 160 +++++++++++++--------
.../test_transform_combine_parallel_matmul.py | 123 +++++++++++++++-
3 files changed, 240 insertions(+), 61 deletions(-)
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index bb1b2c8dd7..23262ea817 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -320,6 +320,24 @@ class Tuple : public Expr {
*/
TVM_DLL explicit Tuple(tvm::Array<Expr> fields, Span span = Span());
+ /*!
+ * \brief Utility constructor to handle conversion to relax::Expr
+ *
+ * If the calling scope already has an array of a specific type of
+ * relax expression (e.g. `Array<relax::Var>`), it must be converted
+ * into an array of base type. This constructor handles the
+ * conversion to the base `Array<relax::Expr>`.
+ *
+ * \tparam RelaxExpr The type of relax expression passed in as an argument.
+ *
+ * \param fields The fields of a tuple.
+ *
+ * \param span The source span of the expression.
+ */
+ template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
+ TVM_DLL explicit Tuple(tvm::Array<RelaxExpr> fields, Span span = Span())
+ : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
+
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};
diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc
index 3ea17fdd70..7e6aa6277b 100644
--- a/src/relax/transform/combine_parallel_matmul.cc
+++ b/src/relax/transform/combine_parallel_matmul.cc
@@ -71,7 +71,16 @@ struct Patterns {
WildcardPattern input;
std::vector<WildcardPattern> rhs;
std::vector<WildcardPattern> bias;
- std::vector<CallPattern> matmul, bias_add, activation;
+ std::vector<CallPattern> matmul;
+ std::vector<CallPattern> bias_add;
+ std::vector<CallPattern> activation;
+};
+
+struct SplitInfo {
+ Var rhs;
+ Optional<Var> bias;
+ PrimExpr split_size;
+ DFPattern pattern_to_replace;
};
Patterns CreatePatterns(const BranchInfo& branch_info) {
@@ -140,40 +149,68 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, rhs_shapes)) continue;
- auto inp = matchings[patterns.input];
+ auto lhs = matchings[patterns.input];
+
+ const auto& patterns_to_replace = [&patterns, &branch_info]() {
+ if (branch_info.activation) return patterns.activation;
+ if (branch_info.bias_dim) return patterns.bias_add;
+ return patterns.matmul;
+ }();
- Array<Var> rhs, bias;
- for (auto ind : indices) {
- rhs.push_back(matchings[patterns.rhs[ind]]);
- if (branch_info.bias_dim) {
- ICHECK(matchings.count(patterns.bias[ind]));
- bias.push_back(matchings[patterns.bias[ind]]);
+ std::vector<SplitInfo> splits;
+ for (auto index : indices) {
+ Var rhs = matchings[patterns.rhs[index]];
+ Optional<Var> bias = NullOpt;
+ if (branch_info.bias_dim.has_value()) {
+ bias = matchings[patterns.bias[index]];
}
+ PrimExpr split_size = GetTensorSInfo(rhs)->GetShape().value()[rhs_dim - 1];
+ DFPattern pattern_to_replace = patterns_to_replace[index];
+ splits.push_back(SplitInfo{rhs, bias, split_size, pattern_to_replace});
+ }
+ // At most one dynamic output shape can be part of the combined
+ // matmul, and it must be the last item in the split. Use
+ // `std::stable_sort` instead of `std::sort` to maintain a
+ // consistent order for all static shapes, and to consistently
+ // select the same dynamic weight to participate.
+ auto is_dynamic_split = [](const SplitInfo& split) -> bool {
+ return !split.split_size->IsInstance<IntImmNode>();
+ };
+ std::stable_sort(splits.begin(), splits.end(),
+ [&is_dynamic_split](const auto& a, const auto& b) {
+ return is_dynamic_split(a) < is_dynamic_split(b);
+ });
+ // Remove anything after the first dynamic shape participating
+ // in the combined matmul.
+ if (auto it = std::find_if(splits.begin(), splits.end(), is_dynamic_split);
+ it != splits.end()) {
+ splits.erase(it + 1, splits.end());
}
- if (!check(inp, rhs, bias, bindings)) {
+ if (splits.size() == 1) {
continue;
}
- auto make_tuple = [](const Array<Var>& var_array) {
- Array<Expr> exp_array;
- for (auto v : var_array) exp_array.push_back(v);
- return Tuple(exp_array);
- };
+ Array<Var> rhs;
+ Array<Var> bias;
+ for (const auto& split : splits) {
+ rhs.push_back(split.rhs);
+ if (split.bias) {
+ bias.push_back(split.bias.value());
+ }
+ }
- auto concat_rhs = concat(make_tuple(rhs), Integer(rhs_dim - 1));
- auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
- auto matmul_combined = matmul(inp, concat_rhs, out_dtype);
+ if (!check(lhs, rhs, bias, bindings)) {
+ continue;
+ }
- const auto& pattern_to_replace = [&patterns, &branch_info]() {
- if (branch_info.activation) return patterns.activation;
- if (branch_info.bias_dim) return patterns.bias_add;
- return patterns.matmul;
- }();
+ auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1));
+ auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
+ auto matmul_combined = matmul(lhs, concat_rhs, out_dtype);
if (branch_info.bias_dim) {
auto bias_dim = GetTensorSInfo(bias[0])->ndim;
- auto concat_bias = concat(make_tuple(bias), Integer(bias_dim - 1));
+ auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1));
matmul_combined = add(matmul_combined, concat_bias);
}
@@ -191,20 +228,23 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
}
}
- int ind = 0;
+ int split_index = 0;
Array<IntImm> sections;
- for (int i = 0; i < static_cast<int>(indices.size()) - 1; ++i) {
- auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1].as<IntImmNode>();
- ind += width->value;
- sections.push_back(IntImm(DataType::Int(64), ind));
+ for (size_t i = 0; i + 1 < splits.size(); i++) {
+ auto width = splits[i].split_size.as<IntImmNode>();
+ ICHECK(width) << "InternalError: "
+ << "All splits except the last one must have a static shape";
+ split_index += width->value;
+ sections.push_back(IntImm(DataType::Int(64), split_index));
}
- int lhs_dim = GetTensorSInfo(inp)->ndim;
+ int lhs_dim = GetTensorSInfo(lhs)->ndim;
int split_axis = std::max<int>(lhs_dim, rhs_dim) - 1;
auto chunks = split(matmul_combined, sections, split_axis);
- for (size_t i = 0; i < indices.size(); ++i) {
- auto bound_var = matchings[pattern_to_replace[indices[i]]];
+ for (size_t i = 0; i < splits.size(); i++) {
+ const auto& split = splits[i];
+ auto bound_var = matchings[split.pattern_to_replace];
replacements.Set(bound_var, TupleGetItem(chunks, i));
}
}
@@ -244,43 +284,43 @@ std::vector<BranchInfo> GetBranchInfo(Function f) {
PostOrderVisit(f, [&](const Expr& e) {
if (!e->IsInstance<CallNode>()) return;
- if (auto match = ExtractMatchedExpr(pat, e, bindings)) {
- auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
- auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);
- auto it = groups.find(matmul_lhs.get());
- BranchInfo* branch = it != groups.end() ? &it->second : nullptr;
- std::optional<int> bias_dim = std::nullopt;
- std::optional<std::string> activation = std::nullopt;
+ auto match = ExtractMatchedExpr(pat, e, bindings);
+ if (!match) return;
- if (match.value().count(bias_pat)) {
- bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
- }
+ auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
+ auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);
- for (size_t i = 0; i < activations.size(); ++i) {
- if (match.value().count(activation_pat[i]) ||
- match.value().count(bias_activation_pat[i])) {
- activation = activations[i];
- }
+ std::optional<int> bias_dim = std::nullopt;
+ std::optional<std::string> activation = std::nullopt;
+
+ if (match.value().count(bias_pat)) {
+ bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
+ }
+
+ for (size_t i = 0; i < activations.size(); ++i) {
+ if (match.value().count(activation_pat[i]) || match.value().count(bias_activation_pat[i])) {
+ activation = activations[i];
}
+ }
- if (!branch) {
- // Create a new subgraph with one matmul
- groups[matmul_lhs.get()] = {1, bias_dim, activation};
- } else {
- // Create a new branch in the existing parallel matmul subtree, and
- // invalidate bias and activation information when needed.
- branch->num_branches += 1;
+ if (auto it = groups.find(matmul_lhs.get()); it != groups.end()) {
+ // Create a new branch in the existing parallel matmul subtree, and
+ // invalidate bias and activation information when needed.
+ BranchInfo* branch = &it->second;
+
+ branch->num_branches += 1;
- if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
- branch->bias_dim = std::nullopt;
- }
+ if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
+ branch->bias_dim = std::nullopt;
+ }
- if (!activation || (branch->activation && *branch->activation != *activation)) {
- branch->activation = std::nullopt;
- }
+ if (!activation || (branch->activation && *branch->activation != *activation)) {
+ branch->activation = std::nullopt;
}
- return;
+ } else {
+ // Create a new subgraph with one matmul
+ groups[matmul_lhs.get()] = {1, bias_dim, activation};
}
});
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py
index 7e7f2328f3..6168d0c58d 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -525,7 +525,16 @@ def test_check():
tvm.ir.assert_structural_equal(after, expected)
-def test_dynamic_rhs():
+def test_combine_matmul_of_static_and_dynamic_shapes():
+ """Combine two matmuls, one with dynamic shape
+
+ The `R.split` operator must have a static list of integer indices
+ at which to split the matmul output, because these integer indices
+ are stored as operator attributes. However, the last output can
+ still have a dynamic shape.
+
+ """
+
@R.function(private=True)
def before(
x: R.Tensor((2, 1024, 640), "float32"),
@@ -572,5 +581,117 @@ def test_dynamic_rhs():
tvm.ir.assert_structural_equal(after, expected)
+def test_combine_matmul_of_dynamic_and_static_shapes():
+ """Combine two matmuls, one with dynamic shape
+
+ Like `test_combine_matmul_of_static_and_dynamic_shapes`, but the
+ dynamic-shaped matmul is encountered first. Due to the
+ requirements imposed by `R.split` storing the split indices as
+ static integers, the static-shaped weights must occur first in the
+ concatenated weights.
+ """
+
+ @R.function(private=True)
+ def before(
+ x: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, "M"), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ ):
+ M = T.int64()
+ with R.dataflow():
+ lv0 = R.matmul(x, w0)
+ lv1 = R.matmul(x, w1)
+ out = (lv0, lv1)
+ R.output(out)
+ return out
+
+ @R.function(private=True)
+ def expected(
+ x: R.Tensor((2, 1024, 640), dtype="float32"),
+ w0: R.Tensor((640, "M"), dtype="float32"),
+ w1: R.Tensor((640, 640), dtype="float32"),
+ ) -> R.Tuple(
+ R.Tensor((2, 1024, "M"), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")
+ ):
+ M = T.int64()
+ with R.dataflow():
+ lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w1, w0), axis=1)
+ lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul(
+ x, lv, out_dtype="float32"
+ )
+ lv2: R.Tuple(
+ R.Tensor((2, 1024, 640), dtype="float32"),
+ R.Tensor((2, 1024, M), dtype="float32"),
+ ) = R.split(lv1, indices_or_sections=[640], axis=2)
+ lv0: R.Tensor((2, 1024, M), dtype="float32") = lv2[1]
+ lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0]
+ out: R.Tuple(
+ R.Tensor((2, 1024, M), dtype="float32"),
+ R.Tensor((2, 1024, 640), dtype="float32"),
+ ) = (lv0, lv1_1)
+ R.output(out)
+ return out
+
+ after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
+
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_limit_one_dynamic_shape_in_combined_matmul():
+ """Combine two matmuls, one with dynamic shape
+
+ Like `test_combine_matmul_of_static_and_dynamic_shapes`, but with
+ two dynamic weights that could, in principle, be merged together.
+ Because `R.split` must have integer indices at which to split,
+ only one of the dynamic outputs can be part of the combined
+ matmul.
+ """
+
+ @R.function(private=True)
+ def before(
+ x: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, "M"), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, "N"), "float32"),
+ ):
+ M = T.int64()
+ with R.dataflow():
+ lv0 = R.matmul(x, w0)
+ lv1 = R.matmul(x, w1)
+ lv2 = R.matmul(x, w2)
+ out = (lv0, lv1, lv2)
+ R.output(out)
+ return out
+
+ @R.function(private=True)
+ def expected(
+ x: R.Tensor((2, 1024, 640), dtype="float32"),
+ w0: R.Tensor((640, "M"), dtype="float32"),
+ w1: R.Tensor((640, 640), dtype="float32"),
+ w2: R.Tensor((640, "N"), "float32"),
+ ) -> R.Tuple(
+ R.Tensor((2, 1024, "M"), dtype="float32"),
+ R.Tensor((2, 1024, 640), dtype="float32"),
+ R.Tensor((2, 1024, "N"), dtype="float32"),
+ ):
+ M = T.int64()
+ with R.dataflow():
+ concat_weights = R.concat((w1, w0), axis=1)
+ concat_output = R.matmul(x, concat_weights, out_dtype="float32")
+ split_output: R.Tuple(
+ [R.Tensor([2, 1024, 640], dtype="float32"), R.Tensor([2, 1024, M], dtype="float32")]
+ ) = R.split(concat_output, indices_or_sections=[640], axis=2)
+ lv0 = split_output[1]
+ lv1 = split_output[0]
+ lv2 = R.matmul(x, w2)
+ out = (lv0, lv1, lv2)
+ R.output(out)
+ return out
+
+ after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
+
+ tvm.ir.assert_structural_equal(after, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()