You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/10 15:57:14 UTC

[spark] branch master updated: [SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 99431e28f95 [SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions
99431e28f95 is described below

commit 99431e28f950bb25c421abd51888a3f9f4b46685
Author: Supun Nakandala <su...@databricks.com>
AuthorDate: Fri Feb 10 23:56:34 2023 +0800

    [SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions
    
    ### What changes were proposed in this pull request?
    - This PR introduces a new expression called `MultiCommutativeOp` which is used by the commutative expressions (e.g., `Add`, `Multiply`, `And`, `Or`, `BitwiseOr`, `BitwiseAnd`, `BitwiseXor`) during canonicalization.
    - During canonicalization, when there is a list of consecutive commutative expressions, we now create a MultiCommutative expression with references to original operands, instead of creating new objects.
    - This new expression is added as a memory optimization to reduce generating a large number of intermediate objects during canonicalization.
    
    ### Why are the changes needed?
    - With the [recent changes](https://github.com/apache/spark/pull/37851) in the expression canonicalization, a complex query with a large number of commutative operations could end up consuming significantly more (sometimes > 10X) memory on the executors.
    - In our case, this issue happens for a specific complex query that has a huge expression tree containing Add operators interleaved by non Add operators.
    - The issue is related to canonicalization and why it is causing issues in the executors is because the codegen component relies on expression canonicalization to deduplicate expressions.
    - When we have a large number of Adds interleaved by non-Add operators, [this line](https://github.com/apache/spark/pull/37851/files#diff-7278f2db37934522ee7c74b71525153234cff245cefaf996957e4a9ff3dbaacdR1171) ends up materializing a new canonicalized expression tree at every non-Add operator.
    - In our case, analyzing the executor heap histogram shows that the additional memory is consumed by a large number of Add objects.
    - The high memory usage causes the executors to lose heartbeat signals and results in task failures.
    - The proposed `MultiCommutativeOp` expression avoids generating new Add expressions and keeps the extra memory usage to a minimum.
    
    ### Does this PR introduce _any_ user-facing change?
    - No
    
    ### How was this patch tested?
    - Existing unit tests and new unit tests.
    
    Closes #39722 from db-scnakandala/SPARK-42162.
    
    Authored-by: Supun Nakandala <su...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/expressions/Expression.scala      |  74 +++++++++++++
 .../sql/catalyst/expressions/arithmetic.scala      |  13 ++-
 .../catalyst/expressions/bitwiseExpressions.scala  |  15 ++-
 .../sql/catalyst/expressions/predicates.scala      |  10 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |   9 ++
 .../catalyst/expressions/CanonicalizeSuite.scala   | 122 ++++++++++++++++++++-
 6 files changed, 234 insertions(+), 9 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 37ec4802993..2d2236a8a80 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, Tre
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD
 import org.apache.spark.sql.types._
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1335,4 +1336,77 @@ trait CommutativeExpression extends Expression {
   protected def orderCommutative(
       f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] =
     gatherCommutative(this, f).sortBy(_.hashCode())
+
+  /**
+   * Helper method to generated a canonicalized plan. If the number of operands are
+   * greater than the MULTI_COMMUTATIVE_OP_OPT_THRESHOLD, this method creates a
+   * [[MultiCommutativeOp]] as the canonicalized plan.
+   */
+  protected def buildCanonicalizedPlan(
+      collectOperands: PartialFunction[Expression, Seq[Expression]],
+      buildBinaryOp: (Expression, Expression) => Expression,
+      evalMode: Option[EvalMode.Value] = None): Expression = {
+    val operands = orderCommutative(collectOperands)
+    val reorderResult =
+      if (operands.length < SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)) {
+        operands.reduce(buildBinaryOp)
+      } else {
+        MultiCommutativeOp(operands, this.getClass, evalMode)(this)
+      }
+    reorderResult
+  }
+}
+
+/**
+ * A helper class used by the Commutative expressions during canonicalization. During
+ * canonicalization, when we have a long tree of commutative operations, we use the MultiCommutative
+ * expression to represent that tree instead of creating new commutative objects.
+ * This class is added as a memory optimization for processing large commutative operation trees
+ * without creating a large number of new intermediate objects.
+ * The MultiCommutativeOp memory optimization is applied to the following commutative
+ * expressions:
+ *      Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor.
+ * @param operands A sequence of operands that produces a commutative expression tree.
+ * @param opCls The class of the root operator of the expression tree.
+ * @param evalMode The optional expression evaluation mode.
+ * @param originalRoot Root operator of the commutative expression tree before canonicalization.
+ *                     This object reference is used to deduce the return dataType of Add and
+ *                     Multiply operations when the input datatype is decimal.
+ */
+case class MultiCommutativeOp(
+    operands: Seq[Expression],
+    opCls: Class[_],
+    evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends Unevaluable {
+  // Helper method to deduce the data type of a single operation.
+  private def singleOpDataType(lType: DataType, rType: DataType): DataType = {
+    originalRoot match {
+      case add: Add =>
+        (lType, rType) match {
+          case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+            add.resultDecimalType(p1, s1, p2, s2)
+          case _ => lType
+        }
+      case multiply: Multiply =>
+        (lType, rType) match {
+          case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+            multiply.resultDecimalType(p1, s1, p2, s2)
+          case _ => lType
+        }
+    }
+  }
+
+  override def dataType: DataType = {
+    originalRoot match {
+      case _: Add | _: Multiply =>
+        operands.map(_.dataType).reduce((l, r) => singleOpDataType(l, r))
+      case other => other.dataType
+    }
+  }
+
+  override def nullable: Boolean = operands.exists(_.nullable)
+
+  override def children: Seq[Expression] = operands
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    this.copy(operands = newChildren)(originalRoot)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index d5694e58cc9..88f7fabf121 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -479,8 +479,11 @@ case class Add(
 
   override lazy val canonicalized: Expression = {
     // TODO: do not reorder consecutive `Add`s with different `evalMode`
-    val reorderResult =
-      orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
+    val reorderResult = buildCanonicalizedPlan(
+      { case Add(l, r, _) => Seq(l, r) },
+      { case (l: Expression, r: Expression) => Add(l, r, evalMode)},
+      Some(evalMode)
+    )
     if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) {
       reorderResult
     } else {
@@ -632,7 +635,11 @@ case class Multiply(
 
   override lazy val canonicalized: Expression = {
     // TODO: do not reorder consecutive `Multiply`s with different `evalMode`
-    orderCommutative({ case Multiply(l, r, _) => Seq(l, r) }).reduce(Multiply(_, _, evalMode))
+    buildCanonicalizedPlan(
+      { case Multiply(l, r, _) => Seq(l, r) },
+      { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
+      Some(evalMode)
+    )
   }
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 70c3d11deda..6061f625ef0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -62,7 +62,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
     newLeft: Expression, newRight: Expression): BitwiseAnd = copy(left = newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    orderCommutative({ case BitwiseAnd(l, r) => Seq(l, r) }).reduce(BitwiseAnd)
+    buildCanonicalizedPlan(
+      { case BitwiseAnd(l, r) => Seq(l, r) },
+      { case (l: Expression, r: Expression) => BitwiseAnd(l, r)}
+    )
   }
 }
 
@@ -106,7 +109,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
     newLeft: Expression, newRight: Expression): BitwiseOr = copy(left = newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    orderCommutative({ case BitwiseOr(l, r) => Seq(l, r) }).reduce(BitwiseOr)
+    buildCanonicalizedPlan(
+      { case BitwiseOr(l, r) => Seq(l, r) },
+      { case (l: Expression, r: Expression) => BitwiseOr(l, r)}
+    )
   }
 }
 
@@ -150,7 +156,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
     newLeft: Expression, newRight: Expression): BitwiseXor = copy(left = newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    orderCommutative({ case BitwiseXor(l, r) => Seq(l, r) }).reduce(BitwiseXor)
+    buildCanonicalizedPlan(
+      { case BitwiseXor(l, r) => Seq(l, r) },
+      { case (l: Expression, r: Expression) => BitwiseXor(l, r)}
+    )
   }
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index b7d7c5e700e..64bee643c86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -805,7 +805,10 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
     copy(left = newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    orderCommutative({ case And(l, r) => Seq(l, r) }).reduce(And)
+    buildCanonicalizedPlan(
+      { case And(l, r) => Seq(l, r) },
+      { case (l: Expression, r: Expression) => And(l, r)}
+    )
   }
 }
 
@@ -899,7 +902,10 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
     copy(left = newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    orderCommutative({ case Or(l, r) => Seq(l, r) }).reduce(Or)
+    buildCanonicalizedPlan(
+      { case Or(l, r) => Seq(l, r) },
+      { case (l: Expression, r: Expression) => Or(l, r)}
+    )
   }
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 34295d1c42a..8d8aacbc9cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -240,6 +240,15 @@ object SQLConf {
     .intConf
     .createWithDefault(100)
 
+  val MULTI_COMMUTATIVE_OP_OPT_THRESHOLD =
+    buildConf("spark.sql.analyzer.canonicalization.multiCommutativeOpMemoryOptThreshold")
+      .internal()
+      .doc("The minimum number of operands in a commutative expression tree to" +
+        " invoke the MultiCommutativeOp memory optimization during canonicalization.")
+      .version("3.4.0")
+      .intConf
+      .createWithDefault(3)
+
   val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules")
     .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " +
       "specified by their rule names and separated by comma. It is not guaranteed that all the " +
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
index 057fb98c239..f2a9eac8216 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
@@ -23,7 +23,9 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.plans.logical.Range
-import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD
+import org.apache.spark.sql.types.{BooleanType, Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType}
 
 class CanonicalizeSuite extends SparkFunSuite {
 
@@ -206,4 +208,122 @@ class CanonicalizeSuite extends SparkFunSuite {
     assert(!Add(Add(literal4, literal5), literal1).semanticEquals(
       Add(Add(literal1, literal5), literal4)))
   }
+
+  test("SPARK-42162: Commutative expression canonicalization should work" +
+    " with the MultiCommutativeOp memory optimization") {
+    val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)
+    SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, "3")
+
+    // Add
+    val d = Decimal(1.2)
+    val literal1 = Literal.create(d, DecimalType(2, 1))
+    val literal2 = Literal.create(d, DecimalType(2, 1))
+    val literal3 = Literal.create(d, DecimalType(3, 2))
+    assert(Add(literal1, Add(literal2, literal3))
+      .semanticEquals(Add(Add(literal1, literal2), literal3)))
+    assert(Add(literal1, Add(literal2, literal3)).canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // Multiply
+    assert(Multiply(literal1, Multiply(literal2, literal3))
+      .semanticEquals(Multiply(Multiply(literal1, literal2), literal3)))
+    assert(Multiply(literal1, Multiply(literal2, literal3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // And
+    val literalBool1 = Literal.create(true, BooleanType)
+    val literalBool2 = Literal.create(true, BooleanType)
+    val literalBool3 = Literal.create(true, BooleanType)
+    assert(And(literalBool1, And(literalBool2, literalBool3))
+      .semanticEquals(And(And(literalBool1, literalBool2), literalBool3)))
+    assert(And(literalBool1, And(literalBool2, literalBool3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // Or
+    assert(Or(literalBool1, Or(literalBool2, literalBool3))
+      .semanticEquals(Or(Or(literalBool1, literalBool2), literalBool3)))
+    assert(Or(literalBool1, Or(literalBool2, literalBool3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // BitwiseAnd
+    val literalBit1 = Literal(1)
+    val literalBit2 = Literal(2)
+    val literalBit3 = Literal(3)
+    assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3))
+      .semanticEquals(BitwiseAnd(BitwiseAnd(literalBit1, literalBit2), literalBit3)))
+    assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // BitwiseOr
+    assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3))
+      .semanticEquals(BitwiseOr(BitwiseOr(literalBit1, literalBit2), literalBit3)))
+    assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // BitwiseXor
+    assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3))
+      .semanticEquals(BitwiseXor(BitwiseXor(literalBit1, literalBit2), literalBit3)))
+    assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString)
+  }
+
+  test("SPARK-42162: Commutative expression canonicalization should not use" +
+    " MultiCommutativeOp memory optimization when threshold is not met") {
+    val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)
+    SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, "100")
+
+    // Add
+    val d = Decimal(1.2)
+    val literal1 = Literal.create(d, DecimalType(2, 1))
+    val literal2 = Literal.create(d, DecimalType(2, 1))
+    val literal3 = Literal.create(d, DecimalType(3, 2))
+    assert(Add(literal1, Add(literal2, literal3))
+      .semanticEquals(Add(Add(literal1, literal2), literal3)))
+    assert(!Add(literal1, Add(literal2, literal3)).canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // Multiply
+    assert(Multiply(literal1, Multiply(literal2, literal3))
+      .semanticEquals(Multiply(Multiply(literal1, literal2), literal3)))
+    assert(!Multiply(literal1, Multiply(literal2, literal3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // And
+    val literalBool1 = Literal.create(true, BooleanType)
+    val literalBool2 = Literal.create(true, BooleanType)
+    val literalBool3 = Literal.create(true, BooleanType)
+    assert(And(literalBool1, And(literalBool2, literalBool3))
+      .semanticEquals(And(And(literalBool1, literalBool2), literalBool3)))
+    assert(!And(literalBool1, And(literalBool2, literalBool3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // Or
+    assert(Or(literalBool1, Or(literalBool2, literalBool3))
+      .semanticEquals(Or(Or(literalBool1, literalBool2), literalBool3)))
+    assert(!Or(literalBool1, Or(literalBool2, literalBool3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // BitwiseAnd
+    val literalBit1 = Literal(1)
+    val literalBit2 = Literal(2)
+    val literalBit3 = Literal(3)
+    assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3))
+      .semanticEquals(BitwiseAnd(BitwiseAnd(literalBit1, literalBit2), literalBit3)))
+    assert(!BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // BitwiseOr
+    assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3))
+      .semanticEquals(BitwiseOr(BitwiseOr(literalBit1, literalBit2), literalBit3)))
+    assert(!BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    // BitwiseXor
+    assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3))
+      .semanticEquals(BitwiseXor(BitwiseXor(literalBit1, literalBit2), literalBit3)))
+    assert(!BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3))
+      .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+    SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org