You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2024/02/23 14:28:19 UTC
(tvm) branch main updated: [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat (#16596)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b5815753dc [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat (#16596)
b5815753dc is described below
commit b5815753dcaf533d2fa27048b524623bbdf87376
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Feb 23 08:28:13 2024 -0600
[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat (#16596)
* [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat
This commit implements an optional optimization pass
`relax.transform.ReorderPermuteDimsAfterConcat`, which reorder
expressions of the form `R.concat(R.permute_dims(A),
R.permute_dims(B))` into `R.permute_dims(R.concat(A,B))`.
This pass is intended to be used alongside `CombineParallelMatmul`.
After parallel matmuls are combined, to be lifted out, and optimized
`nn.Linear` kernels to find the `R.matmul(x, R.permute_dims(weights))`
patterns they are looking for.
```python
@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
"""Initial IRModule
The `R.permute_dims` followed by `R.matmul` is the relax
equivalent of `nn.Linear`, and will frequently have optimized
kernels.
"""
weight_query_T = R.permute_dims(weight_query)
query = R.matmul(x, weight_query)
weight_key_T = R.permute_dims(weight_key)
key = R.matmul(x, weight_key)
weight_value_T = R.permute_dims(weight_value)
value = R.matmul(x, weight_value)
@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
"""After `CombineParallelMatmul`
There's now only a single matmul to be performed, which is
generally better than performing three small matmuls. However,
the optimized kernels for `nn.Linear` can no longer be applied,
because the `R.concat` isn't part of the expected pattern.
"""
weight_query_T = R.permute_dims(weight_query)
weight_key_T = R.permute_dims(weight_key)
weight_value_T = R.permute_dims(weight_value)
fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1)
fused_qkv = R.matmul(x, fused_weight_T)
query, key, value = R.split(fused_qkv)
@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
"""After `ReorderPermuteDimsAfterConcat`
There's still only a single matmul, and the optimized kernels for
`nn.Linear` can be applied again.
"""
fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0)
fused_weight_T = R.permute_dims(fused_weight)
fused_qkv = R.matmul(x, fused_weight_T)
query, key, value = R.split(fused_qkv)
```
* Expand description of `max_concat` variable as a temporary solution
---
python/tvm/relax/transform/__init__.py | 1 +
python/tvm/relax/transform/transform.py | 20 ++
.../transform/reorder_permute_dims_after_concat.cc | 187 +++++++++++++++
..._transform_reorder_permute_dims_after_concat.py | 264 +++++++++++++++++++++
4 files changed, 472 insertions(+)
diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py
index 7efe144c50..c3fb0f23be 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -63,6 +63,7 @@ from .transform import (
RemovePurityChecking,
RemoveUnusedParameters,
RemoveUnusedOutputs,
+ ReorderPermuteDimsAfterConcat,
ReorderTakeAfterMatmul,
RewriteCUDAGraph,
RewriteDataflowReshape,
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index c017f0cda7..e4c66558f5 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1325,6 +1325,26 @@ def ExpandMatmulOfSum():
return _ffi_api.ExpandMatmulOfSum() # type: ignore
+def ReorderPermuteDimsAfterConcat():
+ """Reorder `concat(permute_dims(A), permute_dims(B))` into `permute_dims(concat(A,B))`
+
+ Useful for optimizing computations after `CombineParallelMatmul`.
+ The patterns for optimized `nn.Linear` implementations look for
+ `matmul(activations, permute_dims(weights))`. After
+ `CombineParallelMatmul`, the `matmul(activations,
+ concat(permute_dims(A), permute_dims(B)))` no longer matches this
+ pattern. Rearranging into `matmul(activations,
+ permute_dims(concat(A,B)))` restores the pattern match.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The corresponding pass.
+ """
+
+ return _ffi_api.ReorderPermuteDimsAfterConcat() # type: ignore
+
+
def ReorderTakeAfterMatmul():
"""Reorder `matmul(x, take(weights, indices))` to `take(matmul(x,weights),indices)`
diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc
new file mode 100644
index 0000000000..23a9d9670e
--- /dev/null
+++ b/src/relax/transform/reorder_permute_dims_after_concat.cc
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/reorder_permute_dims_after_concat.cc
+ * \brief Reorder concat(permute_dims(A), permute_dims(B)) into permute_dims(concat(A,B))
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <optional>
+#include <unordered_set>
+#include <vector>
+
+#include "../op/tensor/index.h"
+#include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+namespace {
+std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreatePatterns() {
+ // TODO(Lunderberg): Allow pattern-matching to handle a flexible
+ // number of arguments, each of which matches the same type of
+ // pattern.
+ //
+ // Because we instantiate one DFPattern for each value in
+ // `min_concat <= i <= max_concat`, we don't want to set
+ // `max_concat` to an extremely high value. The current value of 12
+ // was chosen to be significantly higher than the highest value
+ // required so far (3, for query/key/value in attention layers), but
+ // not so high that it requires an excessive number of `DFPattern`.
+ //
+ // This value is deliberately *NOT* exposed, as `max_concat` may be
+ // increased at any point that it is required, and other use cases
+ // should not depend on its value. If there is a use case that
+ // requires more matmuls to be handled, and pattern-matching does
+ // not yet support a flexible number of `Tuple` elements,
+ // `max_concat` should be increased.
+ size_t min_concat = 2;
+ size_t max_concat = 12;
+
+ std::vector<DFPattern> pat_args;
+ std::vector<DFPattern> pat_permute_dims;
+ for (size_t i = 0; i < max_concat; i++) {
+ auto arg = WildcardPattern();
+ pat_args.push_back(arg);
+ pat_permute_dims.push_back(IsOp("relax.permute_dims")(arg));
+ }
+
+ auto make_pattern_with_num_concat = [&](size_t num_concat) -> DFPattern {
+ ICHECK_LT(num_concat, pat_permute_dims.size());
+ auto concat_tuple = TuplePattern(
+ Array<DFPattern>(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat));
+ return IsOp("relax.concat")(concat_tuple);
+ };
+
+ DFPattern pat_concat = make_pattern_with_num_concat(min_concat);
+ for (size_t i = min_concat + 1; i < max_concat; i++) {
+ pat_concat = pat_concat | make_pattern_with_num_concat(i);
+ }
+
+ auto get_permute_dims_optional_axes = [](const Expr& expr) -> Optional<Array<Integer>> {
+ auto call = expr.as<CallNode>();
+ ICHECK(call);
+ auto attrs = call->attrs.as<PermuteDimsAttrs>();
+ ICHECK(attrs);
+
+ return attrs->axes;
+ };
+
+ auto get_permute_dims_axes =
+ [get_permute_dims_optional_axes](const Expr& expr) -> Array<Integer> {
+ if (auto opt_axes = get_permute_dims_optional_axes(expr)) {
+ return opt_axes.value();
+ } else {
+ auto call = Downcast<Call>(expr);
+ Array<Integer> permutation;
+ auto arg_sinfo = call->args[0]->struct_info_.as<TensorStructInfoNode>();
+ CHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, "
+ << "but argument " << call->args[0] << " has struct info "
+ << call->args[0]->struct_info_;
+ CHECK_GE(arg_sinfo->ndim, 0);
+ size_t ndim = arg_sinfo->ndim;
+ for (size_t i = 0; i < ndim; i++) {
+ permutation.push_back(Integer(ndim - i - 1));
+ }
+ return permutation;
+ }
+ };
+
+ auto permute_dims_axes_are_compatible = [&](const Array<Expr>& permute_dims) -> bool {
+ auto first_axes = get_permute_dims_axes(permute_dims[0]);
+ for (size_t i_arg = 1; i_arg < permute_dims.size(); i_arg++) {
+ auto i_axes = get_permute_dims_axes(permute_dims[i_arg]);
+ if (i_axes.size() != first_axes.size()) {
+ return false;
+ }
+ for (size_t i_axis = 0; i_axis < first_axes.size(); i_axis++) {
+ if (i_axes[i_axis]->value != first_axes[i_axis]->value) {
+ return false;
+ }
+ }
+ }
+ return true;
+ };
+
+ auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
+ Array<Expr> args;
+ Array<Expr> all_permute_dims;
+ for (size_t i = 0; i < max_concat; i++) {
+ if (auto permute_dim_expr = matches.Get(pat_permute_dims[i])) {
+ all_permute_dims.push_back(permute_dim_expr.value());
+ args.push_back(matches[pat_args[i]]);
+ }
+ }
+
+ ICHECK_GE(all_permute_dims.size(), min_concat)
+ << "InternalError: "
+ << "Pattern match should return at least " << min_concat << " items, but only found "
+ << all_permute_dims.size() << ": " << all_permute_dims;
+
+ if (!permute_dims_axes_are_compatible(all_permute_dims)) {
+ return expr;
+ }
+ Optional<Array<Integer>> permute_axes = get_permute_dims_optional_axes(all_permute_dims[0]);
+
+ Call concat_call = Downcast<Call>(matches[pat_concat]);
+ auto concat_attrs = concat_call->attrs.as<ConcatAttrs>();
+ ICHECK(concat_attrs);
+
+ auto old_concat_axis = [&]() -> size_t {
+ if (concat_attrs->axis.defined()) {
+ return concat_attrs->axis.value()->value;
+ } else {
+ return 0;
+ }
+ }();
+ Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis];
+
+ auto new_concat = concat(Tuple(args), new_concat_axis);
+ auto new_permute_dims = permute_dims(new_concat, permute_axes);
+
+ return new_permute_dims;
+ };
+
+ return {pat_concat, rewriter};
+}
+
+} // namespace
+
+namespace transform {
+Pass ReorderPermuteDimsAfterConcat() {
+ auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
+ auto [pattern, rewriter] = CreatePatterns();
+ return RewriteCall(pattern, rewriter, func);
+ };
+ return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat")
+ .set_body_typed(ReorderPermuteDimsAfterConcat);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
new file mode 100644
index 0000000000..533ba7b696
--- /dev/null
+++ b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
@@ -0,0 +1,264 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import inspect
+
+import pytest
+
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I, relax as R
+
+
+class Base:
+ def test_compare(self):
+ transform = relax.transform.ReorderPermuteDimsAfterConcat()
+
+ if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception):
+ with pytest.raises(self.Expected):
+ transform(self.Before)
+ else:
+ after = transform(self.Before)
+ tvm.ir.assert_structural_equal(self.Expected, after)
+
+
+class TestSimple(Base):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ linear_weight_A: R.Tensor([128, 32], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ matmul_weight_A = R.permute_dims(linear_weight_A)
+ matmul_weight_B = R.permute_dims(linear_weight_B)
+ matmul_weight = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+ out = R.matmul(x, matmul_weight)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ linear_weight_A: R.Tensor([128, 32], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ linear_weight = R.concat([linear_weight_A, linear_weight_B], axis=0)
+ matmul_weight = R.permute_dims(linear_weight)
+ out = R.matmul(x, matmul_weight)
+ R.output(out)
+ return out
+
+
+class TestCombineExplicitAndImplicitAxes(Base):
+ """Check for explicit axes to be permuted
+
+ If `R.permute_dims` has no axes specified, it reverses the order
+ of all axes. For a 2-d argument, `R.permute_dims(arg)` and
+ `R.permute_dims(arg, [1,0])` are equivalent, and should be
+ able to be combinable.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ linear_weight_A: R.Tensor([128, 32], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ matmul_weight_A = R.permute_dims(linear_weight_A)
+ matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+ matmul_weight = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+ out = R.matmul(x, matmul_weight)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ linear_weight_A: R.Tensor([128, 32], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ linear_weight = R.concat([linear_weight_A, linear_weight_B], axis=0)
+ matmul_weight = R.permute_dims(linear_weight)
+ out = R.matmul(x, matmul_weight)
+ R.output(out)
+ return out
+
+
+class TestDoNotCombineIncompatibleAxes(Base):
+ """No change should be made for incompatible permutations
+
+ The different `R.permute_dims` must each perform the same
+ permutation for the reordering to be valid.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ weight_A: R.Tensor([32, 128], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+ matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+ matmul_weight = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+ out = R.matmul(x, matmul_weight)
+ R.output(out)
+ return out
+
+ Expected = Before
+
+
+class TestCheckForRewriteAfterIncompatibleChange(Base):
+ """Check all R.permute_dims options, not just the first
+
+ Complex conditionals may be implemented in the rewriter, rather
+ than the pattern match. In these cases, the rewriter may return
+ the matched expression unmodified. However, this prevents the
+ pattern-matcher from checking later instances of the match.
+
+ By moving the complex conditional to a `ConstrainedPattern`, the
+ pattern-matcher can check against all possible matches.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ weight_A: R.Tensor([32, 128], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ linear_weight_C: R.Tensor([128, 32], "float32"),
+ linear_weight_D: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+ matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+ matmul_weight_AB = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+ out_AB = R.matmul(x, matmul_weight_AB)
+
+ matmul_weight_C = R.permute_dims(linear_weight_C)
+ matmul_weight_D = R.permute_dims(linear_weight_D)
+ matmul_weight_CD = R.concat([matmul_weight_C, matmul_weight_D], axis=1)
+ out_CD = R.matmul(x, matmul_weight_CD)
+
+ out = (out_AB, out_CD)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ weight_A: R.Tensor([32, 128], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ linear_weight_C: R.Tensor([128, 32], "float32"),
+ linear_weight_D: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+ matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+ matmul_weight_AB = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+ out_AB = R.matmul(x, matmul_weight_AB)
+
+ linear_weight_CD = R.concat([linear_weight_C, linear_weight_D], axis=0)
+ matmul_weight_CD = R.permute_dims(linear_weight_CD)
+ out_CD = R.matmul(x, matmul_weight_CD)
+
+ out = (out_AB, out_CD)
+ R.output(out)
+ return out
+
+
+class TestCheckForRewriteBeforeIncompatibleChange(Base):
+ """Check all R.permute_dims options, not just the first
+
+ Complex conditionals may be implemented in the rewriter, rather
+ than the pattern match. In these cases, the rewriter may return
+ the matched expression unmodified. However, this prevents the
+ pattern-matcher from checking later instances of the match.
+
+ By moving the complex conditional to a `ConstrainedPattern`, the
+ pattern-matcher can check against all possible matches.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ weight_A: R.Tensor([32, 128], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ linear_weight_C: R.Tensor([128, 32], "float32"),
+ linear_weight_D: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ matmul_weight_C = R.permute_dims(linear_weight_C)
+ matmul_weight_D = R.permute_dims(linear_weight_D)
+ matmul_weight_CD = R.concat([matmul_weight_C, matmul_weight_D], axis=1)
+ out_CD = R.matmul(x, matmul_weight_CD)
+
+ matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+ matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+ matmul_weight_AB = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+ out_AB = R.matmul(x, matmul_weight_AB)
+
+ out = (out_AB, out_CD)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([1, 32], "float32"),
+ weight_A: R.Tensor([32, 128], "float32"),
+ linear_weight_B: R.Tensor([128, 32], "float32"),
+ linear_weight_C: R.Tensor([128, 32], "float32"),
+ linear_weight_D: R.Tensor([128, 32], "float32"),
+ ):
+ with R.dataflow():
+ linear_weight_CD = R.concat([linear_weight_C, linear_weight_D], axis=0)
+ matmul_weight_CD = R.permute_dims(linear_weight_CD)
+ out_CD = R.matmul(x, matmul_weight_CD)
+
+ matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+ matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+ matmul_weight_AB = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+ out_AB = R.matmul(x, matmul_weight_AB)
+
+ out = (out_AB, out_CD)
+ R.output(out)
+ return out
+
+
+if __name__ == "__main__":
+ tvm.testing.main()