You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/15 12:57:33 UTC

[spark] branch branch-3.4 updated: [SPARK-42436][SQL] Improve multiTransform to generate alternatives dynamically

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new ca491362c93 [SPARK-42436][SQL] Improve multiTransform to generate alternatives dynamically
ca491362c93 is described below

commit ca491362c9360c0f0ed5d8c29124a4e9bd72c902
Author: Peter Toth <pe...@gmail.com>
AuthorDate: Wed Feb 15 20:56:53 2023 +0800

    [SPARK-42436][SQL] Improve multiTransform to generate alternatives dynamically
    
    ### What changes were proposed in this pull request?
    
    This PR improves `TreeNode.multiTransform()` to generate the alternative sequences only if needed and fully dynamically. Consider the following simplified example:
    ```
    (a + b).multiTransform {
      case a => Stream(1, 2)
      case b => Stream(10, 20)
    }
    ```
    the result is the cartesian product: `Stream(1 + 10, 2 + 10, 1 + 20, 2 + 20)`.
    Currently `multiTransform` calculates the 2 alternative streams for `a` and `b` **before** start building building the cartesian product stream using `+`. So kind of caches the "inner" `Stream(1, 2)` in the beginning and when the "outer" stream (`Stream(10, 20)`) iterates from `10` to `20` reuses the cache. Although this caching is sometimes useful it has 2 drawbacks:
    - If the "outer" (`b` alternatives) stream returns `Seq.emtpy` (to indicate pruning) the alternatives for the `a` are unecessary calculated and will be discarded.
    - The "inner" stream transformation can't depend on the current "outer" stream alternative.
       E.g. let's see the above `a + b` example but we want to transform both `a` and `b` to `1` and `2`, and we want to have only those alternatives where these 2 are transformed equal (`Stream(1 + 1, 2 + 2)`). This is currently it is not possible with a single `multiTransform` call due to the inner stream alternatives are calculated in advance and cached.
    But, if `multiTransform` would be dynamic and the "inner" alternatives stream would be recalculated when the "outer" alternatives stream iterates then this would be possible:
      ```
      // Cache
      var a_or_b = None
      (a + b).multiTransform {
        case a | b =>
          // Return alternatives from cache if this is not the first encounter
          a_or_b.getOrElse(
            // Besides returning the alternatives for the first encounter, also set up a mechanism to
            // update the cache when the new alternatives are requested.
            Stream(Literal(1), Literal(2)).map { x =>
              a_or_b = Some(Seq(x))
              x
            }.append {
              a_or_b = None
              Seq.empty
            })
      }
      ```
    Please note:
    - that this is a simplified example and we could have run 2 simple `transforms` to get the exprected 2 expressions, but `multiTransform` can do other orthogonal transformations in the same run (e.g. `c` -> `Seq(100, 200)`) and `multiTransform` has the advantage of returning the results lazlily as a stream.
    - the original behaviour of caching "inner" alternative streams is still doable and actually our current usecases in `AliasAwareOutputExpression` and in `BroadcastHashJoinExec` still do it as they store the alternatives in advance in maps and the `multiTransform` call just gets the alternatives from those maps when needed.
    
    ### Why are the changes needed?
    Improvement to make `multiTransform` more versatile.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added new UTs.
    
    Closes #40016 from peter-toth/SPARK-42436-multitransform-generate-alternatives-dynamically.
    
    Authored-by: Peter Toth <pe...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 403e3d219fd8771a3cab3f4f58331896ebe16747)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../apache/spark/sql/catalyst/trees/TreeNode.scala |  6 +-
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   | 68 +++++++++++++++++++++-
 2 files changed, 70 insertions(+), 4 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index c8df2086a72..b90fc585a09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -734,7 +734,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
     afterRulesStream.flatMap { afterRule =>
       if (afterRule.containsChild.nonEmpty) {
         MultiTransform.generateCartesianProduct(
-            afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule)))
+            afterRule.children.map(c => () => c.multiTransformDownWithPruning(cond, ruleId)(rule)))
           .map(afterRule.withNewChildren)
       } else {
         Stream(afterRule)
@@ -1373,11 +1373,11 @@ object MultiTransform {
    * @param elementSeqs a list of sequences to build the cartesian product from
    * @return            the stream of generated `Seq` elements
    */
-  def generateCartesianProduct[T](elementSeqs: Seq[Seq[T]]): Stream[Seq[T]] = {
+  def generateCartesianProduct[T](elementSeqs: Seq[() => Seq[T]]): Stream[Seq[T]] = {
     elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) =>
       for {
         elementTail <- elementTails
-        element <- elements
+        element <- elements()
       } yield element +: elementTail
     )
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index e4adf59b392..3411415bbb6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -1000,7 +1000,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
     assert(transformed === expected)
   }
 
-  test("multiTransformDown is lazy") {
+  test("multiTransformDown alternatives are accessed only if needed") {
     val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
     val transformed = e.multiTransformDown {
       case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
@@ -1080,4 +1080,70 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
     }
     assert(transformed2.isEmpty)
   }
+
+  test("multiTransformDown alternatives are generated only if needed") {
+    val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
+    val transformed = e.multiTransformDown {
+      case StringLiteral("a") => newErrorAfterStream()
+      case StringLiteral("b") => Seq.empty
+    }
+    assert(transformed.isEmpty)
+  }
+
+  test("multiTransformDown can do non-cartesian transformations") {
+    val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
+    // Suppose that we want to transform both `a` and `b` to `1` and `2`, but we want to have only
+    // those alternatives where these 2 are transformed equal. The first encounter with `a` or `b`
+    // will keep track of the current alternative in a "global" `a_or_b` cache. If we encounter `a`
+    // or `b` again at other places we can return the cached value to keep the transformations in
+    // sync.
+    var a_or_b = Option.empty[Seq[Expression]]
+    val transformed = e.multiTransformDown {
+      case StringLiteral("a") | StringLiteral("b") =>
+        // Return alternatives from cache if this is not the first encounter
+        a_or_b.getOrElse(
+          // Besides returning the alternatives for the first encounter, also set up a mechanism to
+          // update the cache when the new alternatives are requested.
+          Stream(Literal(1), Literal(2)).map { x =>
+            a_or_b = Some(Seq(x))
+            x
+          }.append {
+            a_or_b = None
+            Seq.empty
+          })
+      case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200))
+    }
+    val expected = for {
+      cd <- Seq(Literal(100), Literal(200))
+      a_or_b <- Seq(Literal(1), Literal(2))
+    } yield Add(Add(a_or_b, a_or_b), cd)
+    assert(transformed == expected)
+
+    var c_or_d = Option.empty[Seq[Expression]]
+    val transformed2 = e.multiTransformDown {
+      case StringLiteral("a") | StringLiteral("b") =>
+        a_or_b.getOrElse(
+          Stream(Literal(1), Literal(2)).map { x =>
+            a_or_b = Some(Seq(x))
+            x
+          }.append {
+            a_or_b = None
+            Seq.empty
+          })
+      case StringLiteral("c") | StringLiteral("d") =>
+        c_or_d.getOrElse(
+          Stream(Literal(10), Literal(20)).map { x =>
+            c_or_d = Some(Seq(x))
+            x
+          }.append {
+            c_or_d = None
+            Seq.empty
+          })
+    }
+    val expected2 = for {
+      c_or_d <- Seq(Literal(10), Literal(20))
+      a_or_b <- Seq(Literal(1), Literal(2))
+    } yield Add(Add(a_or_b, a_or_b), Add(c_or_d, c_or_d))
+    assert(transformed2 == expected2)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org