You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/10 15:57:20 UTC
[spark] branch branch-3.4 updated: [SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 18b1cb12c58 [SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions
18b1cb12c58 is described below
commit 18b1cb12c58382159f66bc7075d9d2d9c3d36e46
Author: Supun Nakandala <su...@databricks.com>
AuthorDate: Fri Feb 10 23:56:34 2023 +0800
[SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions
### What changes were proposed in this pull request?
- This PR introduces a new expression called `MultiCommutativeOp` which is used by the commutative expressions (e.g., `Add`, `Multiply`, `And`, `Or`, `BitwiseOr`, `BitwiseAnd`, `BitwiseXor`) during canonicalization.
- During canonicalization, when there is a list of consecutive commutative expressions, we now create a MultiCommutative expression with references to original operands, instead of creating new objects.
- This new expression is added as a memory optimization to reduce generating a large number of intermediate objects during canonicalization.
### Why are the changes needed?
- With the [recent changes](https://github.com/apache/spark/pull/37851) in the expression canonicalization, a complex query with a large number of commutative operations could end up consuming significantly more (sometimes > 10X) memory on the executors.
- In our case, this issue happens for a specific complex query that has a huge expression tree containing Add operators interleaved by non Add operators.
- The issue is related to canonicalization and why it is causing issues in the executors is because the codegen component relies on expression canonicalization to deduplicate expressions.
- When we have a large number of Adds interleaved by non-Add operators, [this line](https://github.com/apache/spark/pull/37851/files#diff-7278f2db37934522ee7c74b71525153234cff245cefaf996957e4a9ff3dbaacdR1171) ends up materializing a new canonicalized expression tree at every non-Add operator.
- In our case, analyzing the executor heap histogram shows that the additional memory is consumed by a large number of Add objects.
- The high memory usage causes the executors to lose heartbeat signals and results in task failures.
- The proposed `MultiCommutativeOp` expression avoids generating new Add expressions and keeps the extra memory usage to a minimum.
### Does this PR introduce _any_ user-facing change?
- No
### How was this patch tested?
- Existing unit tests and new unit tests.
Closes #39722 from db-scnakandala/SPARK-42162.
Authored-by: Supun Nakandala <su...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit 99431e28f950bb25c421abd51888a3f9f4b46685)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../sql/catalyst/expressions/Expression.scala | 74 +++++++++++++
.../sql/catalyst/expressions/arithmetic.scala | 13 ++-
.../catalyst/expressions/bitwiseExpressions.scala | 15 ++-
.../sql/catalyst/expressions/predicates.scala | 10 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 9 ++
.../catalyst/expressions/CanonicalizeSuite.scala | 122 ++++++++++++++++++++-
6 files changed, 234 insertions(+), 9 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index de0e90285f5..c0439e9a3f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, Tre
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1334,4 +1335,77 @@ trait CommutativeExpression extends Expression {
protected def orderCommutative(
f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] =
gatherCommutative(this, f).sortBy(_.hashCode())
+
+ /**
+ * Helper method to generated a canonicalized plan. If the number of operands are
+ * greater than the MULTI_COMMUTATIVE_OP_OPT_THRESHOLD, this method creates a
+ * [[MultiCommutativeOp]] as the canonicalized plan.
+ */
+ protected def buildCanonicalizedPlan(
+ collectOperands: PartialFunction[Expression, Seq[Expression]],
+ buildBinaryOp: (Expression, Expression) => Expression,
+ evalMode: Option[EvalMode.Value] = None): Expression = {
+ val operands = orderCommutative(collectOperands)
+ val reorderResult =
+ if (operands.length < SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)) {
+ operands.reduce(buildBinaryOp)
+ } else {
+ MultiCommutativeOp(operands, this.getClass, evalMode)(this)
+ }
+ reorderResult
+ }
+}
+
+/**
+ * A helper class used by the Commutative expressions during canonicalization. During
+ * canonicalization, when we have a long tree of commutative operations, we use the MultiCommutative
+ * expression to represent that tree instead of creating new commutative objects.
+ * This class is added as a memory optimization for processing large commutative operation trees
+ * without creating a large number of new intermediate objects.
+ * The MultiCommutativeOp memory optimization is applied to the following commutative
+ * expressions:
+ * Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor.
+ * @param operands A sequence of operands that produces a commutative expression tree.
+ * @param opCls The class of the root operator of the expression tree.
+ * @param evalMode The optional expression evaluation mode.
+ * @param originalRoot Root operator of the commutative expression tree before canonicalization.
+ * This object reference is used to deduce the return dataType of Add and
+ * Multiply operations when the input datatype is decimal.
+ */
+case class MultiCommutativeOp(
+ operands: Seq[Expression],
+ opCls: Class[_],
+ evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends Unevaluable {
+ // Helper method to deduce the data type of a single operation.
+ private def singleOpDataType(lType: DataType, rType: DataType): DataType = {
+ originalRoot match {
+ case add: Add =>
+ (lType, rType) match {
+ case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+ add.resultDecimalType(p1, s1, p2, s2)
+ case _ => lType
+ }
+ case multiply: Multiply =>
+ (lType, rType) match {
+ case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+ multiply.resultDecimalType(p1, s1, p2, s2)
+ case _ => lType
+ }
+ }
+ }
+
+ override def dataType: DataType = {
+ originalRoot match {
+ case _: Add | _: Multiply =>
+ operands.map(_.dataType).reduce((l, r) => singleOpDataType(l, r))
+ case other => other.dataType
+ }
+ }
+
+ override def nullable: Boolean = operands.exists(_.nullable)
+
+ override def children: Seq[Expression] = operands
+
+ override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+ this.copy(operands = newChildren)(originalRoot)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index d5694e58cc9..88f7fabf121 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -479,8 +479,11 @@ case class Add(
override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Add`s with different `evalMode`
- val reorderResult =
- orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
+ val reorderResult = buildCanonicalizedPlan(
+ { case Add(l, r, _) => Seq(l, r) },
+ { case (l: Expression, r: Expression) => Add(l, r, evalMode)},
+ Some(evalMode)
+ )
if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) {
reorderResult
} else {
@@ -632,7 +635,11 @@ case class Multiply(
override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Multiply`s with different `evalMode`
- orderCommutative({ case Multiply(l, r, _) => Seq(l, r) }).reduce(Multiply(_, _, evalMode))
+ buildCanonicalizedPlan(
+ { case Multiply(l, r, _) => Seq(l, r) },
+ { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
+ Some(evalMode)
+ )
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 70c3d11deda..6061f625ef0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -62,7 +62,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
newLeft: Expression, newRight: Expression): BitwiseAnd = copy(left = newLeft, right = newRight)
override lazy val canonicalized: Expression = {
- orderCommutative({ case BitwiseAnd(l, r) => Seq(l, r) }).reduce(BitwiseAnd)
+ buildCanonicalizedPlan(
+ { case BitwiseAnd(l, r) => Seq(l, r) },
+ { case (l: Expression, r: Expression) => BitwiseAnd(l, r)}
+ )
}
}
@@ -106,7 +109,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
newLeft: Expression, newRight: Expression): BitwiseOr = copy(left = newLeft, right = newRight)
override lazy val canonicalized: Expression = {
- orderCommutative({ case BitwiseOr(l, r) => Seq(l, r) }).reduce(BitwiseOr)
+ buildCanonicalizedPlan(
+ { case BitwiseOr(l, r) => Seq(l, r) },
+ { case (l: Expression, r: Expression) => BitwiseOr(l, r)}
+ )
}
}
@@ -150,7 +156,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
newLeft: Expression, newRight: Expression): BitwiseXor = copy(left = newLeft, right = newRight)
override lazy val canonicalized: Expression = {
- orderCommutative({ case BitwiseXor(l, r) => Seq(l, r) }).reduce(BitwiseXor)
+ buildCanonicalizedPlan(
+ { case BitwiseXor(l, r) => Seq(l, r) },
+ { case (l: Expression, r: Expression) => BitwiseXor(l, r)}
+ )
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index b7d7c5e700e..64bee643c86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -805,7 +805,10 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
copy(left = newLeft, right = newRight)
override lazy val canonicalized: Expression = {
- orderCommutative({ case And(l, r) => Seq(l, r) }).reduce(And)
+ buildCanonicalizedPlan(
+ { case And(l, r) => Seq(l, r) },
+ { case (l: Expression, r: Expression) => And(l, r)}
+ )
}
}
@@ -899,7 +902,10 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
copy(left = newLeft, right = newRight)
override lazy val canonicalized: Expression = {
- orderCommutative({ case Or(l, r) => Seq(l, r) }).reduce(Or)
+ buildCanonicalizedPlan(
+ { case Or(l, r) => Seq(l, r) },
+ { case (l: Expression, r: Expression) => Or(l, r)}
+ )
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 34295d1c42a..8d8aacbc9cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -240,6 +240,15 @@ object SQLConf {
.intConf
.createWithDefault(100)
+ val MULTI_COMMUTATIVE_OP_OPT_THRESHOLD =
+ buildConf("spark.sql.analyzer.canonicalization.multiCommutativeOpMemoryOptThreshold")
+ .internal()
+ .doc("The minimum number of operands in a commutative expression tree to" +
+ " invoke the MultiCommutativeOp memory optimization during canonicalization.")
+ .version("3.4.0")
+ .intConf
+ .createWithDefault(3)
+
val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules")
.doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " +
"specified by their rule names and separated by comma. It is not guaranteed that all the " +
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
index 057fb98c239..f2a9eac8216 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
@@ -23,7 +23,9 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.Range
-import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD
+import org.apache.spark.sql.types.{BooleanType, Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType}
class CanonicalizeSuite extends SparkFunSuite {
@@ -206,4 +208,122 @@ class CanonicalizeSuite extends SparkFunSuite {
assert(!Add(Add(literal4, literal5), literal1).semanticEquals(
Add(Add(literal1, literal5), literal4)))
}
+
+ test("SPARK-42162: Commutative expression canonicalization should work" +
+ " with the MultiCommutativeOp memory optimization") {
+ val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)
+ SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, "3")
+
+ // Add
+ val d = Decimal(1.2)
+ val literal1 = Literal.create(d, DecimalType(2, 1))
+ val literal2 = Literal.create(d, DecimalType(2, 1))
+ val literal3 = Literal.create(d, DecimalType(3, 2))
+ assert(Add(literal1, Add(literal2, literal3))
+ .semanticEquals(Add(Add(literal1, literal2), literal3)))
+ assert(Add(literal1, Add(literal2, literal3)).canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // Multiply
+ assert(Multiply(literal1, Multiply(literal2, literal3))
+ .semanticEquals(Multiply(Multiply(literal1, literal2), literal3)))
+ assert(Multiply(literal1, Multiply(literal2, literal3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // And
+ val literalBool1 = Literal.create(true, BooleanType)
+ val literalBool2 = Literal.create(true, BooleanType)
+ val literalBool3 = Literal.create(true, BooleanType)
+ assert(And(literalBool1, And(literalBool2, literalBool3))
+ .semanticEquals(And(And(literalBool1, literalBool2), literalBool3)))
+ assert(And(literalBool1, And(literalBool2, literalBool3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // Or
+ assert(Or(literalBool1, Or(literalBool2, literalBool3))
+ .semanticEquals(Or(Or(literalBool1, literalBool2), literalBool3)))
+ assert(Or(literalBool1, Or(literalBool2, literalBool3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // BitwiseAnd
+ val literalBit1 = Literal(1)
+ val literalBit2 = Literal(2)
+ val literalBit3 = Literal(3)
+ assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3))
+ .semanticEquals(BitwiseAnd(BitwiseAnd(literalBit1, literalBit2), literalBit3)))
+ assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // BitwiseOr
+ assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3))
+ .semanticEquals(BitwiseOr(BitwiseOr(literalBit1, literalBit2), literalBit3)))
+ assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // BitwiseXor
+ assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3))
+ .semanticEquals(BitwiseXor(BitwiseXor(literalBit1, literalBit2), literalBit3)))
+ assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString)
+ }
+
+ test("SPARK-42162: Commutative expression canonicalization should not use" +
+ " MultiCommutativeOp memory optimization when threshold is not met") {
+ val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)
+ SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, "100")
+
+ // Add
+ val d = Decimal(1.2)
+ val literal1 = Literal.create(d, DecimalType(2, 1))
+ val literal2 = Literal.create(d, DecimalType(2, 1))
+ val literal3 = Literal.create(d, DecimalType(3, 2))
+ assert(Add(literal1, Add(literal2, literal3))
+ .semanticEquals(Add(Add(literal1, literal2), literal3)))
+ assert(!Add(literal1, Add(literal2, literal3)).canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // Multiply
+ assert(Multiply(literal1, Multiply(literal2, literal3))
+ .semanticEquals(Multiply(Multiply(literal1, literal2), literal3)))
+ assert(!Multiply(literal1, Multiply(literal2, literal3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // And
+ val literalBool1 = Literal.create(true, BooleanType)
+ val literalBool2 = Literal.create(true, BooleanType)
+ val literalBool3 = Literal.create(true, BooleanType)
+ assert(And(literalBool1, And(literalBool2, literalBool3))
+ .semanticEquals(And(And(literalBool1, literalBool2), literalBool3)))
+ assert(!And(literalBool1, And(literalBool2, literalBool3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // Or
+ assert(Or(literalBool1, Or(literalBool2, literalBool3))
+ .semanticEquals(Or(Or(literalBool1, literalBool2), literalBool3)))
+ assert(!Or(literalBool1, Or(literalBool2, literalBool3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // BitwiseAnd
+ val literalBit1 = Literal(1)
+ val literalBit2 = Literal(2)
+ val literalBit3 = Literal(3)
+ assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3))
+ .semanticEquals(BitwiseAnd(BitwiseAnd(literalBit1, literalBit2), literalBit3)))
+ assert(!BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // BitwiseOr
+ assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3))
+ .semanticEquals(BitwiseOr(BitwiseOr(literalBit1, literalBit2), literalBit3)))
+ assert(!BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ // BitwiseXor
+ assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3))
+ .semanticEquals(BitwiseXor(BitwiseXor(literalBit1, literalBit2), literalBit3)))
+ assert(!BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3))
+ .canonicalized.isInstanceOf[MultiCommutativeOp])
+
+ SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org