You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2024/02/23 14:28:19 UTC

(tvm) branch main updated: [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat (#16596)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b5815753dc [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat (#16596)
b5815753dc is described below

commit b5815753dcaf533d2fa27048b524623bbdf87376
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Feb 23 08:28:13 2024 -0600

    [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat (#16596)
    
    * [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat
    
    This commit implements an optional optimization pass
    `relax.transform.ReorderPermuteDimsAfterConcat`, which reorder
    expressions of the form `R.concat(R.permute_dims(A),
    R.permute_dims(B))` into `R.permute_dims(R.concat(A,B))`.
    
    This pass is intended to be used alongside `CombineParallelMatmul`.
    After parallel matmuls are combined, to be lifted out, and optimized
    `nn.Linear` kernels to find the `R.matmul(x, R.permute_dims(weights))`
    patterns they are looking for.
    
    ```python
    @R.function
    def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
        """Initial IRModule
    
        The `R.permute_dims` followed by `R.matmul` is the relax
        equivalent of `nn.Linear`, and will frequently have optimized
        kernels.
        """
        weight_query_T = R.permute_dims(weight_query)
        query = R.matmul(x, weight_query)
        weight_key_T = R.permute_dims(weight_key)
        key = R.matmul(x, weight_key)
        weight_value_T = R.permute_dims(weight_value)
        value = R.matmul(x, weight_value)
    
    @R.function
    def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
        """After `CombineParallelMatmul`
    
        There's now only a single matmul to be performed, which is
        generally better than performing three small matmuls.  However,
        the optimized kernels for `nn.Linear` can no longer be applied,
        because the `R.concat` isn't part of the expected pattern.
        """
        weight_query_T = R.permute_dims(weight_query)
        weight_key_T = R.permute_dims(weight_key)
        weight_value_T = R.permute_dims(weight_value)
    
        fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1)
        fused_qkv = R.matmul(x, fused_weight_T)
    
        query, key, value = R.split(fused_qkv)
    
    @R.function
    def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
        """After `ReorderPermuteDimsAfterConcat`
    
        There's still only a single matmul, and the optimized kernels for
        `nn.Linear` can be applied again.
        """
        fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0)
    
        fused_weight_T = R.permute_dims(fused_weight)
        fused_qkv = R.matmul(x, fused_weight_T)
    
        query, key, value = R.split(fused_qkv)
    ```
    
    * Expand description of `max_concat` variable as a temporary solution
---
 python/tvm/relax/transform/__init__.py             |   1 +
 python/tvm/relax/transform/transform.py            |  20 ++
 .../transform/reorder_permute_dims_after_concat.cc | 187 +++++++++++++++
 ..._transform_reorder_permute_dims_after_concat.py | 264 +++++++++++++++++++++
 4 files changed, 472 insertions(+)

diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py
index 7efe144c50..c3fb0f23be 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -63,6 +63,7 @@ from .transform import (
     RemovePurityChecking,
     RemoveUnusedParameters,
     RemoveUnusedOutputs,
+    ReorderPermuteDimsAfterConcat,
     ReorderTakeAfterMatmul,
     RewriteCUDAGraph,
     RewriteDataflowReshape,
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index c017f0cda7..e4c66558f5 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1325,6 +1325,26 @@ def ExpandMatmulOfSum():
     return _ffi_api.ExpandMatmulOfSum()  # type: ignore
 
 
+def ReorderPermuteDimsAfterConcat():
+    """Reorder `concat(permute_dims(A), permute_dims(B))` into `permute_dims(concat(A,B))`
+
+    Useful for optimizing computations after `CombineParallelMatmul`.
+    The patterns for optimized `nn.Linear` implementations look for
+    `matmul(activations, permute_dims(weights))`.  After
+    `CombineParallelMatmul`, the `matmul(activations,
+    concat(permute_dims(A), permute_dims(B)))` no longer matches this
+    pattern.  Rearranging into `matmul(activations,
+    permute_dims(concat(A,B)))` restores the pattern match.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The corresponding pass.
+    """
+
+    return _ffi_api.ReorderPermuteDimsAfterConcat()  # type: ignore
+
+
 def ReorderTakeAfterMatmul():
     """Reorder `matmul(x, take(weights, indices))` to `take(matmul(x,weights),indices)`
 
diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc
new file mode 100644
index 0000000000..23a9d9670e
--- /dev/null
+++ b/src/relax/transform/reorder_permute_dims_after_concat.cc
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/reorder_permute_dims_after_concat.cc
+ * \brief Reorder concat(permute_dims(A), permute_dims(B)) into permute_dims(concat(A,B))
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <optional>
+#include <unordered_set>
+#include <vector>
+
+#include "../op/tensor/index.h"
+#include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+namespace {
+std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreatePatterns() {
+  // TODO(Lunderberg): Allow pattern-matching to handle a flexible
+  // number of arguments, each of which matches the same type of
+  // pattern.
+  //
+  // Because we instantiate one DFPattern for each value in
+  // `min_concat <= i <= max_concat`, we don't want to set
+  // `max_concat` to an extremely high value.  The current value of 12
+  // was chosen to be significantly higher than the highest value
+  // required so far (3, for query/key/value in attention layers), but
+  // not so high that it requires an excessive number of `DFPattern`.
+  //
+  // This value is deliberately *NOT* exposed, as `max_concat` may be
+  // increased at any point that it is required, and other use cases
+  // should not depend on its value.  If there is a use case that
+  // requires more matmuls to be handled, and pattern-matching does
+  // not yet support a flexible number of `Tuple` elements,
+  // `max_concat` should be increased.
+  size_t min_concat = 2;
+  size_t max_concat = 12;
+
+  std::vector<DFPattern> pat_args;
+  std::vector<DFPattern> pat_permute_dims;
+  for (size_t i = 0; i < max_concat; i++) {
+    auto arg = WildcardPattern();
+    pat_args.push_back(arg);
+    pat_permute_dims.push_back(IsOp("relax.permute_dims")(arg));
+  }
+
+  auto make_pattern_with_num_concat = [&](size_t num_concat) -> DFPattern {
+    ICHECK_LT(num_concat, pat_permute_dims.size());
+    auto concat_tuple = TuplePattern(
+        Array<DFPattern>(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat));
+    return IsOp("relax.concat")(concat_tuple);
+  };
+
+  DFPattern pat_concat = make_pattern_with_num_concat(min_concat);
+  for (size_t i = min_concat + 1; i < max_concat; i++) {
+    pat_concat = pat_concat | make_pattern_with_num_concat(i);
+  }
+
+  auto get_permute_dims_optional_axes = [](const Expr& expr) -> Optional<Array<Integer>> {
+    auto call = expr.as<CallNode>();
+    ICHECK(call);
+    auto attrs = call->attrs.as<PermuteDimsAttrs>();
+    ICHECK(attrs);
+
+    return attrs->axes;
+  };
+
+  auto get_permute_dims_axes =
+      [get_permute_dims_optional_axes](const Expr& expr) -> Array<Integer> {
+    if (auto opt_axes = get_permute_dims_optional_axes(expr)) {
+      return opt_axes.value();
+    } else {
+      auto call = Downcast<Call>(expr);
+      Array<Integer> permutation;
+      auto arg_sinfo = call->args[0]->struct_info_.as<TensorStructInfoNode>();
+      CHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, "
+                       << "but argument " << call->args[0] << " has struct info "
+                       << call->args[0]->struct_info_;
+      CHECK_GE(arg_sinfo->ndim, 0);
+      size_t ndim = arg_sinfo->ndim;
+      for (size_t i = 0; i < ndim; i++) {
+        permutation.push_back(Integer(ndim - i - 1));
+      }
+      return permutation;
+    }
+  };
+
+  auto permute_dims_axes_are_compatible = [&](const Array<Expr>& permute_dims) -> bool {
+    auto first_axes = get_permute_dims_axes(permute_dims[0]);
+    for (size_t i_arg = 1; i_arg < permute_dims.size(); i_arg++) {
+      auto i_axes = get_permute_dims_axes(permute_dims[i_arg]);
+      if (i_axes.size() != first_axes.size()) {
+        return false;
+      }
+      for (size_t i_axis = 0; i_axis < first_axes.size(); i_axis++) {
+        if (i_axes[i_axis]->value != first_axes[i_axis]->value) {
+          return false;
+        }
+      }
+    }
+    return true;
+  };
+
+  auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
+    Array<Expr> args;
+    Array<Expr> all_permute_dims;
+    for (size_t i = 0; i < max_concat; i++) {
+      if (auto permute_dim_expr = matches.Get(pat_permute_dims[i])) {
+        all_permute_dims.push_back(permute_dim_expr.value());
+        args.push_back(matches[pat_args[i]]);
+      }
+    }
+
+    ICHECK_GE(all_permute_dims.size(), min_concat)
+        << "InternalError: "
+        << "Pattern match should return at least " << min_concat << " items, but only found "
+        << all_permute_dims.size() << ": " << all_permute_dims;
+
+    if (!permute_dims_axes_are_compatible(all_permute_dims)) {
+      return expr;
+    }
+    Optional<Array<Integer>> permute_axes = get_permute_dims_optional_axes(all_permute_dims[0]);
+
+    Call concat_call = Downcast<Call>(matches[pat_concat]);
+    auto concat_attrs = concat_call->attrs.as<ConcatAttrs>();
+    ICHECK(concat_attrs);
+
+    auto old_concat_axis = [&]() -> size_t {
+      if (concat_attrs->axis.defined()) {
+        return concat_attrs->axis.value()->value;
+      } else {
+        return 0;
+      }
+    }();
+    Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis];
+
+    auto new_concat = concat(Tuple(args), new_concat_axis);
+    auto new_permute_dims = permute_dims(new_concat, permute_axes);
+
+    return new_permute_dims;
+  };
+
+  return {pat_concat, rewriter};
+}
+
+}  // namespace
+
+namespace transform {
+Pass ReorderPermuteDimsAfterConcat() {
+  auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
+    auto [pattern, rewriter] = CreatePatterns();
+    return RewriteCall(pattern, rewriter, func);
+  };
+  return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat")
+    .set_body_typed(ReorderPermuteDimsAfterConcat);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
new file mode 100644
index 0000000000..533ba7b696
--- /dev/null
+++ b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py
@@ -0,0 +1,264 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import inspect
+
+import pytest
+
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I, relax as R
+
+
+class Base:
+    def test_compare(self):
+        transform = relax.transform.ReorderPermuteDimsAfterConcat()
+
+        if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception):
+            with pytest.raises(self.Expected):
+                transform(self.Before)
+        else:
+            after = transform(self.Before)
+            tvm.ir.assert_structural_equal(self.Expected, after)
+
+
+class TestSimple(Base):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            linear_weight_A: R.Tensor([128, 32], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                matmul_weight_A = R.permute_dims(linear_weight_A)
+                matmul_weight_B = R.permute_dims(linear_weight_B)
+                matmul_weight = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+                out = R.matmul(x, matmul_weight)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            linear_weight_A: R.Tensor([128, 32], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                linear_weight = R.concat([linear_weight_A, linear_weight_B], axis=0)
+                matmul_weight = R.permute_dims(linear_weight)
+                out = R.matmul(x, matmul_weight)
+                R.output(out)
+            return out
+
+
+class TestCombineExplicitAndImplicitAxes(Base):
+    """Check for explicit axes to be permuted
+
+    If `R.permute_dims` has no axes specified, it reverses the order
+    of all axes.  For a 2-d argument, `R.permute_dims(arg)` and
+    `R.permute_dims(arg, [1,0])` are equivalent, and should be
+    able to be combinable.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            linear_weight_A: R.Tensor([128, 32], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                matmul_weight_A = R.permute_dims(linear_weight_A)
+                matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+                matmul_weight = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+                out = R.matmul(x, matmul_weight)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            linear_weight_A: R.Tensor([128, 32], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                linear_weight = R.concat([linear_weight_A, linear_weight_B], axis=0)
+                matmul_weight = R.permute_dims(linear_weight)
+                out = R.matmul(x, matmul_weight)
+                R.output(out)
+            return out
+
+
+class TestDoNotCombineIncompatibleAxes(Base):
+    """No change should be made for incompatible permutations
+
+    The different `R.permute_dims` must each perform the same
+    permutation for the reordering to be valid.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            weight_A: R.Tensor([32, 128], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+                matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+                matmul_weight = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+                out = R.matmul(x, matmul_weight)
+                R.output(out)
+            return out
+
+    Expected = Before
+
+
+class TestCheckForRewriteAfterIncompatibleChange(Base):
+    """Check all R.permute_dims options, not just the first
+
+    Complex conditionals may be implemented in the rewriter, rather
+    than the pattern match.  In these cases, the rewriter may return
+    the matched expression unmodified.  However, this prevents the
+    pattern-matcher from checking later instances of the match.
+
+    By moving the complex conditional to a `ConstrainedPattern`, the
+    pattern-matcher can check against all possible matches.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            weight_A: R.Tensor([32, 128], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+            linear_weight_C: R.Tensor([128, 32], "float32"),
+            linear_weight_D: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+                matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+                matmul_weight_AB = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+                out_AB = R.matmul(x, matmul_weight_AB)
+
+                matmul_weight_C = R.permute_dims(linear_weight_C)
+                matmul_weight_D = R.permute_dims(linear_weight_D)
+                matmul_weight_CD = R.concat([matmul_weight_C, matmul_weight_D], axis=1)
+                out_CD = R.matmul(x, matmul_weight_CD)
+
+                out = (out_AB, out_CD)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            weight_A: R.Tensor([32, 128], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+            linear_weight_C: R.Tensor([128, 32], "float32"),
+            linear_weight_D: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+                matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+                matmul_weight_AB = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+                out_AB = R.matmul(x, matmul_weight_AB)
+
+                linear_weight_CD = R.concat([linear_weight_C, linear_weight_D], axis=0)
+                matmul_weight_CD = R.permute_dims(linear_weight_CD)
+                out_CD = R.matmul(x, matmul_weight_CD)
+
+                out = (out_AB, out_CD)
+                R.output(out)
+            return out
+
+
+class TestCheckForRewriteBeforeIncompatibleChange(Base):
+    """Check all R.permute_dims options, not just the first
+
+    Complex conditionals may be implemented in the rewriter, rather
+    than the pattern match.  In these cases, the rewriter may return
+    the matched expression unmodified.  However, this prevents the
+    pattern-matcher from checking later instances of the match.
+
+    By moving the complex conditional to a `ConstrainedPattern`, the
+    pattern-matcher can check against all possible matches.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            weight_A: R.Tensor([32, 128], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+            linear_weight_C: R.Tensor([128, 32], "float32"),
+            linear_weight_D: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                matmul_weight_C = R.permute_dims(linear_weight_C)
+                matmul_weight_D = R.permute_dims(linear_weight_D)
+                matmul_weight_CD = R.concat([matmul_weight_C, matmul_weight_D], axis=1)
+                out_CD = R.matmul(x, matmul_weight_CD)
+
+                matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+                matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+                matmul_weight_AB = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+                out_AB = R.matmul(x, matmul_weight_AB)
+
+                out = (out_AB, out_CD)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([1, 32], "float32"),
+            weight_A: R.Tensor([32, 128], "float32"),
+            linear_weight_B: R.Tensor([128, 32], "float32"),
+            linear_weight_C: R.Tensor([128, 32], "float32"),
+            linear_weight_D: R.Tensor([128, 32], "float32"),
+        ):
+            with R.dataflow():
+                linear_weight_CD = R.concat([linear_weight_C, linear_weight_D], axis=0)
+                matmul_weight_CD = R.permute_dims(linear_weight_CD)
+                out_CD = R.matmul(x, matmul_weight_CD)
+
+                matmul_weight_A = R.permute_dims(weight_A, axes=[0, 1])
+                matmul_weight_B = R.permute_dims(linear_weight_B, axes=[1, 0])
+                matmul_weight_AB = R.concat([matmul_weight_A, matmul_weight_B], axis=1)
+                out_AB = R.matmul(x, matmul_weight_AB)
+
+                out = (out_AB, out_CD)
+                R.output(out)
+            return out
+
+
+if __name__ == "__main__":
+    tvm.testing.main()